[REFORMAT] Ran black reformat
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# modules/admin/api.py
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends # Import Body
|
||||
from pydantic import BaseModel # Import BaseModel
|
||||
from fastapi import APIRouter, Depends # Import Body
|
||||
from pydantic import BaseModel # Import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from core.database import get_db
|
||||
from modules.auth.dependencies import admin_only
|
||||
@@ -9,14 +9,17 @@ from .tasks import cleardb
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)])
|
||||
|
||||
|
||||
# Define a Pydantic model for the request body
|
||||
class ClearDbRequest(BaseModel):
|
||||
hard: bool
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def read_admin():
|
||||
return {"message": "Admin route"}
|
||||
|
||||
|
||||
# Change to POST and use the request body model
|
||||
@router.post("/cleardb")
|
||||
def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
|
||||
@@ -25,6 +28,6 @@ def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
|
||||
'hard'=True: Drop and recreate all tables.
|
||||
'hard'=False: Delete data from tables except users.
|
||||
"""
|
||||
hard = payload.hard # Get 'hard' from the payload
|
||||
hard = payload.hard # Get 'hard' from the payload
|
||||
cleardb.delay(hard)
|
||||
return {"message": "Clearing database in the background", "hard": hard}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# modules/admin/services.py
|
||||
|
||||
|
||||
## temp
|
||||
## temp
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from core.celery_app import celery_app
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def cleardb(hard: bool):
|
||||
"""
|
||||
@@ -32,4 +33,4 @@ def cleardb(hard: bool):
|
||||
print(f"Deleting table: {table_name}")
|
||||
db.execute(table.delete())
|
||||
db.commit()
|
||||
return {"message": "Database cleared"}
|
||||
return {"message": "Database cleared"}
|
||||
|
||||
@@ -3,9 +3,24 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from jose import JWTError
|
||||
from modules.auth.models import User
|
||||
from modules.auth.schemas import UserCreate, UserResponse, Token, RefreshTokenRequest, LogoutRequest
|
||||
from modules.auth.schemas import (
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
RefreshTokenRequest,
|
||||
LogoutRequest,
|
||||
)
|
||||
from modules.auth.services import create_user
|
||||
from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens
|
||||
from modules.auth.security import (
|
||||
TokenType,
|
||||
get_current_user,
|
||||
oauth2_scheme,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_token,
|
||||
authenticate_user,
|
||||
blacklist_tokens,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
from core.database import get_db
|
||||
@@ -15,12 +30,19 @@ from core.exceptions import unauthorized_exception
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@router.post(
|
||||
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
|
||||
return create_user(user.username, user.password, user.name, db)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
|
||||
def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
):
|
||||
"""
|
||||
Authenticate user and return JWT tokens in the response body.
|
||||
"""
|
||||
@@ -30,39 +52,53 @@ def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annota
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username},
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
)
|
||||
refresh_token = create_refresh_token(data={"sub": user.username})
|
||||
|
||||
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
def refresh_token(payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]):
|
||||
def refresh_token(
|
||||
payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]
|
||||
):
|
||||
print("Refreshing token...")
|
||||
refresh_token = payload.refresh_token
|
||||
if not refresh_token:
|
||||
raise unauthorized_exception("Refresh token missing in request body")
|
||||
|
||||
user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db)
|
||||
user_data = verify_token(
|
||||
refresh_token, expected_token_type=TokenType.REFRESH, db=db
|
||||
)
|
||||
if not user_data:
|
||||
raise unauthorized_exception("Invalid refresh token")
|
||||
|
||||
new_access_token = create_access_token(data={"sub": user_data.username})
|
||||
return {"access_token": new_access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(payload: LogoutRequest, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme)):
|
||||
def logout(
|
||||
payload: LogoutRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
access_token: str = Depends(oauth2_scheme),
|
||||
):
|
||||
try:
|
||||
refresh_token = payload.refresh_token
|
||||
if not refresh_token:
|
||||
raise unauthorized_exception("Refresh token not found in request body")
|
||||
|
||||
blacklist_tokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
db=db
|
||||
)
|
||||
|
||||
blacklist_tokens(access_token=access_token, refresh_token=refresh_token, db=db)
|
||||
|
||||
return {"message": "Logged out successfully"}
|
||||
except JWTError:
|
||||
raise unauthorized_exception("Invalid token")
|
||||
raise unauthorized_exception("Invalid token")
|
||||
|
||||
@@ -5,14 +5,18 @@ from modules.auth.schemas import UserRole
|
||||
from modules.auth.models import User
|
||||
from core.exceptions import forbidden_exception
|
||||
|
||||
|
||||
class RoleChecker:
|
||||
def __init__(self, allowed_roles: list[UserRole]):
|
||||
self.allowed_roles = allowed_roles
|
||||
|
||||
def __call__(self, user: User = Depends(get_current_user)):
|
||||
if user.role not in self.allowed_roles:
|
||||
raise forbidden_exception("You do not have permission to perform this action.")
|
||||
raise forbidden_exception(
|
||||
"You do not have permission to perform this action."
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
admin_only = RoleChecker([UserRole.ADMIN])
|
||||
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
|
||||
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
|
||||
|
||||
@@ -4,10 +4,12 @@ from sqlalchemy import Column, Integer, String, Enum, DateTime
|
||||
from sqlalchemy.orm import relationship
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class UserRole(str, PyEnum):
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
@@ -2,33 +2,41 @@
|
||||
from enum import Enum as PyEnum
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
scopes: list[str] = []
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserRole(str, PyEnum):
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
name: str
|
||||
|
||||
|
||||
class UserPatch(BaseModel):
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
uuid: str
|
||||
username: str
|
||||
|
||||
@@ -18,6 +18,7 @@ from modules.auth.schemas import TokenData
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||
|
||||
|
||||
class TokenType(str, Enum):
|
||||
ACCESS = "access"
|
||||
REFRESH = "refresh"
|
||||
@@ -25,11 +26,13 @@ class TokenType(str, Enum):
|
||||
|
||||
password_hasher = PasswordHasher()
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password with Argon2 (and optional pepper)."""
|
||||
peppered_password = password + settings.PEPPER # Prepend/append pepper
|
||||
return password_hasher.hash(peppered_password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hashed version using Argon2."""
|
||||
peppered_password = plain_password + settings.PEPPER
|
||||
@@ -38,6 +41,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
except VerifyMismatchError:
|
||||
return False
|
||||
|
||||
|
||||
def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
||||
"""
|
||||
Authenticate a user by checking username/password against the database.
|
||||
@@ -45,41 +49,46 @@ def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
||||
"""
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
|
||||
|
||||
# If user not found or password doesn't match
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
# expire = datetime.now(timezone.utc) + timedelta(seconds=5)
|
||||
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
|
||||
return jwt.encode(
|
||||
to_encode,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM
|
||||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
|
||||
return jwt.encode(
|
||||
to_encode,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM
|
||||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
|
||||
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
|
||||
|
||||
def verify_token(
|
||||
token: str, expected_token_type: TokenType, db: Session
|
||||
) -> TokenData | None:
|
||||
"""Verify a JWT token and return TokenData if valid.
|
||||
|
||||
Parameters
|
||||
@@ -96,24 +105,32 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok
|
||||
TokenData | None
|
||||
TokenData instance if the token is valid, None otherwise.
|
||||
"""
|
||||
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
|
||||
is_blacklisted = (
|
||||
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
|
||||
is not None
|
||||
)
|
||||
if is_blacklisted:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
token_type: str = payload.get("token_type")
|
||||
|
||||
|
||||
if username is None or token_type != expected_token_type:
|
||||
return None
|
||||
|
||||
|
||||
return TokenData(username=username)
|
||||
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
|
||||
|
||||
def get_current_user(
|
||||
db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@@ -121,26 +138,28 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen
|
||||
)
|
||||
|
||||
# Check if the token is blacklisted
|
||||
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
|
||||
is_blacklisted = (
|
||||
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
|
||||
is not None
|
||||
)
|
||||
if is_blacklisted:
|
||||
raise credentials_exception
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM]
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
user: User = db.query(User).filter(User.username == username).first()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
|
||||
"""Blacklist both access and refresh tokens.
|
||||
|
||||
@@ -154,7 +173,9 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
|
||||
Database session to perform the operation.
|
||||
"""
|
||||
for token in [access_token, refresh_token]:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||
|
||||
# Add the token to the blacklist
|
||||
@@ -163,10 +184,13 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def blacklist_token(token: str, db: Session) -> None:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||
|
||||
|
||||
# Add the token to the blacklist
|
||||
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
|
||||
db.add(blacklisted_token)
|
||||
|
||||
@@ -20,11 +20,13 @@ def create_user(username: str, password: str, name: str, db: Session) -> UserRes
|
||||
existing_user = db.query(User).filter(User.username == username).first()
|
||||
if existing_user:
|
||||
raise conflict_exception("Username already exists")
|
||||
|
||||
|
||||
hashed_password = hash_password(password)
|
||||
user_uuid = str(uuid.uuid4())
|
||||
user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid)
|
||||
user = User(
|
||||
username=username, hashed_password=hashed_password, name=name, uuid=user_uuid
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user) # Loads the generated ID
|
||||
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic
|
||||
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic
|
||||
|
||||
@@ -6,50 +6,63 @@ from typing import List, Optional
|
||||
from modules.auth.dependencies import get_current_user
|
||||
from core.database import get_db
|
||||
from modules.auth.models import User
|
||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate, CalendarEventResponse
|
||||
from modules.calendar.service import create_calendar_event, get_calendar_event_by_id, get_calendar_events, update_calendar_event, delete_calendar_event
|
||||
from modules.calendar.schemas import (
|
||||
CalendarEventCreate,
|
||||
CalendarEventUpdate,
|
||||
CalendarEventResponse,
|
||||
)
|
||||
from modules.calendar.service import (
|
||||
create_calendar_event,
|
||||
get_calendar_event_by_id,
|
||||
get_calendar_events,
|
||||
update_calendar_event,
|
||||
delete_calendar_event,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/calendar", tags=["calendar"])
|
||||
|
||||
@router.post("/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@router.post(
|
||||
"/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
def create_event(
|
||||
event: CalendarEventCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return create_calendar_event(db, user.id, event)
|
||||
|
||||
|
||||
@router.get("/events", response_model=List[CalendarEventResponse])
|
||||
def get_events(
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
start: Optional[datetime] = None,
|
||||
end: Optional[datetime] = None
|
||||
end: Optional[datetime] = None,
|
||||
):
|
||||
return get_calendar_events(db, user.id, start, end)
|
||||
|
||||
|
||||
@router.get("/events/{event_id}", response_model=CalendarEventResponse)
|
||||
def get_event_by_id(
|
||||
event_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
event = get_calendar_event_by_id(db, user.id, event_id)
|
||||
return event
|
||||
|
||||
|
||||
@router.patch("/events/{event_id}", response_model=CalendarEventResponse)
|
||||
def update_event(
|
||||
event_id: int,
|
||||
event: CalendarEventUpdate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return update_calendar_event(db, user.id, event_id, event)
|
||||
|
||||
|
||||
@router.delete("/events/{event_id}", status_code=204)
|
||||
def delete_event(
|
||||
event_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
delete_calendar_event(db, user.id, event_id)
|
||||
delete_calendar_event(db, user.id, event_id)
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
# modules/calendar/models.py
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Boolean # Add Boolean
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Boolean,
|
||||
) # Add Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from core.database import Base
|
||||
|
||||
|
||||
class CalendarEvent(Base):
|
||||
__tablename__ = "calendar_events"
|
||||
|
||||
@@ -12,10 +21,12 @@ class CalendarEvent(Base):
|
||||
start = Column(DateTime, nullable=False)
|
||||
end = Column(DateTime)
|
||||
location = Column(String)
|
||||
all_day = Column(Boolean, default=False) # Add all_day column
|
||||
all_day = Column(Boolean, default=False) # Add all_day column
|
||||
tags = Column(JSON)
|
||||
color = Column(String) # hex code for color
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False) # <-- Relationship
|
||||
color = Column(String) # hex code for color
|
||||
user_id = Column(
|
||||
Integer, ForeignKey("users.id"), nullable=False
|
||||
) # <-- Relationship
|
||||
|
||||
# Bi-directional relationship (for eager loading)
|
||||
user = relationship("User", back_populates="calendar_events")
|
||||
user = relationship("User", back_populates="calendar_events")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# modules/calendar/schemas.py
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, field_validator # Add field_validator
|
||||
from typing import List, Optional # Add List and Optional
|
||||
from pydantic import BaseModel, field_validator # Add field_validator
|
||||
from typing import List, Optional # Add List and Optional
|
||||
|
||||
|
||||
# Base schema for common fields, including tags
|
||||
class CalendarEventBase(BaseModel):
|
||||
@@ -10,21 +11,23 @@ class CalendarEventBase(BaseModel):
|
||||
start: datetime
|
||||
end: Optional[datetime] = None
|
||||
location: Optional[str] = None
|
||||
color: Optional[str] = None # Assuming color exists
|
||||
all_day: Optional[bool] = None # Add all_day field
|
||||
tags: Optional[List[str]] = None # Add optional tags
|
||||
color: Optional[str] = None # Assuming color exists
|
||||
all_day: Optional[bool] = None # Add all_day field
|
||||
tags: Optional[List[str]] = None # Add optional tags
|
||||
|
||||
@field_validator('tags', mode='before')
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def tags_validate_null_string(cls, v):
|
||||
if v == "Null":
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
# Schema for creating an event (inherits from Base)
|
||||
class CalendarEventCreate(CalendarEventBase):
|
||||
pass
|
||||
|
||||
|
||||
# Schema for updating an event (all fields optional)
|
||||
class CalendarEventUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
@@ -33,23 +36,24 @@ class CalendarEventUpdate(BaseModel):
|
||||
end: Optional[datetime] = None
|
||||
location: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
all_day: Optional[bool] = None # Add all_day field
|
||||
tags: Optional[List[str]] = None # Add optional tags for update
|
||||
all_day: Optional[bool] = None # Add all_day field
|
||||
tags: Optional[List[str]] = None # Add optional tags for update
|
||||
|
||||
@field_validator('tags', mode='before')
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def tags_validate_null_string(cls, v):
|
||||
if v == "Null":
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
# Schema for the response (inherits from Base, adds ID and user_id)
|
||||
class CalendarEventResponse(CalendarEventBase):
|
||||
id: int
|
||||
user_id: int
|
||||
tags: List[str] # Keep as List[str], remove default []
|
||||
tags: List[str] # Keep as List[str], remove default []
|
||||
|
||||
@field_validator('tags', mode='before')
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def tags_validate_none_to_list(cls, v):
|
||||
# If the value from the source object (e.g., ORM model) is None,
|
||||
@@ -59,4 +63,4 @@ class CalendarEventResponse(CalendarEventBase):
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
from_attributes = True
|
||||
|
||||
@@ -1,25 +1,34 @@
|
||||
# modules/calendar/service.py
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_ # Import or_
|
||||
from sqlalchemy import or_ # Import or_
|
||||
from datetime import datetime
|
||||
from modules.calendar.models import CalendarEvent
|
||||
from core.exceptions import not_found_exception
|
||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate # Import schemas
|
||||
from modules.calendar.schemas import (
|
||||
CalendarEventCreate,
|
||||
CalendarEventUpdate,
|
||||
) # Import schemas
|
||||
|
||||
|
||||
def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate):
|
||||
# Ensure tags is None if not provided or empty list, matching model
|
||||
tags_to_store = event_data.tags if event_data.tags else None
|
||||
event = CalendarEvent(
|
||||
**event_data.model_dump(exclude={'tags'}), # Use model_dump and exclude tags initially
|
||||
tags=tags_to_store, # Set tags separately
|
||||
user_id=user_id
|
||||
**event_data.model_dump(
|
||||
exclude={"tags"}
|
||||
), # Use model_dump and exclude tags initially
|
||||
tags=tags_to_store, # Set tags separately
|
||||
user_id=user_id,
|
||||
)
|
||||
db.add(event)
|
||||
db.commit()
|
||||
db.refresh(event)
|
||||
return event
|
||||
|
||||
def get_calendar_events(db: Session, user_id: int, start: datetime | None, end: datetime | None):
|
||||
|
||||
def get_calendar_events(
|
||||
db: Session, user_id: int, start: datetime | None, end: datetime | None
|
||||
):
|
||||
"""
|
||||
Retrieves calendar events for a user, optionally filtered by a date range.
|
||||
|
||||
@@ -46,9 +55,13 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
|
||||
query = query.filter(
|
||||
or_(
|
||||
# Case 1: Event has duration and overlaps
|
||||
(CalendarEvent.end is not None) & (CalendarEvent.start < end) & (CalendarEvent.end > start),
|
||||
(CalendarEvent.end is not None)
|
||||
& (CalendarEvent.start < end)
|
||||
& (CalendarEvent.end > start),
|
||||
# Case 2: Event is a point event within the range
|
||||
(CalendarEvent.end is None) & (CalendarEvent.start >= start) & (CalendarEvent.start < end)
|
||||
(CalendarEvent.end is None)
|
||||
& (CalendarEvent.start >= start)
|
||||
& (CalendarEvent.start < end),
|
||||
)
|
||||
)
|
||||
# If only start is provided, filter events starting on or after start
|
||||
@@ -60,37 +73,41 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
|
||||
elif end:
|
||||
# Includes events with duration ending <= end (or starting before end if end is None)
|
||||
# Includes point events occurring < end
|
||||
query = query.filter(
|
||||
query = query.filter(
|
||||
or_(
|
||||
# Event ends before the specified end time
|
||||
(CalendarEvent.end is not None) & (CalendarEvent.end <= end),
|
||||
# Point event occurs before the specified end time
|
||||
(CalendarEvent.end is None) & (CalendarEvent.start < end)
|
||||
(CalendarEvent.end is None) & (CalendarEvent.start < end),
|
||||
)
|
||||
)
|
||||
# Alternative interpretation for "ending before end": include events that *start* before end
|
||||
# query = query.filter(CalendarEvent.start < end)
|
||||
# Alternative interpretation for "ending before end": include events that *start* before end
|
||||
# query = query.filter(CalendarEvent.start < end)
|
||||
|
||||
return query.order_by(CalendarEvent.start).all() # Order by start time
|
||||
|
||||
return query.order_by(CalendarEvent.start).all() # Order by start time
|
||||
|
||||
def get_calendar_event_by_id(db: Session, user_id: int, event_id: int):
|
||||
event = db.query(CalendarEvent).filter(
|
||||
CalendarEvent.id == event_id,
|
||||
CalendarEvent.user_id == user_id
|
||||
).first()
|
||||
event = (
|
||||
db.query(CalendarEvent)
|
||||
.filter(CalendarEvent.id == event_id, CalendarEvent.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not event:
|
||||
raise not_found_exception()
|
||||
return event
|
||||
|
||||
def update_calendar_event(db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate):
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
|
||||
def update_calendar_event(
|
||||
db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate
|
||||
):
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
# Use model_dump with exclude_unset=True to only update provided fields
|
||||
update_data = event_data.model_dump(exclude_unset=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
# Ensure tags is handled correctly (set to None if empty list provided)
|
||||
if key == 'tags' and isinstance(value, list) and not value:
|
||||
if key == "tags" and isinstance(value, list) and not value:
|
||||
setattr(event, key, None)
|
||||
else:
|
||||
setattr(event, key, value)
|
||||
@@ -99,7 +116,8 @@ def update_calendar_event(db: Session, user_id: int, event_id: int, event_data:
|
||||
db.refresh(event)
|
||||
return event
|
||||
|
||||
|
||||
def delete_calendar_event(db: Session, user_id: int, event_id: int):
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
db.delete(event)
|
||||
db.commit()
|
||||
db.commit()
|
||||
|
||||
@@ -7,13 +7,27 @@ from core.database import get_db
|
||||
|
||||
from modules.auth.dependencies import get_current_user
|
||||
from modules.auth.models import User
|
||||
|
||||
# Import the new service functions and Enum
|
||||
from modules.nlp.service import process_request, ask_ai, save_chat_message, get_chat_history, MessageSender
|
||||
from modules.nlp.service import (
|
||||
process_request,
|
||||
ask_ai,
|
||||
save_chat_message,
|
||||
get_chat_history,
|
||||
MessageSender,
|
||||
)
|
||||
|
||||
# Import the response schema and the new ChatMessage model for response type hinting
|
||||
from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse
|
||||
from modules.calendar.service import create_calendar_event, get_calendar_events, update_calendar_event, delete_calendar_event
|
||||
from modules.calendar.service import (
|
||||
create_calendar_event,
|
||||
get_calendar_events,
|
||||
update_calendar_event,
|
||||
delete_calendar_event,
|
||||
)
|
||||
from modules.calendar.models import CalendarEvent
|
||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate
|
||||
|
||||
# Import TODO services, schemas, and model
|
||||
from modules.todo import service as todo_service
|
||||
from modules.todo.models import Todo
|
||||
@@ -21,17 +35,20 @@ from modules.todo.schemas import TodoCreate, TodoUpdate
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
id: int
|
||||
sender: MessageSender # Use the enum directly
|
||||
sender: MessageSender # Use the enum directly
|
||||
text: str
|
||||
timestamp: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True # Allow Pydantic to work with ORM models
|
||||
|
||||
from_attributes = True # Allow Pydantic to work with ORM models
|
||||
|
||||
|
||||
router = APIRouter(prefix="/nlp", tags=["nlp"])
|
||||
|
||||
|
||||
# Helper to format calendar events (expects list of CalendarEvent models)
|
||||
def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
|
||||
if not events:
|
||||
@@ -39,12 +56,15 @@ def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
|
||||
formatted = ["Here are the events:"]
|
||||
for event in events:
|
||||
# Access attributes directly from the model instance
|
||||
start_str = event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
|
||||
start_str = (
|
||||
event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
|
||||
)
|
||||
end_str = event.end.strftime("%H:%M") if event.end else ""
|
||||
title = event.title or "Untitled Event"
|
||||
formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})")
|
||||
return formatted
|
||||
|
||||
|
||||
# Helper to format TODO items (expects list of Todo models)
|
||||
def format_todos(todos: List[Todo]) -> List[str]:
|
||||
if not todos:
|
||||
@@ -54,19 +74,28 @@ def format_todos(todos: List[Todo]) -> List[str]:
|
||||
status = "[X]" if todo.complete else "[ ]"
|
||||
date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else ""
|
||||
remind_str = " (Reminder)" if todo.remind else ""
|
||||
formatted.append(f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})")
|
||||
formatted.append(
|
||||
f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})"
|
||||
)
|
||||
return formatted
|
||||
|
||||
|
||||
# Update the response model for the endpoint
|
||||
@router.post("/process-command", response_model=ProcessCommandResponse)
|
||||
def process_command(request_data: ProcessCommandRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
def process_command(
|
||||
request_data: ProcessCommandRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Process the user command, save messages, execute action, save response, and return user-friendly responses.
|
||||
"""
|
||||
user_input = request_data.user_input
|
||||
|
||||
# --- Save User Message ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.USER, text=user_input)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.USER, text=user_input
|
||||
)
|
||||
# ------------------------
|
||||
|
||||
command_data = process_request(user_input)
|
||||
@@ -74,11 +103,13 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
|
||||
params = command_data["params"]
|
||||
response_text = command_data["response_text"]
|
||||
|
||||
responses = [response_text] # Start with the initial response
|
||||
responses = [response_text] # Start with the initial response
|
||||
|
||||
# --- Save Initial AI Response ---
|
||||
# Save the first response generated by process_request
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=response_text
|
||||
)
|
||||
# -----------------------------
|
||||
|
||||
if intent == "error":
|
||||
@@ -97,139 +128,233 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
|
||||
ai_answer = ask_ai(**params)
|
||||
responses.append(ai_answer)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "get_calendar_events":
|
||||
events: List[CalendarEvent] = get_calendar_events(db, current_user.id, **params)
|
||||
events: List[CalendarEvent] = get_calendar_events(
|
||||
db, current_user.id, **params
|
||||
)
|
||||
formatted_responses = format_calendar_events(events)
|
||||
responses.extend(formatted_responses)
|
||||
# --- Save Additional AI Responses ---
|
||||
for resp in formatted_responses:
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
|
||||
)
|
||||
# ----------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "add_calendar_event":
|
||||
event_data = CalendarEventCreate(**params)
|
||||
created_event = create_calendar_event(db, current_user.id, event_data)
|
||||
start_str = created_event.start.strftime("%Y-%m-%d %H:%M") if created_event.start else "No start time"
|
||||
start_str = (
|
||||
created_event.start.strftime("%Y-%m-%d %H:%M")
|
||||
if created_event.start
|
||||
else "No start time"
|
||||
)
|
||||
title = created_event.title or "Untitled Event"
|
||||
add_response = f"Added: {title} starting at {start_str}."
|
||||
responses.append(add_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=add_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "update_calendar_event":
|
||||
event_id = params.pop('event_id', None)
|
||||
event_id = params.pop("event_id", None)
|
||||
if event_id is None:
|
||||
# Save the error message before raising
|
||||
error_msg = "Event ID is required for update."
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
event_data = CalendarEventUpdate(**params)
|
||||
updated_event = update_calendar_event(db, current_user.id, event_id, event_data=event_data)
|
||||
updated_event = update_calendar_event(
|
||||
db, current_user.id, event_id, event_data=event_data
|
||||
)
|
||||
title = updated_event.title or "Untitled Event"
|
||||
update_response = f"Updated event ID {updated_event.id}: {title}."
|
||||
responses.append(update_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=update_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "delete_calendar_event":
|
||||
event_id = params.get('event_id')
|
||||
event_id = params.get("event_id")
|
||||
if event_id is None:
|
||||
# Save the error message before raising
|
||||
error_msg = "Event ID is required for delete."
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
delete_calendar_event(db, current_user.id, event_id)
|
||||
delete_response = f"Deleted event ID {event_id}."
|
||||
responses.append(delete_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=delete_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
# --- Add TODO Cases ---
|
||||
# --- Add TODO Cases ---
|
||||
case "get_todos":
|
||||
todos: List[Todo] = todo_service.get_todos(db, user=current_user, **params)
|
||||
todos: List[Todo] = todo_service.get_todos(
|
||||
db, user=current_user, **params
|
||||
)
|
||||
formatted_responses = format_todos(todos)
|
||||
responses.extend(formatted_responses)
|
||||
# --- Save Additional AI Responses ---
|
||||
for resp in formatted_responses:
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
|
||||
)
|
||||
# ----------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "add_todo":
|
||||
todo_data = TodoCreate(**params)
|
||||
created_todo = todo_service.create_todo(db, todo=todo_data, user=current_user)
|
||||
add_response = f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
|
||||
created_todo = todo_service.create_todo(
|
||||
db, todo=todo_data, user=current_user
|
||||
)
|
||||
add_response = (
|
||||
f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
|
||||
)
|
||||
responses.append(add_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=add_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "update_todo":
|
||||
todo_id = params.pop('todo_id', None)
|
||||
todo_id = params.pop("todo_id", None)
|
||||
if todo_id is None:
|
||||
error_msg = "TODO ID is required for update."
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
todo_data = TodoUpdate(**params)
|
||||
updated_todo = todo_service.update_todo(db, todo_id=todo_id, todo_update=todo_data, user=current_user)
|
||||
update_response = f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
|
||||
if 'complete' in params:
|
||||
status = "complete" if params['complete'] else "incomplete"
|
||||
updated_todo = todo_service.update_todo(
|
||||
db, todo_id=todo_id, todo_update=todo_data, user=current_user
|
||||
)
|
||||
update_response = (
|
||||
f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
|
||||
)
|
||||
if "complete" in params:
|
||||
status = "complete" if params["complete"] else "incomplete"
|
||||
update_response += f" Marked as {status}."
|
||||
responses.append(update_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=update_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "delete_todo":
|
||||
todo_id = params.get('todo_id')
|
||||
todo_id = params.get("todo_id")
|
||||
if todo_id is None:
|
||||
error_msg = "TODO ID is required for delete."
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
deleted_todo = todo_service.delete_todo(db, todo_id=todo_id, user=current_user)
|
||||
delete_response = f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
|
||||
deleted_todo = todo_service.delete_todo(
|
||||
db, todo_id=todo_id, user=current_user
|
||||
)
|
||||
delete_response = (
|
||||
f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
|
||||
)
|
||||
responses.append(delete_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=delete_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
# --- End TODO Cases ---
|
||||
|
||||
case _:
|
||||
print(f"Warning: Unhandled intent '{intent}' reached api.py match statement.")
|
||||
case _:
|
||||
print(
|
||||
f"Warning: Unhandled intent '{intent}' reached api.py match statement."
|
||||
)
|
||||
# The initial response_text was already saved
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
except HTTPException as http_exc:
|
||||
# Don't save again if already saved before raising
|
||||
if http_exc.status_code != 400 or ('event_id' not in http_exc.detail.lower()):
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=http_exc.detail)
|
||||
if http_exc.status_code != 400 or ("event_id" not in http_exc.detail.lower()):
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=http_exc.detail,
|
||||
)
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
print(f"Error executing intent '{intent}': {e}")
|
||||
error_response = "Sorry, I encountered an error while trying to perform that action."
|
||||
error_response = (
|
||||
"Sorry, I encountered an error while trying to perform that action."
|
||||
)
|
||||
# --- Save Final Error AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_response)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=error_response
|
||||
)
|
||||
# ----------------------------------
|
||||
return ProcessCommandResponse(responses=[error_response])
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[ChatMessageResponse])
|
||||
def read_chat_history(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
def read_chat_history(
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieves the last 50 chat messages for the current user."""
|
||||
history = get_chat_history(db, user_id=current_user.id, limit=50)
|
||||
return history
|
||||
# -------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
\
|
||||
# /home/cdp/code/MAIA/backend/modules/nlp/models.py
|
||||
from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey, Enum as SQLEnum
|
||||
from sqlalchemy.orm import relationship
|
||||
@@ -7,10 +6,12 @@ import enum
|
||||
|
||||
from core.database import Base
|
||||
|
||||
|
||||
class MessageSender(enum.Enum):
|
||||
USER = "user"
|
||||
AI = "ai"
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_messages"
|
||||
|
||||
@@ -20,4 +21,4 @@ class ChatMessage(Base):
|
||||
text = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
owner = relationship("User") # Relationship to the User model
|
||||
owner = relationship("User") # Relationship to the User model
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class ProcessCommandRequest(BaseModel):
|
||||
user_input: str
|
||||
|
||||
|
||||
class ProcessCommandResponse(BaseModel):
|
||||
responses: List[str]
|
||||
# Optional: Keep details if needed for specific frontend logic beyond display
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# modules/nlp/service.py
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc # Import desc for ordering
|
||||
from sqlalchemy import desc # Import desc for ordering
|
||||
from google import genai
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import List # Import List
|
||||
from typing import List # Import List
|
||||
|
||||
# Import the new model and Enum
|
||||
from .models import ChatMessage, MessageSender
|
||||
@@ -14,7 +14,8 @@ from core.config import settings
|
||||
client = genai.Client(api_key=settings.GOOGLE_API_KEY)
|
||||
|
||||
### Base prompt for MAIA, used for inital user requests
|
||||
SYSTEM_PROMPT = """
|
||||
SYSTEM_PROMPT = (
|
||||
"""
|
||||
You are MAIA - My AI Assistant. Your job is to parse user requests into structured JSON commands and generate a user-facing response text.
|
||||
|
||||
Available functions/intents:
|
||||
@@ -109,8 +110,11 @@ MAIA:
|
||||
"response_text": "Okay, I've deleted task 2 from your list."
|
||||
}
|
||||
|
||||
The datetime right now is """+str(datetime.now(timezone.utc))+""".
|
||||
The datetime right now is """
|
||||
+ str(datetime.now(timezone.utc))
|
||||
+ """.
|
||||
"""
|
||||
)
|
||||
|
||||
### Prompt for MAIA to forward user request to AI
|
||||
SYSTEM_FORWARD_PROMPT = f"""
|
||||
@@ -123,6 +127,7 @@ Here is the user request:
|
||||
|
||||
# --- Chat History Service Functions ---
|
||||
|
||||
|
||||
def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: str):
|
||||
"""Saves a chat message to the database."""
|
||||
db_message = ChatMessage(user_id=user_id, sender=sender, text=text)
|
||||
@@ -131,16 +136,21 @@ def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: st
|
||||
db.refresh(db_message)
|
||||
return db_message
|
||||
|
||||
|
||||
def get_chat_history(db: Session, user_id: int, limit: int = 50) -> List[ChatMessage]:
|
||||
"""Retrieves the last 'limit' chat messages for a user."""
|
||||
return db.query(ChatMessage)\
|
||||
.filter(ChatMessage.user_id == user_id)\
|
||||
.order_by(desc(ChatMessage.timestamp))\
|
||||
.limit(limit)\
|
||||
.all()[::-1] # Reverse to get oldest first for display order
|
||||
return (
|
||||
db.query(ChatMessage)
|
||||
.filter(ChatMessage.user_id == user_id)
|
||||
.order_by(desc(ChatMessage.timestamp))
|
||||
.limit(limit)
|
||||
.all()[::-1]
|
||||
) # Reverse to get oldest first for display order
|
||||
|
||||
|
||||
# --- Existing NLP Service Functions ---
|
||||
|
||||
|
||||
def process_request(request: str):
|
||||
"""
|
||||
Process the user request using the Google GenAI API.
|
||||
@@ -152,7 +162,7 @@ def process_request(request: str):
|
||||
config={
|
||||
"temperature": 0.3, # Less creativity, more factual
|
||||
"response_mime_type": "application/json",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Parse the JSON response
|
||||
@@ -160,7 +170,9 @@ def process_request(request: str):
|
||||
parsed_response = json.loads(response.text)
|
||||
# Validate required fields
|
||||
if not all(k in parsed_response for k in ("intent", "params", "response_text")):
|
||||
raise ValueError("AI response missing required fields (intent, params, response_text)")
|
||||
raise ValueError(
|
||||
"AI response missing required fields (intent, params, response_text)"
|
||||
)
|
||||
return parsed_response
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Error parsing AI response: {e}")
|
||||
@@ -169,9 +181,10 @@ def process_request(request: str):
|
||||
return {
|
||||
"intent": "error",
|
||||
"params": {},
|
||||
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?"
|
||||
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?",
|
||||
}
|
||||
|
||||
|
||||
def ask_ai(request: str):
|
||||
"""
|
||||
Ask the AI a question.
|
||||
@@ -179,6 +192,6 @@ def ask_ai(request: str):
|
||||
"""
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents=SYSTEM_FORWARD_PROMPT+request,
|
||||
contents=SYSTEM_FORWARD_PROMPT + request,
|
||||
)
|
||||
return response.text
|
||||
return response.text
|
||||
|
||||
@@ -5,58 +5,65 @@ from typing import List
|
||||
|
||||
from . import service, schemas
|
||||
from core.database import get_db
|
||||
from modules.auth.dependencies import get_current_user # Corrected import
|
||||
from modules.auth.models import User # Assuming User model is in auth.models
|
||||
from modules.auth.dependencies import get_current_user # Corrected import
|
||||
from modules.auth.models import User # Assuming User model is in auth.models
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/todos",
|
||||
tags=["todos"],
|
||||
dependencies=[Depends(get_current_user)], # Corrected dependency
|
||||
dependencies=[Depends(get_current_user)], # Corrected dependency
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=schemas.Todo, status_code=status.HTTP_201_CREATED)
|
||||
def create_todo_endpoint(
|
||||
todo: schemas.TodoCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
return service.create_todo(db=db, todo=todo, user=current_user)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[schemas.Todo])
|
||||
def read_todos_endpoint(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
todos = service.get_todos(db=db, user=current_user, skip=skip, limit=limit)
|
||||
return todos
|
||||
|
||||
|
||||
@router.get("/{todo_id}", response_model=schemas.Todo)
|
||||
def read_todo_endpoint(
|
||||
todo_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
db_todo = service.get_todo(db=db, todo_id=todo_id, user=current_user)
|
||||
if db_todo is None:
|
||||
raise HTTPException(status_code=404, detail="Todo not found")
|
||||
return db_todo
|
||||
|
||||
|
||||
@router.put("/{todo_id}", response_model=schemas.Todo)
|
||||
def update_todo_endpoint(
|
||||
todo_id: int,
|
||||
todo_update: schemas.TodoUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
return service.update_todo(db=db, todo_id=todo_id, todo_update=todo_update, user=current_user)
|
||||
return service.update_todo(
|
||||
db=db, todo_id=todo_id, todo_update=todo_update, user=current_user
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{todo_id}", response_model=schemas.Todo)
|
||||
def delete_todo_endpoint(
|
||||
todo_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
return service.delete_todo(db=db, todo_id=todo_id, user=current_user)
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from core.database import Base
|
||||
|
||||
|
||||
class Todo(Base):
|
||||
__tablename__ = "todos"
|
||||
|
||||
@@ -13,4 +14,6 @@ class Todo(Base):
|
||||
complete = Column(Boolean, default=False)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||
|
||||
owner = relationship("User") # Add relationship if needed, assuming User model exists in auth.models
|
||||
owner = relationship(
|
||||
"User"
|
||||
) # Add relationship if needed, assuming User model exists in auth.models
|
||||
|
||||
@@ -3,21 +3,25 @@ from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import datetime
|
||||
|
||||
|
||||
class TodoBase(BaseModel):
|
||||
task: str
|
||||
date: Optional[datetime.datetime] = None
|
||||
remind: bool = False
|
||||
complete: bool = False
|
||||
|
||||
|
||||
class TodoCreate(TodoBase):
|
||||
pass
|
||||
|
||||
|
||||
class TodoUpdate(BaseModel):
|
||||
task: Optional[str] = None
|
||||
date: Optional[datetime.datetime] = None
|
||||
remind: Optional[bool] = None
|
||||
complete: Optional[bool] = None
|
||||
|
||||
|
||||
class Todo(TodoBase):
|
||||
id: int
|
||||
owner_id: int
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# backend/modules/todo/service.py
|
||||
from sqlalchemy.orm import Session
|
||||
from . import models, schemas
|
||||
from modules.auth.models import User # Assuming User model is in auth.models
|
||||
from modules.auth.models import User # Assuming User model is in auth.models
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
|
||||
db_todo = models.Todo(**todo.dict(), owner_id=user.id)
|
||||
db.add(db_todo)
|
||||
@@ -11,17 +12,34 @@ def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
|
||||
db.refresh(db_todo)
|
||||
return db_todo
|
||||
|
||||
|
||||
def get_todos(db: Session, user: User, skip: int = 0, limit: int = 100):
|
||||
return db.query(models.Todo).filter(models.Todo.owner_id == user.id).offset(skip).limit(limit).all()
|
||||
return (
|
||||
db.query(models.Todo)
|
||||
.filter(models.Todo.owner_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_todo(db: Session, todo_id: int, user: User):
|
||||
db_todo = db.query(models.Todo).filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id).first()
|
||||
db_todo = (
|
||||
db.query(models.Todo)
|
||||
.filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id)
|
||||
.first()
|
||||
)
|
||||
if db_todo is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found"
|
||||
)
|
||||
return db_todo
|
||||
|
||||
|
||||
def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user: User):
|
||||
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence
|
||||
db_todo = get_todo(
|
||||
db=db, todo_id=todo_id, user=user
|
||||
) # Reuse get_todo to check ownership and existence
|
||||
update_data = todo_update.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_todo, key, value)
|
||||
@@ -29,8 +47,11 @@ def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user
|
||||
db.refresh(db_todo)
|
||||
return db_todo
|
||||
|
||||
|
||||
def delete_todo(db: Session, todo_id: int, user: User):
|
||||
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence
|
||||
db_todo = get_todo(
|
||||
db=db, todo_id=todo_id, user=user
|
||||
) # Reuse get_todo to check ownership and existence
|
||||
db.delete(db_todo)
|
||||
db.commit()
|
||||
return db_todo
|
||||
|
||||
@@ -11,37 +11,52 @@ from modules.auth.models import User
|
||||
|
||||
router = APIRouter(prefix="/user", tags=["user"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
def me(db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||
def me(
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Get the current user. Requires user to be logged in.
|
||||
Returns the user object.
|
||||
"""
|
||||
"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/{username}", response_model=UserResponse)
|
||||
def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||
def get_user(
|
||||
username: str,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Get a user by username.
|
||||
Returns the user object.
|
||||
"""
|
||||
if current_user.username != username:
|
||||
raise forbidden_exception("You can only view your own profile")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise not_found_exception("User not found")
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/{username}", response_model=UserResponse)
|
||||
def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||
def update_user(
|
||||
username: str,
|
||||
user_data: UserPatch,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Update a user by username.
|
||||
Returns the updated user object.
|
||||
"""
|
||||
if current_user.username != username:
|
||||
raise forbidden_exception("You can only update your own profile")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise not_found_exception("User not found")
|
||||
@@ -60,19 +75,24 @@ def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depe
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/{username}", response_model=UserResponse)
|
||||
def delete_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||
def delete_user(
|
||||
username: str,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Delete a user by username.
|
||||
Returns the deleted user object.
|
||||
"""
|
||||
if current_user.username != username:
|
||||
raise forbidden_exception("You can only delete your own profile")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise not_found_exception("User not found")
|
||||
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
return user
|
||||
return user
|
||||
|
||||
Reference in New Issue
Block a user