From 1553004efcc97fb5df51b6c18f9e32be5cd30a0b Mon Sep 17 00:00:00 2001 From: c-d-p Date: Wed, 23 Apr 2025 01:00:56 +0200 Subject: [PATCH] [REFORMAT] Ran black reformat --- backend/alembic/env.py | 6 +- ..._initial_migration_with_existing_tables.py | 6 +- .../versions/9a82960db482_add_todo_table.py | 6 +- backend/core/celery_app.py | 9 +- backend/core/config.py | 10 +- backend/core/database.py | 9 +- backend/core/exceptions.py | 8 +- backend/main.py | 15 +- backend/modules/admin/api.py | 9 +- backend/modules/admin/services.py | 2 +- backend/modules/admin/tasks.py | 3 +- backend/modules/auth/api.py | 70 ++++-- backend/modules/auth/dependencies.py | 8 +- backend/modules/auth/models.py | 2 + backend/modules/auth/schemas.py | 10 +- backend/modules/auth/security.py | 74 ++++-- backend/modules/auth/services.py | 8 +- backend/modules/calendar/api.py | 39 ++- backend/modules/calendar/models.py | 21 +- backend/modules/calendar/schemas.py | 28 ++- backend/modules/calendar/service.py | 62 +++-- backend/modules/nlp/api.py | 223 ++++++++++++++---- backend/modules/nlp/models.py | 5 +- backend/modules/nlp/schemas.py | 2 + backend/modules/nlp/service.py | 41 ++-- backend/modules/todo/api.py | 25 +- backend/modules/todo/models.py | 5 +- backend/modules/todo/schemas.py | 4 + backend/modules/todo/service.py | 33 ++- backend/modules/user/api.py | 38 ++- backend/tests/conftest.py | 17 +- backend/tests/helpers/generators.py | 27 ++- backend/tests/test_admin.py | 50 ++-- backend/tests/test_auth.py | 65 +++-- backend/tests/test_calendar.py | 180 ++++++++++---- backend/tests/test_main.py | 1 + backend/tests/test_nlp.py | 176 ++++++++++---- backend/tests/test_todo.py | 92 +++++--- 38 files changed, 1005 insertions(+), 384 deletions(-) diff --git a/backend/alembic/env.py b/backend/alembic/env.py index c9fe209..f78bcd6 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -7,7 +7,7 @@ from sqlalchemy import pool from alembic import context -from core.database import Base # Import your Base +from core.database import Base # Import your Base # --- Add project root to sys.path --- @@ -77,9 +77,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/backend/alembic/versions/69069d6184b3_initial_migration_with_existing_tables.py b/backend/alembic/versions/69069d6184b3_initial_migration_with_existing_tables.py index 6dece1c..be37ff4 100644 --- a/backend/alembic/versions/69069d6184b3_initial_migration_with_existing_tables.py +++ b/backend/alembic/versions/69069d6184b3_initial_migration_with_existing_tables.py @@ -1,16 +1,16 @@ """Initial migration with existing tables Revision ID: 69069d6184b3 -Revises: +Revises: Create Date: 2025-04-21 01:14:33.233195 """ + from typing import Sequence, Union - # revision identifiers, used by Alembic. -revision: str = '69069d6184b3' +revision: str = "69069d6184b3" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/9a82960db482_add_todo_table.py b/backend/alembic/versions/9a82960db482_add_todo_table.py index c264d75..d38d4fd 100644 --- a/backend/alembic/versions/9a82960db482_add_todo_table.py +++ b/backend/alembic/versions/9a82960db482_add_todo_table.py @@ -5,13 +5,13 @@ Revises: 69069d6184b3 Create Date: 2025-04-21 20:33:27.028529 """ + from typing import Sequence, Union - # revision identifiers, used by Alembic. -revision: str = '9a82960db482' -down_revision: Union[str, None] = '69069d6184b3' +revision: str = "9a82960db482" +down_revision: Union[str, None] = "69069d6184b3" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/core/celery_app.py b/backend/core/celery_app.py index 51e7355..1bb2eb4 100644 --- a/backend/core/celery_app.py +++ b/backend/core/celery_app.py @@ -1,14 +1,17 @@ # core/celery_app.py from celery import Celery -from core.config import settings # Import your settings +from core.config import settings # Import your settings celery_app = Celery( "worker", broker=settings.REDIS_URL, backend=settings.REDIS_URL, - include=["modules.auth.tasks", "modules.admin.tasks"] # Add paths to modules containing tasks + include=[ + "modules.auth.tasks", + "modules.admin.tasks", + ], # Add paths to modules containing tasks # Add other modules with tasks here, e.g., "modules.some_other_module.tasks" ) # Optional: Update Celery configuration directly if needed -# celery_app.conf.update(task_track_started=True) \ No newline at end of file +# celery_app.conf.update(task_track_started=True) diff --git a/backend/core/config.py b/backend/core/config.py index c6b88b2..30b9b76 100644 --- a/backend/core/config.py +++ b/backend/core/config.py @@ -4,12 +4,13 @@ import os DOTENV_PATH = os.path.join(os.path.dirname(__file__), "../.env") + class Settings(BaseSettings): # Database settings - reads DB_URL from environment or .env DB_URL: str = "postgresql://maia:maia@localhost:5432/maia" # Redis settings - reads REDIS_URL from environment or .env, also used for Celery. - REDIS_URL: str ="redis://localhost:6379/0" + REDIS_URL: str = "redis://localhost:6379/0" # JWT settings - reads from environment or .env JWT_ALGORITHM: str = "HS256" @@ -19,13 +20,14 @@ class Settings(BaseSettings): JWT_SECRET_KEY: str # Other settings - GOOGLE_API_KEY: str = "" # Example with a default + GOOGLE_API_KEY: str = "" # Example with a default class Config: # Tell pydantic-settings to load variables from a .env file env_file = DOTENV_PATH - env_file_encoding = 'utf-8' - extra = 'ignore' + env_file_encoding = "utf-8" + extra = "ignore" + # Create a single instance of the settings settings = Settings() diff --git a/backend/core/database.py b/backend/core/database.py index 4897177..cfec7b6 100644 --- a/backend/core/database.py +++ b/backend/core/database.py @@ -10,6 +10,7 @@ Base = declarative_base() # Used for models _engine = None _SessionLocal = None + def get_engine(): global _engine if _engine is None: @@ -20,10 +21,13 @@ def get_engine(): try: _engine.connect() except Exception: - raise Exception("Database connection failed. Is the database server running?") + raise Exception( + "Database connection failed. Is the database server running?" + ) Base.metadata.create_all(_engine) # Create tables here return _engine + def get_sessionmaker(): global _SessionLocal if _SessionLocal is None: @@ -31,10 +35,11 @@ def get_sessionmaker(): _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) return _SessionLocal + def get_db() -> Generator[Session, None, None]: SessionLocal = get_sessionmaker() db = SessionLocal() try: yield db finally: - db.close() \ No newline at end of file + db.close() diff --git a/backend/core/exceptions.py b/backend/core/exceptions.py index 0d6a1b5..0f16d79 100644 --- a/backend/core/exceptions.py +++ b/backend/core/exceptions.py @@ -8,20 +8,26 @@ from starlette.status import ( HTTP_409_CONFLICT, ) + def bad_request_exception(detail: str = "Bad Request"): return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail) + def unauthorized_exception(detail: str = "Unauthorized"): return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail) + def forbidden_exception(detail: str = "Forbidden"): return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail) + def not_found_exception(detail: str = "Not Found"): return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail) + def internal_server_error_exception(detail: str = "Internal Server Error"): return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail) + def conflict_exception(detail: str = "Conflict"): - return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail) \ No newline at end of file + return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail) diff --git a/backend/main.py b/backend/main.py index 5af39b7..46eb1f7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,11 +11,12 @@ import logging # import all models to ensure they are registered before create_all -logging.getLogger('passlib').setLevel(logging.ERROR) # fix bc package logging is broken +logging.getLogger("passlib").setLevel(logging.ERROR) # fix bc package logging is broken + # Create DB tables (remove in production; use migrations instead) def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]: - + @asynccontextmanager async def lifespan(app: FastAPI): # Base.metadata.drop_all(bind=get_engine()) @@ -24,6 +25,7 @@ def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any] return lifespan + lifespan = lifespan_factory() app = FastAPI(lifespan=lifespan) @@ -34,17 +36,18 @@ app.include_router(router) app.add_middleware( CORSMiddleware, allow_origins=[ - "http://localhost:8081", # Keep for web testing if needed - "http://192.168.1.9:8081", # Add your mobile device/emulator origin (adjust port if needed) + "http://localhost:8081", # Keep for web testing if needed + "http://192.168.1.9:8081", # Add your mobile device/emulator origin (adjust port if needed) "http://192.168.255.221:8081", # Add other origins if necessary, e.g., production frontend URL ], allow_credentials=True, allow_methods=["*"], - allow_headers=["*"] + allow_headers=["*"], ) + # Health endpoint @app.get("/api/health") def health(): - return {"status": "ok"} \ No newline at end of file + return {"status": "ok"} diff --git a/backend/modules/admin/api.py b/backend/modules/admin/api.py index 120ab82..9b3c260 100644 --- a/backend/modules/admin/api.py +++ b/backend/modules/admin/api.py @@ -1,7 +1,7 @@ # modules/admin/api.py from typing import Annotated -from fastapi import APIRouter, Depends # Import Body -from pydantic import BaseModel # Import BaseModel +from fastapi import APIRouter, Depends # Import Body +from pydantic import BaseModel # Import BaseModel from sqlalchemy.orm import Session from core.database import get_db from modules.auth.dependencies import admin_only @@ -9,14 +9,17 @@ from .tasks import cleardb router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)]) + # Define a Pydantic model for the request body class ClearDbRequest(BaseModel): hard: bool + @router.get("/") def read_admin(): return {"message": "Admin route"} + # Change to POST and use the request body model @router.post("/cleardb") def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]): @@ -25,6 +28,6 @@ def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]): 'hard'=True: Drop and recreate all tables. 'hard'=False: Delete data from tables except users. """ - hard = payload.hard # Get 'hard' from the payload + hard = payload.hard # Get 'hard' from the payload cleardb.delay(hard) return {"message": "Clearing database in the background", "hard": hard} diff --git a/backend/modules/admin/services.py b/backend/modules/admin/services.py index c1becdb..00b45ac 100644 --- a/backend/modules/admin/services.py +++ b/backend/modules/admin/services.py @@ -1,4 +1,4 @@ # modules/admin/services.py -## temp \ No newline at end of file +## temp diff --git a/backend/modules/admin/tasks.py b/backend/modules/admin/tasks.py index 3c03c28..1ba029f 100644 --- a/backend/modules/admin/tasks.py +++ b/backend/modules/admin/tasks.py @@ -1,5 +1,6 @@ from core.celery_app import celery_app + @celery_app.task def cleardb(hard: bool): """ @@ -32,4 +33,4 @@ def cleardb(hard: bool): print(f"Deleting table: {table_name}") db.execute(table.delete()) db.commit() - return {"message": "Database cleared"} \ No newline at end of file + return {"message": "Database cleared"} diff --git a/backend/modules/auth/api.py b/backend/modules/auth/api.py index d959e6b..29ded0b 100644 --- a/backend/modules/auth/api.py +++ b/backend/modules/auth/api.py @@ -3,9 +3,24 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from jose import JWTError from modules.auth.models import User -from modules.auth.schemas import UserCreate, UserResponse, Token, RefreshTokenRequest, LogoutRequest +from modules.auth.schemas import ( + UserCreate, + UserResponse, + Token, + RefreshTokenRequest, + LogoutRequest, +) from modules.auth.services import create_user -from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens +from modules.auth.security import ( + TokenType, + get_current_user, + oauth2_scheme, + create_access_token, + create_refresh_token, + verify_token, + authenticate_user, + blacklist_tokens, +) from sqlalchemy.orm import Session from typing import Annotated from core.database import get_db @@ -15,12 +30,19 @@ from core.exceptions import unauthorized_exception router = APIRouter(prefix="/auth", tags=["auth"]) -@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) + +@router.post( + "/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED +) def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]): return create_user(user.username, user.password, user.name, db) + @router.post("/login", response_model=Token) -def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]): +def login( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + db: Annotated[Session, Depends(get_db)], +): """ Authenticate user and return JWT tokens in the response body. """ @@ -30,39 +52,53 @@ def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annota status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", ) - - access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)) + + access_token = create_access_token( + data={"sub": user.username}, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), + ) refresh_token = create_refresh_token(data={"sub": user.username}) - return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"} + return { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "bearer", + } + @router.post("/refresh") -def refresh_token(payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]): +def refresh_token( + payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)] +): print("Refreshing token...") refresh_token = payload.refresh_token if not refresh_token: raise unauthorized_exception("Refresh token missing in request body") - user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db) + user_data = verify_token( + refresh_token, expected_token_type=TokenType.REFRESH, db=db + ) if not user_data: raise unauthorized_exception("Invalid refresh token") new_access_token = create_access_token(data={"sub": user_data.username}) return {"access_token": new_access_token, "token_type": "bearer"} + @router.post("/logout") -def logout(payload: LogoutRequest, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme)): +def logout( + payload: LogoutRequest, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + access_token: str = Depends(oauth2_scheme), +): try: refresh_token = payload.refresh_token if not refresh_token: raise unauthorized_exception("Refresh token not found in request body") - - blacklist_tokens( - access_token=access_token, - refresh_token=refresh_token, - db=db - ) + + blacklist_tokens(access_token=access_token, refresh_token=refresh_token, db=db) return {"message": "Logged out successfully"} except JWTError: - raise unauthorized_exception("Invalid token") \ No newline at end of file + raise unauthorized_exception("Invalid token") diff --git a/backend/modules/auth/dependencies.py b/backend/modules/auth/dependencies.py index b8eb91c..ec84bea 100644 --- a/backend/modules/auth/dependencies.py +++ b/backend/modules/auth/dependencies.py @@ -5,14 +5,18 @@ from modules.auth.schemas import UserRole from modules.auth.models import User from core.exceptions import forbidden_exception + class RoleChecker: def __init__(self, allowed_roles: list[UserRole]): self.allowed_roles = allowed_roles def __call__(self, user: User = Depends(get_current_user)): if user.role not in self.allowed_roles: - raise forbidden_exception("You do not have permission to perform this action.") + raise forbidden_exception( + "You do not have permission to perform this action." + ) return user + admin_only = RoleChecker([UserRole.ADMIN]) -any_user = RoleChecker([UserRole.ADMIN, UserRole.USER]) \ No newline at end of file +any_user = RoleChecker([UserRole.ADMIN, UserRole.USER]) diff --git a/backend/modules/auth/models.py b/backend/modules/auth/models.py index e650234..03e3bc0 100644 --- a/backend/modules/auth/models.py +++ b/backend/modules/auth/models.py @@ -4,10 +4,12 @@ from sqlalchemy import Column, Integer, String, Enum, DateTime from sqlalchemy.orm import relationship from enum import Enum as PyEnum + class UserRole(str, PyEnum): ADMIN = "admin" USER = "user" + class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) diff --git a/backend/modules/auth/schemas.py b/backend/modules/auth/schemas.py index 9668549..e04cd50 100644 --- a/backend/modules/auth/schemas.py +++ b/backend/modules/auth/schemas.py @@ -2,33 +2,41 @@ from enum import Enum as PyEnum from pydantic import BaseModel + class Token(BaseModel): access_token: str token_type: str refresh_token: str | None = None + class TokenData(BaseModel): username: str | None = None scopes: list[str] = [] + class RefreshTokenRequest(BaseModel): refresh_token: str + class LogoutRequest(BaseModel): refresh_token: str + class UserRole(str, PyEnum): ADMIN = "admin" USER = "user" + class UserCreate(BaseModel): username: str password: str name: str + class UserPatch(BaseModel): name: str | None = None - + + class UserResponse(BaseModel): uuid: str username: str diff --git a/backend/modules/auth/security.py b/backend/modules/auth/security.py index 521be25..1dd2f0d 100644 --- a/backend/modules/auth/security.py +++ b/backend/modules/auth/security.py @@ -18,6 +18,7 @@ from modules.auth.schemas import TokenData oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") + class TokenType(str, Enum): ACCESS = "access" REFRESH = "refresh" @@ -25,11 +26,13 @@ class TokenType(str, Enum): password_hasher = PasswordHasher() + def hash_password(password: str) -> str: """Hash a password with Argon2 (and optional pepper).""" peppered_password = password + settings.PEPPER # Prepend/append pepper return password_hasher.hash(peppered_password) + def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hashed version using Argon2.""" peppered_password = plain_password + settings.PEPPER @@ -38,6 +41,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: except VerifyMismatchError: return False + def authenticate_user(username: str, password: str, db: Session) -> User | None: """ Authenticate a user by checking username/password against the database. @@ -45,41 +49,46 @@ def authenticate_user(username: str, password: str, db: Session) -> User | None: """ # Get user from database user = db.query(User).filter(User.username == username).first() - + # If user not found or password doesn't match if not user or not verify_password(password, user.hashed_password): return None - + return user + def create_access_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(timezone.utc) + timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) # expire = datetime.now(timezone.utc) + timedelta(seconds=5) to_encode.update({"exp": expire, "token_type": TokenType.ACCESS}) return jwt.encode( - to_encode, - settings.JWT_SECRET_KEY, - algorithm=settings.JWT_ALGORITHM + to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) + def create_refresh_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + expire = datetime.now(timezone.utc) + timedelta( + days=settings.REFRESH_TOKEN_EXPIRE_DAYS + ) to_encode.update({"exp": expire, "token_type": TokenType.REFRESH}) return jwt.encode( - to_encode, - settings.JWT_SECRET_KEY, - algorithm=settings.JWT_ALGORITHM + to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) -def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None: + +def verify_token( + token: str, expected_token_type: TokenType, db: Session +) -> TokenData | None: """Verify a JWT token and return TokenData if valid. Parameters @@ -96,24 +105,32 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok TokenData | None TokenData instance if the token is valid, None otherwise. """ - is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None + is_blacklisted = ( + db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() + is not None + ) if is_blacklisted: return None - + try: - payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) + payload = jwt.decode( + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] + ) username: str = payload.get("sub") token_type: str = payload.get("token_type") - + if username is None or token_type != expected_token_type: return None - + return TokenData(username=username) except JWTError: return None -def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User: + +def get_current_user( + db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme) +) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -121,26 +138,28 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen ) # Check if the token is blacklisted - is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None + is_blacklisted = ( + db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() + is not None + ) if is_blacklisted: raise credentials_exception try: payload = jwt.decode( - token, - settings.JWT_SECRET_KEY, - algorithms=[settings.JWT_ALGORITHM] + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) username: str = payload.get("sub") if username is None: raise credentials_exception except JWTError: raise credentials_exception - + user: User = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user + def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None: """Blacklist both access and refresh tokens. @@ -154,7 +173,9 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None Database session to perform the operation. """ for token in [access_token, refresh_token]: - payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) + payload = jwt.decode( + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] + ) expires_at = datetime.fromtimestamp(payload.get("exp")) # Add the token to the blacklist @@ -163,10 +184,13 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None db.commit() + def blacklist_token(token: str, db: Session) -> None: - payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) + payload = jwt.decode( + token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] + ) expires_at = datetime.fromtimestamp(payload.get("exp")) - + # Add the token to the blacklist blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at) db.add(blacklisted_token) diff --git a/backend/modules/auth/services.py b/backend/modules/auth/services.py index fadf03e..3f50a37 100644 --- a/backend/modules/auth/services.py +++ b/backend/modules/auth/services.py @@ -20,11 +20,13 @@ def create_user(username: str, password: str, name: str, db: Session) -> UserRes existing_user = db.query(User).filter(User.username == username).first() if existing_user: raise conflict_exception("Username already exists") - + hashed_password = hash_password(password) user_uuid = str(uuid.uuid4()) - user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid) + user = User( + username=username, hashed_password=hashed_password, name=name, uuid=user_uuid + ) db.add(user) db.commit() db.refresh(user) # Loads the generated ID - return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic \ No newline at end of file + return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic diff --git a/backend/modules/calendar/api.py b/backend/modules/calendar/api.py index fdf079c..cdb1cf6 100644 --- a/backend/modules/calendar/api.py +++ b/backend/modules/calendar/api.py @@ -6,50 +6,63 @@ from typing import List, Optional from modules.auth.dependencies import get_current_user from core.database import get_db from modules.auth.models import User -from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate, CalendarEventResponse -from modules.calendar.service import create_calendar_event, get_calendar_event_by_id, get_calendar_events, update_calendar_event, delete_calendar_event +from modules.calendar.schemas import ( + CalendarEventCreate, + CalendarEventUpdate, + CalendarEventResponse, +) +from modules.calendar.service import ( + create_calendar_event, + get_calendar_event_by_id, + get_calendar_events, + update_calendar_event, + delete_calendar_event, +) router = APIRouter(prefix="/calendar", tags=["calendar"]) -@router.post("/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED) + +@router.post( + "/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED +) def create_event( event: CalendarEventCreate, user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): return create_calendar_event(db, user.id, event) + @router.get("/events", response_model=List[CalendarEventResponse]) def get_events( user: User = Depends(get_current_user), db: Session = Depends(get_db), start: Optional[datetime] = None, - end: Optional[datetime] = None + end: Optional[datetime] = None, ): return get_calendar_events(db, user.id, start, end) + @router.get("/events/{event_id}", response_model=CalendarEventResponse) def get_event_by_id( - event_id: int, - user: User = Depends(get_current_user), - db: Session = Depends(get_db) + event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db) ): event = get_calendar_event_by_id(db, user.id, event_id) return event + @router.patch("/events/{event_id}", response_model=CalendarEventResponse) def update_event( event_id: int, event: CalendarEventUpdate, user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): return update_calendar_event(db, user.id, event_id, event) + @router.delete("/events/{event_id}", status_code=204) def delete_event( - event_id: int, - user: User = Depends(get_current_user), - db: Session = Depends(get_db) + event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db) ): - delete_calendar_event(db, user.id, event_id) \ No newline at end of file + delete_calendar_event(db, user.id, event_id) diff --git a/backend/modules/calendar/models.py b/backend/modules/calendar/models.py index 17cde49..f5d1554 100644 --- a/backend/modules/calendar/models.py +++ b/backend/modules/calendar/models.py @@ -1,8 +1,17 @@ # modules/calendar/models.py -from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Boolean # Add Boolean +from sqlalchemy import ( + Column, + Integer, + String, + DateTime, + ForeignKey, + JSON, + Boolean, +) # Add Boolean from sqlalchemy.orm import relationship from core.database import Base + class CalendarEvent(Base): __tablename__ = "calendar_events" @@ -12,10 +21,12 @@ class CalendarEvent(Base): start = Column(DateTime, nullable=False) end = Column(DateTime) location = Column(String) - all_day = Column(Boolean, default=False) # Add all_day column + all_day = Column(Boolean, default=False) # Add all_day column tags = Column(JSON) - color = Column(String) # hex code for color - user_id = Column(Integer, ForeignKey("users.id"), nullable=False) # <-- Relationship + color = Column(String) # hex code for color + user_id = Column( + Integer, ForeignKey("users.id"), nullable=False + ) # <-- Relationship # Bi-directional relationship (for eager loading) - user = relationship("User", back_populates="calendar_events") \ No newline at end of file + user = relationship("User", back_populates="calendar_events") diff --git a/backend/modules/calendar/schemas.py b/backend/modules/calendar/schemas.py index 760da49..7505c1a 100644 --- a/backend/modules/calendar/schemas.py +++ b/backend/modules/calendar/schemas.py @@ -1,7 +1,8 @@ # modules/calendar/schemas.py from datetime import datetime -from pydantic import BaseModel, field_validator # Add field_validator -from typing import List, Optional # Add List and Optional +from pydantic import BaseModel, field_validator # Add field_validator +from typing import List, Optional # Add List and Optional + # Base schema for common fields, including tags class CalendarEventBase(BaseModel): @@ -10,21 +11,23 @@ class CalendarEventBase(BaseModel): start: datetime end: Optional[datetime] = None location: Optional[str] = None - color: Optional[str] = None # Assuming color exists - all_day: Optional[bool] = None # Add all_day field - tags: Optional[List[str]] = None # Add optional tags + color: Optional[str] = None # Assuming color exists + all_day: Optional[bool] = None # Add all_day field + tags: Optional[List[str]] = None # Add optional tags - @field_validator('tags', mode='before') + @field_validator("tags", mode="before") @classmethod def tags_validate_null_string(cls, v): if v == "Null": return None return v + # Schema for creating an event (inherits from Base) class CalendarEventCreate(CalendarEventBase): pass + # Schema for updating an event (all fields optional) class CalendarEventUpdate(BaseModel): title: Optional[str] = None @@ -33,23 +36,24 @@ class CalendarEventUpdate(BaseModel): end: Optional[datetime] = None location: Optional[str] = None color: Optional[str] = None - all_day: Optional[bool] = None # Add all_day field - tags: Optional[List[str]] = None # Add optional tags for update + all_day: Optional[bool] = None # Add all_day field + tags: Optional[List[str]] = None # Add optional tags for update - @field_validator('tags', mode='before') + @field_validator("tags", mode="before") @classmethod def tags_validate_null_string(cls, v): if v == "Null": return None return v + # Schema for the response (inherits from Base, adds ID and user_id) class CalendarEventResponse(CalendarEventBase): id: int user_id: int - tags: List[str] # Keep as List[str], remove default [] + tags: List[str] # Keep as List[str], remove default [] - @field_validator('tags', mode='before') + @field_validator("tags", mode="before") @classmethod def tags_validate_none_to_list(cls, v): # If the value from the source object (e.g., ORM model) is None, @@ -59,4 +63,4 @@ class CalendarEventResponse(CalendarEventBase): return v class Config: - from_attributes = True \ No newline at end of file + from_attributes = True diff --git a/backend/modules/calendar/service.py b/backend/modules/calendar/service.py index d973aca..9e0afcb 100644 --- a/backend/modules/calendar/service.py +++ b/backend/modules/calendar/service.py @@ -1,25 +1,34 @@ # modules/calendar/service.py from sqlalchemy.orm import Session -from sqlalchemy import or_ # Import or_ +from sqlalchemy import or_ # Import or_ from datetime import datetime from modules.calendar.models import CalendarEvent from core.exceptions import not_found_exception -from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate # Import schemas +from modules.calendar.schemas import ( + CalendarEventCreate, + CalendarEventUpdate, +) # Import schemas + def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate): # Ensure tags is None if not provided or empty list, matching model tags_to_store = event_data.tags if event_data.tags else None event = CalendarEvent( - **event_data.model_dump(exclude={'tags'}), # Use model_dump and exclude tags initially - tags=tags_to_store, # Set tags separately - user_id=user_id + **event_data.model_dump( + exclude={"tags"} + ), # Use model_dump and exclude tags initially + tags=tags_to_store, # Set tags separately + user_id=user_id, ) db.add(event) db.commit() db.refresh(event) return event -def get_calendar_events(db: Session, user_id: int, start: datetime | None, end: datetime | None): + +def get_calendar_events( + db: Session, user_id: int, start: datetime | None, end: datetime | None +): """ Retrieves calendar events for a user, optionally filtered by a date range. @@ -46,9 +55,13 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end: query = query.filter( or_( # Case 1: Event has duration and overlaps - (CalendarEvent.end is not None) & (CalendarEvent.start < end) & (CalendarEvent.end > start), + (CalendarEvent.end is not None) + & (CalendarEvent.start < end) + & (CalendarEvent.end > start), # Case 2: Event is a point event within the range - (CalendarEvent.end is None) & (CalendarEvent.start >= start) & (CalendarEvent.start < end) + (CalendarEvent.end is None) + & (CalendarEvent.start >= start) + & (CalendarEvent.start < end), ) ) # If only start is provided, filter events starting on or after start @@ -60,37 +73,41 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end: elif end: # Includes events with duration ending <= end (or starting before end if end is None) # Includes point events occurring < end - query = query.filter( + query = query.filter( or_( # Event ends before the specified end time (CalendarEvent.end is not None) & (CalendarEvent.end <= end), # Point event occurs before the specified end time - (CalendarEvent.end is None) & (CalendarEvent.start < end) + (CalendarEvent.end is None) & (CalendarEvent.start < end), ) ) - # Alternative interpretation for "ending before end": include events that *start* before end - # query = query.filter(CalendarEvent.start < end) + # Alternative interpretation for "ending before end": include events that *start* before end + # query = query.filter(CalendarEvent.start < end) + return query.order_by(CalendarEvent.start).all() # Order by start time - return query.order_by(CalendarEvent.start).all() # Order by start time def get_calendar_event_by_id(db: Session, user_id: int, event_id: int): - event = db.query(CalendarEvent).filter( - CalendarEvent.id == event_id, - CalendarEvent.user_id == user_id - ).first() + event = ( + db.query(CalendarEvent) + .filter(CalendarEvent.id == event_id, CalendarEvent.user_id == user_id) + .first() + ) if not event: raise not_found_exception() return event -def update_calendar_event(db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate): - event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check + +def update_calendar_event( + db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate +): + event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check # Use model_dump with exclude_unset=True to only update provided fields update_data = event_data.model_dump(exclude_unset=True) for key, value in update_data.items(): # Ensure tags is handled correctly (set to None if empty list provided) - if key == 'tags' and isinstance(value, list) and not value: + if key == "tags" and isinstance(value, list) and not value: setattr(event, key, None) else: setattr(event, key, value) @@ -99,7 +116,8 @@ def update_calendar_event(db: Session, user_id: int, event_id: int, event_data: db.refresh(event) return event + def delete_calendar_event(db: Session, user_id: int, event_id: int): - event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check + event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check db.delete(event) - db.commit() \ No newline at end of file + db.commit() diff --git a/backend/modules/nlp/api.py b/backend/modules/nlp/api.py index 14e3b8a..0ce4878 100644 --- a/backend/modules/nlp/api.py +++ b/backend/modules/nlp/api.py @@ -7,13 +7,27 @@ from core.database import get_db from modules.auth.dependencies import get_current_user from modules.auth.models import User + # Import the new service functions and Enum -from modules.nlp.service import process_request, ask_ai, save_chat_message, get_chat_history, MessageSender +from modules.nlp.service import ( + process_request, + ask_ai, + save_chat_message, + get_chat_history, + MessageSender, +) + # Import the response schema and the new ChatMessage model for response type hinting from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse -from modules.calendar.service import create_calendar_event, get_calendar_events, update_calendar_event, delete_calendar_event +from modules.calendar.service import ( + create_calendar_event, + get_calendar_events, + update_calendar_event, + delete_calendar_event, +) from modules.calendar.models import CalendarEvent from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate + # Import TODO services, schemas, and model from modules.todo import service as todo_service from modules.todo.models import Todo @@ -21,17 +35,20 @@ from modules.todo.schemas import TodoCreate, TodoUpdate from pydantic import BaseModel from datetime import datetime + class ChatMessageResponse(BaseModel): id: int - sender: MessageSender # Use the enum directly + sender: MessageSender # Use the enum directly text: str timestamp: datetime class Config: - from_attributes = True # Allow Pydantic to work with ORM models - + from_attributes = True # Allow Pydantic to work with ORM models + + router = APIRouter(prefix="/nlp", tags=["nlp"]) + # Helper to format calendar events (expects list of CalendarEvent models) def format_calendar_events(events: List[CalendarEvent]) -> List[str]: if not events: @@ -39,12 +56,15 @@ def format_calendar_events(events: List[CalendarEvent]) -> List[str]: formatted = ["Here are the events:"] for event in events: # Access attributes directly from the model instance - start_str = event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time" + start_str = ( + event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time" + ) end_str = event.end.strftime("%H:%M") if event.end else "" title = event.title or "Untitled Event" formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})") return formatted + # Helper to format TODO items (expects list of Todo models) def format_todos(todos: List[Todo]) -> List[str]: if not todos: @@ -54,19 +74,28 @@ def format_todos(todos: List[Todo]) -> List[str]: status = "[X]" if todo.complete else "[ ]" date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else "" remind_str = " (Reminder)" if todo.remind else "" - formatted.append(f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})") + formatted.append( + f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})" + ) return formatted + # Update the response model for the endpoint @router.post("/process-command", response_model=ProcessCommandResponse) -def process_command(request_data: ProcessCommandRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): +def process_command( + request_data: ProcessCommandRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): """ Process the user command, save messages, execute action, save response, and return user-friendly responses. """ user_input = request_data.user_input # --- Save User Message --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.USER, text=user_input) + save_chat_message( + db, user_id=current_user.id, sender=MessageSender.USER, text=user_input + ) # ------------------------ command_data = process_request(user_input) @@ -74,11 +103,13 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De params = command_data["params"] response_text = command_data["response_text"] - responses = [response_text] # Start with the initial response + responses = [response_text] # Start with the initial response # --- Save Initial AI Response --- # Save the first response generated by process_request - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text) + save_chat_message( + db, user_id=current_user.id, sender=MessageSender.AI, text=response_text + ) # ----------------------------- if intent == "error": @@ -97,139 +128,233 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De ai_answer = ask_ai(**params) responses.append(ai_answer) # --- Save Additional AI Response --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer) + save_chat_message( + db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer + ) # --------------------------------- return ProcessCommandResponse(responses=responses) case "get_calendar_events": - events: List[CalendarEvent] = get_calendar_events(db, current_user.id, **params) + events: List[CalendarEvent] = get_calendar_events( + db, current_user.id, **params + ) formatted_responses = format_calendar_events(events) responses.extend(formatted_responses) # --- Save Additional AI Responses --- for resp in formatted_responses: - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp) + save_chat_message( + db, user_id=current_user.id, sender=MessageSender.AI, text=resp + ) # ---------------------------------- return ProcessCommandResponse(responses=responses) case "add_calendar_event": event_data = CalendarEventCreate(**params) created_event = create_calendar_event(db, current_user.id, event_data) - start_str = created_event.start.strftime("%Y-%m-%d %H:%M") if created_event.start else "No start time" + start_str = ( + created_event.start.strftime("%Y-%m-%d %H:%M") + if created_event.start + else "No start time" + ) title = created_event.title or "Untitled Event" add_response = f"Added: {title} starting at {start_str}." responses.append(add_response) # --- Save Additional AI Response --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=add_response, + ) # --------------------------------- return ProcessCommandResponse(responses=responses) case "update_calendar_event": - event_id = params.pop('event_id', None) + event_id = params.pop("event_id", None) if event_id is None: # Save the error message before raising error_msg = "Event ID is required for update." - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=error_msg, + ) raise HTTPException(status_code=400, detail=error_msg) event_data = CalendarEventUpdate(**params) - updated_event = update_calendar_event(db, current_user.id, event_id, event_data=event_data) + updated_event = update_calendar_event( + db, current_user.id, event_id, event_data=event_data + ) title = updated_event.title or "Untitled Event" update_response = f"Updated event ID {updated_event.id}: {title}." responses.append(update_response) # --- Save Additional AI Response --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=update_response, + ) # --------------------------------- return ProcessCommandResponse(responses=responses) case "delete_calendar_event": - event_id = params.get('event_id') + event_id = params.get("event_id") if event_id is None: # Save the error message before raising error_msg = "Event ID is required for delete." - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=error_msg, + ) raise HTTPException(status_code=400, detail=error_msg) delete_calendar_event(db, current_user.id, event_id) delete_response = f"Deleted event ID {event_id}." responses.append(delete_response) # --- Save Additional AI Response --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=delete_response, + ) # --------------------------------- return ProcessCommandResponse(responses=responses) - # --- Add TODO Cases --- + # --- Add TODO Cases --- case "get_todos": - todos: List[Todo] = todo_service.get_todos(db, user=current_user, **params) + todos: List[Todo] = todo_service.get_todos( + db, user=current_user, **params + ) formatted_responses = format_todos(todos) responses.extend(formatted_responses) # --- Save Additional AI Responses --- for resp in formatted_responses: - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp) + save_chat_message( + db, user_id=current_user.id, sender=MessageSender.AI, text=resp + ) # ---------------------------------- return ProcessCommandResponse(responses=responses) case "add_todo": todo_data = TodoCreate(**params) - created_todo = todo_service.create_todo(db, todo=todo_data, user=current_user) - add_response = f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})." + created_todo = todo_service.create_todo( + db, todo=todo_data, user=current_user + ) + add_response = ( + f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})." + ) responses.append(add_response) # --- Save Additional AI Response --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=add_response, + ) # --------------------------------- return ProcessCommandResponse(responses=responses) case "update_todo": - todo_id = params.pop('todo_id', None) + todo_id = params.pop("todo_id", None) if todo_id is None: error_msg = "TODO ID is required for update." - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=error_msg, + ) raise HTTPException(status_code=400, detail=error_msg) todo_data = TodoUpdate(**params) - updated_todo = todo_service.update_todo(db, todo_id=todo_id, todo_update=todo_data, user=current_user) - update_response = f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'." - if 'complete' in params: - status = "complete" if params['complete'] else "incomplete" + updated_todo = todo_service.update_todo( + db, todo_id=todo_id, todo_update=todo_data, user=current_user + ) + update_response = ( + f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'." + ) + if "complete" in params: + status = "complete" if params["complete"] else "incomplete" update_response += f" Marked as {status}." responses.append(update_response) # --- Save Additional AI Response --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=update_response, + ) # --------------------------------- return ProcessCommandResponse(responses=responses) case "delete_todo": - todo_id = params.get('todo_id') + todo_id = params.get("todo_id") if todo_id is None: error_msg = "TODO ID is required for delete." - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=error_msg, + ) raise HTTPException(status_code=400, detail=error_msg) - deleted_todo = todo_service.delete_todo(db, todo_id=todo_id, user=current_user) - delete_response = f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'." + deleted_todo = todo_service.delete_todo( + db, todo_id=todo_id, user=current_user + ) + delete_response = ( + f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'." + ) responses.append(delete_response) # --- Save Additional AI Response --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response) + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=delete_response, + ) # --------------------------------- return ProcessCommandResponse(responses=responses) # --- End TODO Cases --- - case _: - print(f"Warning: Unhandled intent '{intent}' reached api.py match statement.") + case _: + print( + f"Warning: Unhandled intent '{intent}' reached api.py match statement." + ) # The initial response_text was already saved return ProcessCommandResponse(responses=responses) except HTTPException as http_exc: # Don't save again if already saved before raising - if http_exc.status_code != 400 or ('event_id' not in http_exc.detail.lower()): - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=http_exc.detail) + if http_exc.status_code != 400 or ("event_id" not in http_exc.detail.lower()): + save_chat_message( + db, + user_id=current_user.id, + sender=MessageSender.AI, + text=http_exc.detail, + ) raise http_exc except Exception as e: print(f"Error executing intent '{intent}': {e}") - error_response = "Sorry, I encountered an error while trying to perform that action." + error_response = ( + "Sorry, I encountered an error while trying to perform that action." + ) # --- Save Final Error AI Response --- - save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_response) + save_chat_message( + db, user_id=current_user.id, sender=MessageSender.AI, text=error_response + ) # ---------------------------------- return ProcessCommandResponse(responses=[error_response]) + @router.get("/history", response_model=List[ChatMessageResponse]) -def read_chat_history(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): +def read_chat_history( + current_user: User = Depends(get_current_user), db: Session = Depends(get_db) +): """Retrieves the last 50 chat messages for the current user.""" history = get_chat_history(db, user_id=current_user.id, limit=50) return history -# ------------------------------------- \ No newline at end of file + + +# ------------------------------------- diff --git a/backend/modules/nlp/models.py b/backend/modules/nlp/models.py index ddeba47..120fcee 100644 --- a/backend/modules/nlp/models.py +++ b/backend/modules/nlp/models.py @@ -1,4 +1,3 @@ -\ # /home/cdp/code/MAIA/backend/modules/nlp/models.py from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey, Enum as SQLEnum from sqlalchemy.orm import relationship @@ -7,10 +6,12 @@ import enum from core.database import Base + class MessageSender(enum.Enum): USER = "user" AI = "ai" + class ChatMessage(Base): __tablename__ = "chat_messages" @@ -20,4 +21,4 @@ class ChatMessage(Base): text = Column(Text, nullable=False) timestamp = Column(DateTime(timezone=True), server_default=func.now()) - owner = relationship("User") # Relationship to the User model + owner = relationship("User") # Relationship to the User model diff --git a/backend/modules/nlp/schemas.py b/backend/modules/nlp/schemas.py index 48b1870..d2a2807 100644 --- a/backend/modules/nlp/schemas.py +++ b/backend/modules/nlp/schemas.py @@ -2,9 +2,11 @@ from pydantic import BaseModel from typing import List + class ProcessCommandRequest(BaseModel): user_input: str + class ProcessCommandResponse(BaseModel): responses: List[str] # Optional: Keep details if needed for specific frontend logic beyond display diff --git a/backend/modules/nlp/service.py b/backend/modules/nlp/service.py index f1e2046..ef4bca4 100644 --- a/backend/modules/nlp/service.py +++ b/backend/modules/nlp/service.py @@ -1,11 +1,11 @@ # modules/nlp/service.py from sqlalchemy.orm import Session -from sqlalchemy import desc # Import desc for ordering +from sqlalchemy import desc # Import desc for ordering from google import genai import json from datetime import datetime, timezone -from typing import List # Import List +from typing import List # Import List # Import the new model and Enum from .models import ChatMessage, MessageSender @@ -14,7 +14,8 @@ from core.config import settings client = genai.Client(api_key=settings.GOOGLE_API_KEY) ### Base prompt for MAIA, used for inital user requests -SYSTEM_PROMPT = """ +SYSTEM_PROMPT = ( + """ You are MAIA - My AI Assistant. Your job is to parse user requests into structured JSON commands and generate a user-facing response text. Available functions/intents: @@ -109,8 +110,11 @@ MAIA: "response_text": "Okay, I've deleted task 2 from your list." } -The datetime right now is """+str(datetime.now(timezone.utc))+""". +The datetime right now is """ + + str(datetime.now(timezone.utc)) + + """. """ +) ### Prompt for MAIA to forward user request to AI SYSTEM_FORWARD_PROMPT = f""" @@ -123,6 +127,7 @@ Here is the user request: # --- Chat History Service Functions --- + def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: str): """Saves a chat message to the database.""" db_message = ChatMessage(user_id=user_id, sender=sender, text=text) @@ -131,16 +136,21 @@ def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: st db.refresh(db_message) return db_message + def get_chat_history(db: Session, user_id: int, limit: int = 50) -> List[ChatMessage]: """Retrieves the last 'limit' chat messages for a user.""" - return db.query(ChatMessage)\ - .filter(ChatMessage.user_id == user_id)\ - .order_by(desc(ChatMessage.timestamp))\ - .limit(limit)\ - .all()[::-1] # Reverse to get oldest first for display order + return ( + db.query(ChatMessage) + .filter(ChatMessage.user_id == user_id) + .order_by(desc(ChatMessage.timestamp)) + .limit(limit) + .all()[::-1] + ) # Reverse to get oldest first for display order + # --- Existing NLP Service Functions --- + def process_request(request: str): """ Process the user request using the Google GenAI API. @@ -152,7 +162,7 @@ def process_request(request: str): config={ "temperature": 0.3, # Less creativity, more factual "response_mime_type": "application/json", - } + }, ) # Parse the JSON response @@ -160,7 +170,9 @@ def process_request(request: str): parsed_response = json.loads(response.text) # Validate required fields if not all(k in parsed_response for k in ("intent", "params", "response_text")): - raise ValueError("AI response missing required fields (intent, params, response_text)") + raise ValueError( + "AI response missing required fields (intent, params, response_text)" + ) return parsed_response except (json.JSONDecodeError, ValueError) as e: print(f"Error parsing AI response: {e}") @@ -169,9 +181,10 @@ def process_request(request: str): return { "intent": "error", "params": {}, - "response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?" + "response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?", } + def ask_ai(request: str): """ Ask the AI a question. @@ -179,6 +192,6 @@ def ask_ai(request: str): """ response = client.models.generate_content( model="gemini-2.0-flash", - contents=SYSTEM_FORWARD_PROMPT+request, + contents=SYSTEM_FORWARD_PROMPT + request, ) - return response.text \ No newline at end of file + return response.text diff --git a/backend/modules/todo/api.py b/backend/modules/todo/api.py index ae57ceb..29cb261 100644 --- a/backend/modules/todo/api.py +++ b/backend/modules/todo/api.py @@ -5,58 +5,65 @@ from typing import List from . import service, schemas from core.database import get_db -from modules.auth.dependencies import get_current_user # Corrected import -from modules.auth.models import User # Assuming User model is in auth.models +from modules.auth.dependencies import get_current_user # Corrected import +from modules.auth.models import User # Assuming User model is in auth.models router = APIRouter( prefix="/todos", tags=["todos"], - dependencies=[Depends(get_current_user)], # Corrected dependency + dependencies=[Depends(get_current_user)], # Corrected dependency responses={404: {"description": "Not found"}}, ) + @router.post("/", response_model=schemas.Todo, status_code=status.HTTP_201_CREATED) def create_todo_endpoint( todo: schemas.TodoCreate, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) # Corrected dependency + current_user: User = Depends(get_current_user), # Corrected dependency ): return service.create_todo(db=db, todo=todo, user=current_user) + @router.get("/", response_model=List[schemas.Todo]) def read_todos_endpoint( skip: int = 0, limit: int = 100, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) # Corrected dependency + current_user: User = Depends(get_current_user), # Corrected dependency ): todos = service.get_todos(db=db, user=current_user, skip=skip, limit=limit) return todos + @router.get("/{todo_id}", response_model=schemas.Todo) def read_todo_endpoint( todo_id: int, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) # Corrected dependency + current_user: User = Depends(get_current_user), # Corrected dependency ): db_todo = service.get_todo(db=db, todo_id=todo_id, user=current_user) if db_todo is None: raise HTTPException(status_code=404, detail="Todo not found") return db_todo + @router.put("/{todo_id}", response_model=schemas.Todo) def update_todo_endpoint( todo_id: int, todo_update: schemas.TodoUpdate, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) # Corrected dependency + current_user: User = Depends(get_current_user), # Corrected dependency ): - return service.update_todo(db=db, todo_id=todo_id, todo_update=todo_update, user=current_user) + return service.update_todo( + db=db, todo_id=todo_id, todo_update=todo_update, user=current_user + ) + @router.delete("/{todo_id}", response_model=schemas.Todo) def delete_todo_endpoint( todo_id: int, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) # Corrected dependency + current_user: User = Depends(get_current_user), # Corrected dependency ): return service.delete_todo(db=db, todo_id=todo_id, user=current_user) diff --git a/backend/modules/todo/models.py b/backend/modules/todo/models.py index dc702b5..f3b713a 100644 --- a/backend/modules/todo/models.py +++ b/backend/modules/todo/models.py @@ -3,6 +3,7 @@ from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey from sqlalchemy.orm import relationship from core.database import Base + class Todo(Base): __tablename__ = "todos" @@ -13,4 +14,6 @@ class Todo(Base): complete = Column(Boolean, default=False) owner_id = Column(Integer, ForeignKey("users.id")) - owner = relationship("User") # Add relationship if needed, assuming User model exists in auth.models + owner = relationship( + "User" + ) # Add relationship if needed, assuming User model exists in auth.models diff --git a/backend/modules/todo/schemas.py b/backend/modules/todo/schemas.py index a857393..5efe9c9 100644 --- a/backend/modules/todo/schemas.py +++ b/backend/modules/todo/schemas.py @@ -3,21 +3,25 @@ from pydantic import BaseModel from typing import Optional import datetime + class TodoBase(BaseModel): task: str date: Optional[datetime.datetime] = None remind: bool = False complete: bool = False + class TodoCreate(TodoBase): pass + class TodoUpdate(BaseModel): task: Optional[str] = None date: Optional[datetime.datetime] = None remind: Optional[bool] = None complete: Optional[bool] = None + class Todo(TodoBase): id: int owner_id: int diff --git a/backend/modules/todo/service.py b/backend/modules/todo/service.py index bff56f6..a4f203b 100644 --- a/backend/modules/todo/service.py +++ b/backend/modules/todo/service.py @@ -1,9 +1,10 @@ # backend/modules/todo/service.py from sqlalchemy.orm import Session from . import models, schemas -from modules.auth.models import User # Assuming User model is in auth.models +from modules.auth.models import User # Assuming User model is in auth.models from fastapi import HTTPException, status + def create_todo(db: Session, todo: schemas.TodoCreate, user: User): db_todo = models.Todo(**todo.dict(), owner_id=user.id) db.add(db_todo) @@ -11,17 +12,34 @@ def create_todo(db: Session, todo: schemas.TodoCreate, user: User): db.refresh(db_todo) return db_todo + def get_todos(db: Session, user: User, skip: int = 0, limit: int = 100): - return db.query(models.Todo).filter(models.Todo.owner_id == user.id).offset(skip).limit(limit).all() + return ( + db.query(models.Todo) + .filter(models.Todo.owner_id == user.id) + .offset(skip) + .limit(limit) + .all() + ) + def get_todo(db: Session, todo_id: int, user: User): - db_todo = db.query(models.Todo).filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id).first() + db_todo = ( + db.query(models.Todo) + .filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id) + .first() + ) if db_todo is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found" + ) return db_todo + def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user: User): - db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence + db_todo = get_todo( + db=db, todo_id=todo_id, user=user + ) # Reuse get_todo to check ownership and existence update_data = todo_update.dict(exclude_unset=True) for key, value in update_data.items(): setattr(db_todo, key, value) @@ -29,8 +47,11 @@ def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user db.refresh(db_todo) return db_todo + def delete_todo(db: Session, todo_id: int, user: User): - db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence + db_todo = get_todo( + db=db, todo_id=todo_id, user=user + ) # Reuse get_todo to check ownership and existence db.delete(db_todo) db.commit() return db_todo diff --git a/backend/modules/user/api.py b/backend/modules/user/api.py index 1804394..34317f7 100644 --- a/backend/modules/user/api.py +++ b/backend/modules/user/api.py @@ -11,37 +11,52 @@ from modules.auth.models import User router = APIRouter(prefix="/user", tags=["user"]) + @router.get("/me", response_model=UserResponse) -def me(db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse: +def me( + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], +) -> UserResponse: """ Get the current user. Requires user to be logged in. Returns the user object. - """ + """ return current_user + @router.get("/{username}", response_model=UserResponse) -def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse: +def get_user( + username: str, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], +) -> UserResponse: """ Get a user by username. Returns the user object. """ if current_user.username != username: raise forbidden_exception("You can only view your own profile") - + user = db.query(User).filter(User.username == username).first() if not user: raise not_found_exception("User not found") return user + @router.patch("/{username}", response_model=UserResponse) -def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse: +def update_user( + username: str, + user_data: UserPatch, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], +) -> UserResponse: """ Update a user by username. Returns the updated user object. """ if current_user.username != username: raise forbidden_exception("You can only update your own profile") - + user = db.query(User).filter(User.username == username).first() if not user: raise not_found_exception("User not found") @@ -60,19 +75,24 @@ def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depe db.refresh(user) return user + @router.delete("/{username}", response_model=UserResponse) -def delete_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse: +def delete_user( + username: str, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], +) -> UserResponse: """ Delete a user by username. Returns the deleted user object. """ if current_user.username != username: raise forbidden_exception("You can only delete your own profile") - + user = db.query(User).filter(User.username == username).first() if not user: raise not_found_exception("User not found") db.delete(user) db.commit() - return user \ No newline at end of file + return user diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index f83fe56..217f4ef 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -12,6 +12,7 @@ from core.database import get_db, get_sessionmaker fake = Faker() + @pytest.fixture(scope="session") def postgres_container() -> Generator[PostgresContainer, None, None]: """Fixture to create a PostgreSQL container for testing.""" @@ -21,7 +22,8 @@ def postgres_container() -> Generator[PostgresContainer, None, None]: print(f"Postgres container started at {settings.DB_URL}") yield postgres print("Postgres container stopped.") - + + @pytest.fixture(scope="function") def db(postgres_container) -> Generator[Session, None, None]: """Function-scoped database session with rollback""" @@ -34,25 +36,28 @@ def db(postgres_container) -> Generator[Session, None, None]: session.rollback() session.close() + @pytest.fixture(scope="function") def client(db: Session) -> Generator[TestClient, None, None]: """Function-scoped test client with dependency override""" from main import app - + # Override the database dependency def override_get_db(): try: yield db finally: pass # Don't close session here - + app.dependency_overrides[get_db] = override_get_db - + with TestClient(app) as test_client: yield test_client - + app.dependency_overrides.clear() + def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None: from main import app - app.dependency_overrides[dependency] = lambda: mocked_response \ No newline at end of file + + app.dependency_overrides[dependency] = lambda: mocked_response diff --git a/backend/tests/helpers/generators.py b/backend/tests/helpers/generators.py index 9ed550d..cb3719f 100644 --- a/backend/tests/helpers/generators.py +++ b/backend/tests/helpers/generators.py @@ -5,17 +5,24 @@ from sqlalchemy.orm import Session from core.config import settings from modules.auth.models import User -from modules.auth.security import authenticate_user, create_access_token, create_refresh_token, hash_password +from modules.auth.security import ( + authenticate_user, + create_access_token, + create_refresh_token, + hash_password, +) from modules.auth.schemas import UserRole from tests.conftest import fake -from typing import Optional # Import Optional +from typing import Optional # Import Optional -def create_user(db: Session, is_admin: bool = False, username: Optional[str] = None) -> User: +def create_user( + db: Session, is_admin: bool = False, username: Optional[str] = None +) -> User: unhashed_password = fake.password() _user = User( name=fake.name(), - username=username or fake.user_name(), # Use provided username or generate one + username=username or fake.user_name(), # Use provided username or generate one hashed_password=hash_password(unhashed_password), uuid=uuid_pkg.uuid4(), role=UserRole.ADMIN if is_admin else UserRole.USER, @@ -24,14 +31,18 @@ def create_user(db: Session, is_admin: bool = False, username: Optional[str] = N db.add(_user) db.commit() db.refresh(_user) - return _user, unhashed_password # return for testing + return _user, unhashed_password # return for testing + def login(db: Session, username: str, password: str) -> str: user = authenticate_user(username, password, db) if not user: raise Exception("Incorrect username or password") - - access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)) + + access_token = create_access_token( + data={"sub": user.username}, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), + ) refresh_token = create_refresh_token(data={"sub": user.username}) max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 @@ -40,4 +51,4 @@ def login(db: Session, username: str, password: str) -> str: "access_token": access_token, "refresh_token": refresh_token, "max_age": max_age, - } \ No newline at end of file + } diff --git a/backend/tests/test_admin.py b/backend/tests/test_admin.py index 3dc906a..93cf2ee 100644 --- a/backend/tests/test_admin.py +++ b/backend/tests/test_admin.py @@ -7,71 +7,93 @@ from tests.helpers import generators # Test admin routes require admin privileges + def test_read_admin_unauthorized(client: TestClient) -> None: """Test accessing admin route without authentication.""" response = client.get("/api/admin/") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_read_admin_forbidden(db: Session, client: TestClient) -> None: """Test accessing admin route as a non-admin user.""" - user, password = generators.create_user(db, is_admin=False) # Use is_admin=False + user, password = generators.create_user(db, is_admin=False) # Use is_admin=False login_rsp = generators.login(db, user.username, password) access_token = login_rsp["access_token"] - response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"}) + response = client.get( + "/api/admin/", headers={"Authorization": f"Bearer {access_token}"} + ) assert response.status_code == status.HTTP_403_FORBIDDEN + def test_read_admin_success(db: Session, client: TestClient) -> None: """Test accessing admin route as an admin user.""" - admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True + admin_user, password = generators.create_user( + db, is_admin=True + ) # Use is_admin=True login_rsp = generators.login(db, admin_user.username, password) access_token = login_rsp["access_token"] - response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"}) + response = client.get( + "/api/admin/", headers={"Authorization": f"Bearer {access_token}"} + ) assert response.status_code == status.HTTP_200_OK assert response.json() == {"message": "Admin route"} -@patch("modules.admin.api.cleardb.delay") # Mock the celery task + +@patch("modules.admin.api.cleardb.delay") # Mock the celery task def test_clear_db_soft(mock_cleardb_delay, db: Session, client: TestClient) -> None: """Test soft clearing the database as admin.""" - admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True + admin_user, password = generators.create_user( + db, is_admin=True + ) # Use is_admin=True login_rsp = generators.login(db, admin_user.username, password) access_token = login_rsp["access_token"] response = client.post( "/api/admin/cleardb", headers={"Authorization": f"Bearer {access_token}"}, - json={"hard": False} + json={"hard": False}, ) assert response.status_code == status.HTTP_200_OK - assert response.json() == {"message": "Clearing database in the background", "hard": False} + assert response.json() == { + "message": "Clearing database in the background", + "hard": False, + } mock_cleardb_delay.assert_called_once_with(False) -@patch("modules.admin.api.cleardb.delay") # Mock the celery task + +@patch("modules.admin.api.cleardb.delay") # Mock the celery task def test_clear_db_hard(mock_cleardb_delay, db: Session, client: TestClient) -> None: """Test hard clearing the database as admin.""" - admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True + admin_user, password = generators.create_user( + db, is_admin=True + ) # Use is_admin=True login_rsp = generators.login(db, admin_user.username, password) access_token = login_rsp["access_token"] response = client.post( "/api/admin/cleardb", headers={"Authorization": f"Bearer {access_token}"}, - json={"hard": True} + json={"hard": True}, ) assert response.status_code == status.HTTP_200_OK - assert response.json() == {"message": "Clearing database in the background", "hard": True} + assert response.json() == { + "message": "Clearing database in the background", + "hard": True, + } mock_cleardb_delay.assert_called_once_with(True) + def test_clear_db_forbidden(db: Session, client: TestClient) -> None: """Test clearing the database as a non-admin user.""" - user, password = generators.create_user(db, is_admin=False) # Use is_admin=False + user, password = generators.create_user(db, is_admin=False) # Use is_admin=False login_rsp = generators.login(db, user.username, password) access_token = login_rsp["access_token"] response = client.post( "/api/admin/cleardb", headers={"Authorization": f"Bearer {access_token}"}, - json={"hard": False} + json={"hard": False}, ) assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index 22c0115..53acc81 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -34,6 +34,7 @@ def test_register(client: TestClient) -> None: ) assert response.status_code == status.HTTP_201_CREATED + def test_login(db: Session, client: TestClient) -> None: user, unhashed_password = generators.create_user(db) @@ -51,17 +52,21 @@ def test_login(db: Session, client: TestClient) -> None: assert "token_type" in response_data assert response_data["token_type"] == "bearer" + def test_refresh_token(db: Session, client: TestClient) -> None: user, unhashed_password = generators.create_user(db) rsp = generators.login(db, user.username, unhashed_password) access_token = rsp["access_token"] refresh_token = rsp["refresh_token"] - time.sleep(1) # Sleep to ensure tokens won't be identical + time.sleep(1) # Sleep to ensure tokens won't be identical response = client.post( "/api/auth/refresh", - headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}, + headers={ + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + }, json={"refresh_token": refresh_token}, ) assert response.status_code == status.HTTP_200_OK @@ -70,7 +75,10 @@ def test_refresh_token(db: Session, client: TestClient) -> None: assert "access_token" in response_data assert "token_type" in response_data assert response_data["token_type"] == "bearer" - assert response_data["access_token"] != access_token # Ensure the token is refreshed + assert ( + response_data["access_token"] != access_token + ) # Ensure the token is refreshed + def test_logout(db: Session, client: TestClient) -> None: user, unhashed_password = generators.create_user(db) @@ -80,15 +88,20 @@ def test_logout(db: Session, client: TestClient) -> None: response = client.post( "/api/auth/logout", - headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}, + headers={ + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + }, json={"refresh_token": refresh_token}, ) assert response.status_code == status.HTTP_200_OK # Verify that the token is blacklisted - blacklisted_token = db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first() + blacklisted_token = ( + db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first() + ) assert blacklisted_token is not None - + # Verify that we can't still actually do anything response = client.get( "/api/user/me", @@ -98,7 +111,10 @@ def test_logout(db: Session, client: TestClient) -> None: response = client.post( "/api/auth/refresh", - headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}, + headers={ + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + }, json={"refresh_token": refresh_token}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -106,7 +122,9 @@ def test_logout(db: Session, client: TestClient) -> None: def test_get_me(db: Session, client: TestClient) -> None: user, unhashed_password = generators.create_user(db) - access_token = generators.login(db, user.username, unhashed_password)["access_token"] + access_token = generators.login(db, user.username, unhashed_password)[ + "access_token" + ] response = client.get( "/api/user/me", @@ -119,14 +137,18 @@ def test_get_me(db: Session, client: TestClient) -> None: assert response_data["uuid"] == user.uuid assert response_data["username"] == user.username + def test_get_me_unauthorized(client: TestClient) -> None: ### This test should fail (unauthorized) because the user isn't logged in response = client.get("/api/user/me") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_get_user(db: Session, client: TestClient) -> None: user, unhashed_password = generators.create_user(db) - access_token = generators.login(db, user.username, unhashed_password)["access_token"] + access_token = generators.login(db, user.username, unhashed_password)[ + "access_token" + ] response = client.get( f"/api/user/{user.username}", @@ -139,11 +161,14 @@ def test_get_user(db: Session, client: TestClient) -> None: assert response_data["uuid"] == user.uuid assert response_data["username"] == user.username + def test_get_user_unauthorized(db: Session, client: TestClient) -> None: ### This test should fail (unauthorized) because the user isn't us user, unhashed_password = generators.create_user(db) user2, _ = generators.create_user(db) - access_token = generators.login(db, user.username, unhashed_password)["access_token"] + access_token = generators.login(db, user.username, unhashed_password)[ + "access_token" + ] response = client.get( f"/api/user/{user2.username}", @@ -151,11 +176,14 @@ def test_get_user_unauthorized(db: Session, client: TestClient) -> None: ) assert response.status_code == status.HTTP_403_FORBIDDEN + def test_update_user(db: Session, client: TestClient) -> None: user, unhashed_password = generators.create_user(db) new_name = fake.name() - access_token = generators.login(db, user.username, unhashed_password)["access_token"] + access_token = generators.login(db, user.username, unhashed_password)[ + "access_token" + ] response = client.patch( f"/api/user/{user.username}", headers={"Authorization": f"Bearer {access_token}"}, @@ -168,7 +196,9 @@ def test_update_user(db: Session, client: TestClient) -> None: def test_delete_user(db: Session, client: TestClient) -> None: user, unhashed_password = generators.create_user(db) - access_token = generators.login(db, user.username, unhashed_password)["access_token"] + access_token = generators.login(db, user.username, unhashed_password)[ + "access_token" + ] response = client.delete( f"/api/user/{user.username}", headers={"Authorization": f"Bearer {access_token}"}, @@ -179,6 +209,7 @@ def test_delete_user(db: Session, client: TestClient) -> None: deleted_user = db.query(User).filter(User.username == user.username).first() assert deleted_user is None + def test_get_user_forbidden(db: Session, client: TestClient) -> None: """Test getting another user's profile (should be forbidden).""" user1, password_user1 = generators.create_user(db, username="user1_get_forbidden") @@ -195,9 +226,12 @@ def test_get_user_forbidden(db: Session, client: TestClient) -> None: ) assert response.status_code == status.HTTP_403_FORBIDDEN + def test_update_user_forbidden(db: Session, client: TestClient) -> None: """Test updating another user's profile (should be forbidden).""" - user1, password_user1 = generators.create_user(db, username="user1_update_forbidden") + user1, password_user1 = generators.create_user( + db, username="user1_update_forbidden" + ) user2, _ = generators.create_user(db, username="user2_update_forbidden") new_name = fake.name() @@ -213,9 +247,12 @@ def test_update_user_forbidden(db: Session, client: TestClient) -> None: ) assert response.status_code == status.HTTP_403_FORBIDDEN + def test_delete_user_forbidden(db: Session, client: TestClient) -> None: """Test deleting another user's profile (should be forbidden).""" - user1, password_user1 = generators.create_user(db, username="user1_delete_forbidden") + user1, password_user1 = generators.create_user( + db, username="user1_delete_forbidden" + ) user2, _ = generators.create_user(db, username="user2_delete_forbidden") # Log in as user1 diff --git a/backend/tests/test_calendar.py b/backend/tests/test_calendar.py index 5e250d9..35f3bec 100644 --- a/backend/tests/test_calendar.py +++ b/backend/tests/test_calendar.py @@ -4,9 +4,10 @@ from sqlalchemy.orm import Session from datetime import datetime, timedelta from tests.helpers import generators -from modules.calendar.models import CalendarEvent # Assuming model exists +from modules.calendar.models import CalendarEvent # Assuming model exists from tests.conftest import fake + # Helper function to create an event payload def create_event_payload(start_offset_days=0, end_offset_days=1): start_time = datetime.utcnow() + timedelta(days=start_offset_days) @@ -14,19 +15,22 @@ def create_event_payload(start_offset_days=0, end_offset_days=1): return { "title": fake.sentence(nb_words=3), "description": fake.text(), - "start": start_time.isoformat(), # Rename start_time to start - "end": end_time.isoformat(), # Rename end_time to end + "start": start_time.isoformat(), # Rename start_time to start + "end": end_time.isoformat(), # Rename end_time to end "all_day": fake.boolean(), } + # --- Test Create Event --- + def test_create_event_unauthorized(client: TestClient) -> None: """Test creating an event without authentication.""" payload = create_event_payload() response = client.post("/api/calendar/events", json=payload) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_create_event_success(db: Session, client: TestClient) -> None: """Test creating a calendar event successfully.""" user, password = generators.create_user(db) @@ -37,9 +41,11 @@ def test_create_event_success(db: Session, client: TestClient) -> None: response = client.post( "/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, - json=payload + json=payload, ) - assert response.status_code == status.HTTP_201_CREATED # Change expected status to 201 + assert ( + response.status_code == status.HTTP_201_CREATED + ) # Change expected status to 201 data = response.json() assert data["title"] == payload["title"] assert data["description"] == payload["description"] @@ -56,13 +62,16 @@ def test_create_event_success(db: Session, client: TestClient) -> None: assert event_in_db.user_id == user.id assert event_in_db.title == payload["title"] + # --- Test Get Events --- + def test_get_events_unauthorized(client: TestClient) -> None: """Test getting events without authentication.""" response = client.get("/api/calendar/events") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_get_events_success(db: Session, client: TestClient) -> None: """Test getting all calendar events for a user.""" user, password = generators.create_user(db) @@ -71,21 +80,31 @@ def test_get_events_success(db: Session, client: TestClient) -> None: # Create a couple of events for the user payload1 = create_event_payload(0, 1) - client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1) + client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload1, + ) payload2 = create_event_payload(2, 3) - client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2) + client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload2, + ) # Create an event for another user (should not be returned) other_user, other_password = generators.create_user(db) other_login_rsp = generators.login(db, other_user.username, other_password) other_access_token = other_login_rsp["access_token"] other_payload = create_event_payload(4, 5) - client.post("/api/calendar/events", headers={"Authorization": f"Bearer {other_access_token}"}, json=other_payload) - + client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {other_access_token}"}, + json=other_payload, + ) response = client.get( - "/api/calendar/events", - headers={"Authorization": f"Bearer {access_token}"} + "/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"} ) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -103,12 +122,24 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None: access_token = login_rsp["access_token"] # Create events - payload1 = create_event_payload(0, 1) # Today -> Tomorrow - client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1) - payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days - client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2) - payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days - client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload3) + payload1 = create_event_payload(0, 1) # Today -> Tomorrow + client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload1, + ) + payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days + client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload2, + ) + payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days + client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload3, + ) # Filter for events starting within the next week start_filter = datetime.utcnow().isoformat() @@ -117,11 +148,11 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None: response = client.get( "/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, - params={"start": start_filter, "end": end_filter} + params={"start": start_filter, "end": end_filter}, ) assert response.status_code == status.HTTP_200_OK data = response.json() - assert len(data) == 2 # Should get event 1 and 2 + assert len(data) == 2 # Should get event 1 and 2 assert data[0]["title"] == payload1["title"] assert data[1]["title"] == payload2["title"] @@ -130,40 +161,50 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None: response = client.get( "/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, - params={"start": start_filter_late} + params={"start": start_filter_late}, ) assert response.status_code == status.HTTP_200_OK data = response.json() - assert len(data) == 1 # Should get event 3 + assert len(data) == 1 # Should get event 3 assert data[0]["title"] == payload3["title"] # --- Test Get Event By ID --- + def test_get_event_by_id_unauthorized(db: Session, client: TestClient) -> None: """Test getting a specific event without authentication.""" user, password = generators.create_user(db) login_rsp = generators.login(db, user.username, password) access_token = login_rsp["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload, + ) event_id = create_response.json()["id"] response = client.get(f"/api/calendar/events/{event_id}") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_get_event_by_id_success(db: Session, client: TestClient) -> None: """Test getting a specific event successfully.""" user, password = generators.create_user(db) login_rsp = generators.login(db, user.username, password) access_token = login_rsp["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload, + ) event_id = create_response.json()["id"] response = client.get( f"/api/calendar/events/{event_id}", - headers={"Authorization": f"Bearer {access_token}"} + headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -171,6 +212,7 @@ def test_get_event_by_id_success(db: Session, client: TestClient) -> None: assert data["title"] == payload["title"] assert data["user_id"] == user.id + def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None: """Test getting a non-existent event.""" user, password = generators.create_user(db) @@ -180,10 +222,11 @@ def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None: response = client.get( f"/api/calendar/events/{non_existent_id}", - headers={"Authorization": f"Bearer {access_token}"} + headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None: """Test getting another user's event.""" user1, password_user1 = generators.create_user(db) @@ -193,7 +236,11 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None: login_rsp1 = generators.login(db, user1.username, password_user1) access_token1 = login_rsp1["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload) + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token1}"}, + json=payload, + ) event_id = create_response.json()["id"] # Log in as user2 and try to get user1's event @@ -202,45 +249,60 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None: response = client.get( f"/api/calendar/events/{event_id}", - headers={"Authorization": f"Bearer {access_token2}"} + headers={"Authorization": f"Bearer {access_token2}"}, ) - assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match + assert ( + response.status_code == status.HTTP_404_NOT_FOUND + ) # Service layer returns 404 if user_id doesn't match + # --- Test Update Event --- + def test_update_event_unauthorized(db: Session, client: TestClient) -> None: """Test updating an event without authentication.""" user, password = generators.create_user(db) login_rsp = generators.login(db, user.username, password) access_token = login_rsp["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload, + ) event_id = create_response.json()["id"] update_payload = {"title": "Updated Title"} response = client.patch(f"/api/calendar/events/{event_id}", json=update_payload) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_update_event_success(db: Session, client: TestClient) -> None: """Test updating an event successfully.""" user, password = generators.create_user(db) login_rsp = generators.login(db, user.username, password) access_token = login_rsp["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) - assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201 + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload, + ) + assert ( + create_response.status_code == status.HTTP_201_CREATED + ) # Ensure creation check uses 201 event_id = create_response.json()["id"] update_payload = { "title": "Updated Title", "description": "Updated description.", - "all_day": not payload["all_day"] # Toggle all_day + "all_day": not payload["all_day"], # Toggle all_day } response = client.patch( f"/api/calendar/events/{event_id}", headers={"Authorization": f"Bearer {access_token}"}, - json=update_payload + json=update_payload, ) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -248,7 +310,7 @@ def test_update_event_success(db: Session, client: TestClient) -> None: assert data["title"] == update_payload["title"] assert data["description"] == update_payload["description"] assert data["all_day"] == update_payload["all_day"] - assert data["start"] == payload["start"] # Check correct field name 'start' + assert data["start"] == payload["start"] # Check correct field name 'start' assert data["user_id"] == user.id # Verify in DB @@ -258,6 +320,7 @@ def test_update_event_success(db: Session, client: TestClient) -> None: assert event_in_db.description == update_payload["description"] assert event_in_db.all_day == update_payload["all_day"] + def test_update_event_not_found(db: Session, client: TestClient) -> None: """Test updating a non-existent event.""" user, password = generators.create_user(db) @@ -269,10 +332,11 @@ def test_update_event_not_found(db: Session, client: TestClient) -> None: response = client.patch( f"/api/calendar/events/{non_existent_id}", headers={"Authorization": f"Bearer {access_token}"}, - json=update_payload + json=update_payload, ) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_update_event_forbidden(db: Session, client: TestClient) -> None: """Test updating another user's event.""" user1, password_user1 = generators.create_user(db) @@ -282,7 +346,11 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None: login_rsp1 = generators.login(db, user1.username, password_user1) access_token1 = login_rsp1["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload) + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token1}"}, + json=payload, + ) event_id = create_response.json()["id"] # Log in as user2 and try to update user1's event @@ -293,32 +361,47 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None: response = client.patch( f"/api/calendar/events/{event_id}", headers={"Authorization": f"Bearer {access_token2}"}, - json=update_payload + json=update_payload, ) - assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match + assert ( + response.status_code == status.HTTP_404_NOT_FOUND + ) # Service layer returns 404 if user_id doesn't match + # --- Test Delete Event --- + def test_delete_event_unauthorized(db: Session, client: TestClient) -> None: """Test deleting an event without authentication.""" user, password = generators.create_user(db) login_rsp = generators.login(db, user.username, password) access_token = login_rsp["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload, + ) event_id = create_response.json()["id"] response = client.delete(f"/api/calendar/events/{event_id}") assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_delete_event_success(db: Session, client: TestClient) -> None: """Test deleting an event successfully.""" user, password = generators.create_user(db) login_rsp = generators.login(db, user.username, password) access_token = login_rsp["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) - assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201 + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token}"}, + json=payload, + ) + assert ( + create_response.status_code == status.HTTP_201_CREATED + ) # Ensure creation check uses 201 event_id = create_response.json()["id"] # Verify event exists before delete @@ -327,7 +410,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None: response = client.delete( f"/api/calendar/events/{event_id}", - headers={"Authorization": f"Bearer {access_token}"} + headers={"Authorization": f"Bearer {access_token}"}, ) assert response.status_code == status.HTTP_204_NO_CONTENT @@ -338,7 +421,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None: # Try getting the deleted event (should be 404) get_response = client.get( f"/api/calendar/events/{event_id}", - headers={"Authorization": f"Bearer {access_token}"} + headers={"Authorization": f"Bearer {access_token}"}, ) assert get_response.status_code == status.HTTP_404_NOT_FOUND @@ -352,7 +435,7 @@ def test_delete_event_not_found(db: Session, client: TestClient) -> None: response = client.delete( f"/api/calendar/events/{non_existent_id}", - headers={"Authorization": f"Bearer {access_token}"} + headers={"Authorization": f"Bearer {access_token}"}, ) # The service layer raises NotFound, which should result in 404 assert response.status_code == status.HTTP_404_NOT_FOUND @@ -367,7 +450,11 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None: login_rsp1 = generators.login(db, user1.username, password_user1) access_token1 = login_rsp1["access_token"] payload = create_event_payload() - create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload) + create_response = client.post( + "/api/calendar/events", + headers={"Authorization": f"Bearer {access_token1}"}, + json=payload, + ) event_id = create_response.json()["id"] # Log in as user2 and try to delete user1's event @@ -376,7 +463,7 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None: response = client.delete( f"/api/calendar/events/{event_id}", - headers={"Authorization": f"Bearer {access_token2}"} + headers={"Authorization": f"Bearer {access_token2}"}, ) # The service layer raises NotFound if user_id doesn't match, resulting in 404 assert response.status_code == status.HTTP_404_NOT_FOUND @@ -385,4 +472,3 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None: event_in_db = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first() assert event_in_db is not None assert event_in_db.user_id == user1.id - diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 22c5cf6..29ea03c 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -2,6 +2,7 @@ from fastapi.testclient import TestClient # No database needed for this simple test + def test_health_check(client: TestClient): """Test the health check endpoint.""" response = client.get("/api/health") diff --git a/backend/tests/test_nlp.py b/backend/tests/test_nlp.py index ab770fd..4086d22 100644 --- a/backend/tests/test_nlp.py +++ b/backend/tests/test_nlp.py @@ -7,24 +7,37 @@ from datetime import datetime from tests.helpers import generators from modules.nlp.schemas import ProcessCommandResponse -from modules.nlp.models import MessageSender, ChatMessage # Import necessary models/enums +from modules.nlp.models import ( + MessageSender, + ChatMessage, +) # Import necessary models/enums + # --- Mocks --- # Mock the external AI call and internal service functions @pytest.fixture(autouse=True) def mock_nlp_services(): - with patch("modules.nlp.api.process_request") as mock_process, \ - patch("modules.nlp.api.ask_ai") as mock_ask, \ - patch("modules.nlp.api.save_chat_message") as mock_save, \ - patch("modules.nlp.api.get_chat_history") as mock_get_history, \ - patch("modules.nlp.api.create_calendar_event") as mock_create_event, \ - patch("modules.nlp.api.get_calendar_events") as mock_get_events, \ - patch("modules.nlp.api.update_calendar_event") as mock_update_event, \ - patch("modules.nlp.api.delete_calendar_event") as mock_delete_event, \ - patch("modules.nlp.api.todo_service.create_todo") as mock_create_todo, \ - patch("modules.nlp.api.todo_service.get_todos") as mock_get_todos, \ - patch("modules.nlp.api.todo_service.update_todo") as mock_update_todo, \ - patch("modules.nlp.api.todo_service.delete_todo") as mock_delete_todo: + with patch("modules.nlp.api.process_request") as mock_process, patch( + "modules.nlp.api.ask_ai" + ) as mock_ask, patch("modules.nlp.api.save_chat_message") as mock_save, patch( + "modules.nlp.api.get_chat_history" + ) as mock_get_history, patch( + "modules.nlp.api.create_calendar_event" + ) as mock_create_event, patch( + "modules.nlp.api.get_calendar_events" + ) as mock_get_events, patch( + "modules.nlp.api.update_calendar_event" + ) as mock_update_event, patch( + "modules.nlp.api.delete_calendar_event" + ) as mock_delete_event, patch( + "modules.nlp.api.todo_service.create_todo" + ) as mock_create_todo, patch( + "modules.nlp.api.todo_service.get_todos" + ) as mock_get_todos, patch( + "modules.nlp.api.todo_service.update_todo" + ) as mock_update_todo, patch( + "modules.nlp.api.todo_service.delete_todo" + ) as mock_delete_todo: mocks = { "process_request": mock_process, "ask_ai": mock_ask, @@ -41,21 +54,24 @@ def mock_nlp_services(): } yield mocks + # --- Helper Function --- def _login_user(db: Session, client: TestClient): user, password = generators.create_user(db) login_rsp = generators.login(db, user.username, password) return user, login_rsp["access_token"], login_rsp["refresh_token"] + # --- Tests for /process-command --- + def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services): user, access_token, refresh_token = _login_user(db, client) user_input = "What is the capital of France?" mock_nlp_services["process_request"].return_value = { "intent": "ask_ai", "params": {"request": user_input}, - "response_text": "Let me check that for you." + "response_text": "Let me check that for you.", } mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris." @@ -63,25 +79,45 @@ def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_servic "/api/nlp/process-command", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json={"user_input": user_input} + json={"user_input": user_input}, ) assert response.status_code == status.HTTP_200_OK - assert response.json() == ProcessCommandResponse(responses=["Let me check that for you.", "The capital of France is Paris."]).model_dump() + assert ( + response.json() + == ProcessCommandResponse( + responses=["Let me check that for you.", "The capital of France is Paris."] + ).model_dump() + ) # Verify save calls: user message, initial AI response, final AI answer assert mock_nlp_services["save_chat_message"].call_count == 3 - mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input) - mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you.") - mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="The capital of France is Paris.") + mock_nlp_services["save_chat_message"].assert_any_call( + db, user_id=user.id, sender=MessageSender.USER, text=user_input + ) + mock_nlp_services["save_chat_message"].assert_any_call( + db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you." + ) + mock_nlp_services["save_chat_message"].assert_any_call( + db, + user_id=user.id, + sender=MessageSender.AI, + text="The capital of France is Paris.", + ) mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input) -def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_services): + +def test_process_command_get_calendar( + client: TestClient, db: Session, mock_nlp_services +): user, access_token, refresh_token = _login_user(db, client) user_input = "What are my events today?" mock_nlp_services["process_request"].return_value = { "intent": "get_calendar_events", - "params": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T23:59:59Z"}, # Example params - "response_text": "Okay, fetching your events." + "params": { + "start": "2024-01-01T00:00:00Z", + "end": "2024-01-01T23:59:59Z", + }, # Example params + "response_text": "Okay, fetching your events.", } # Mock the actual event model returned by the service mock_event = MagicMock() @@ -94,26 +130,32 @@ def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_ "/api/nlp/process-command", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json={"user_input": user_input} + json={"user_input": user_input}, ) assert response.status_code == status.HTTP_200_OK expected_responses = [ "Okay, fetching your events.", "Here are the events:", - "- Team Meeting (2024-01-01 10:00 - 11:00)" + "- Team Meeting (2024-01-01 10:00 - 11:00)", ] - assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump() - assert mock_nlp_services["save_chat_message"].call_count == 4 # User, Initial AI, Header, Event + assert ( + response.json() + == ProcessCommandResponse(responses=expected_responses).model_dump() + ) + assert ( + mock_nlp_services["save_chat_message"].call_count == 4 + ) # User, Initial AI, Header, Event mock_nlp_services["get_calendar_events"].assert_called_once() + def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services): user, access_token, refresh_token = _login_user(db, client) user_input = "Add buy milk to my list" mock_nlp_services["process_request"].return_value = { "intent": "add_todo", "params": {"task": "buy milk"}, - "response_text": "Adding it now." + "response_text": "Adding it now.", } # Mock the actual Todo model returned by the service mock_todo = MagicMock() @@ -125,81 +167,119 @@ def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_serv "/api/nlp/process-command", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json={"user_input": user_input} + json={"user_input": user_input}, ) assert response.status_code == status.HTTP_200_OK expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."] - assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump() - assert mock_nlp_services["save_chat_message"].call_count == 3 # User, Initial AI, Confirmation AI + assert ( + response.json() + == ProcessCommandResponse(responses=expected_responses).model_dump() + ) + assert ( + mock_nlp_services["save_chat_message"].call_count == 3 + ) # User, Initial AI, Confirmation AI mock_nlp_services["create_todo"].assert_called_once() -def test_process_command_clarification(client: TestClient, db: Session, mock_nlp_services): + +def test_process_command_clarification( + client: TestClient, db: Session, mock_nlp_services +): user, access_token, refresh_token = _login_user(db, client) user_input = "Delete the event" clarification_text = "Which event do you mean? Please provide the ID." mock_nlp_services["process_request"].return_value = { "intent": "clarification_needed", "params": {"request": user_input}, - "response_text": clarification_text + "response_text": clarification_text, } response = client.post( "/api/nlp/process-command", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json={"user_input": user_input} + json={"user_input": user_input}, ) assert response.status_code == status.HTTP_200_OK - assert response.json() == ProcessCommandResponse(responses=[clarification_text]).model_dump() + assert ( + response.json() + == ProcessCommandResponse(responses=[clarification_text]).model_dump() + ) # Verify save calls: user message, clarification AI response assert mock_nlp_services["save_chat_message"].call_count == 2 - mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input) - mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=clarification_text) + mock_nlp_services["save_chat_message"].assert_any_call( + db, user_id=user.id, sender=MessageSender.USER, text=user_input + ) + mock_nlp_services["save_chat_message"].assert_any_call( + db, user_id=user.id, sender=MessageSender.AI, text=clarification_text + ) # Ensure no action services were called mock_nlp_services["delete_calendar_event"].assert_not_called() -def test_process_command_error_intent(client: TestClient, db: Session, mock_nlp_services): + +def test_process_command_error_intent( + client: TestClient, db: Session, mock_nlp_services +): user, access_token, refresh_token = _login_user(db, client) user_input = "Gibberish request" error_text = "Sorry, I didn't understand that." mock_nlp_services["process_request"].return_value = { "intent": "error", "params": {}, - "response_text": error_text + "response_text": error_text, } response = client.post( "/api/nlp/process-command", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json={"user_input": user_input} + json={"user_input": user_input}, ) assert response.status_code == status.HTTP_200_OK - assert response.json() == ProcessCommandResponse(responses=[error_text]).model_dump() + assert ( + response.json() == ProcessCommandResponse(responses=[error_text]).model_dump() + ) # Verify save calls: user message, error AI response assert mock_nlp_services["save_chat_message"].call_count == 2 - mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input) - mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=error_text) + mock_nlp_services["save_chat_message"].assert_any_call( + db, user_id=user.id, sender=MessageSender.USER, text=user_input + ) + mock_nlp_services["save_chat_message"].assert_any_call( + db, user_id=user.id, sender=MessageSender.AI, text=error_text + ) + # --- Tests for /history --- + def test_get_history(client: TestClient, db: Session, mock_nlp_services): user, access_token, refresh_token = _login_user(db, client) # Mock the history data returned by the service mock_history = [ - ChatMessage(id=1, user_id=user.id, sender=MessageSender.USER, text="Hello", timestamp=datetime.now()), - ChatMessage(id=2, user_id=user.id, sender=MessageSender.AI, text="Hi there!", timestamp=datetime.now()) + ChatMessage( + id=1, + user_id=user.id, + sender=MessageSender.USER, + text="Hello", + timestamp=datetime.now(), + ), + ChatMessage( + id=2, + user_id=user.id, + sender=MessageSender.AI, + text="Hi there!", + timestamp=datetime.now(), + ), ] mock_nlp_services["get_chat_history"].return_value = mock_history response = client.get( "/api/nlp/history", headers={"Authorization": f"Bearer {access_token}"}, - cookies={"refresh_token": refresh_token} + cookies={"refresh_token": refresh_token}, ) assert response.status_code == status.HTTP_200_OK @@ -208,11 +288,15 @@ def test_get_history(client: TestClient, db: Session, mock_nlp_services): assert len(response_data) == 2 assert response_data[0]["text"] == "Hello" assert response_data[1]["text"] == "Hi there!" - mock_nlp_services["get_chat_history"].assert_called_once_with(db, user_id=user.id, limit=50) + mock_nlp_services["get_chat_history"].assert_called_once_with( + db, user_id=user.id, limit=50 + ) + def test_get_history_unauthorized(client: TestClient): response = client.get("/api/nlp/history") assert response.status_code == status.HTTP_401_UNAUTHORIZED + # Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.) # Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete) diff --git a/backend/tests/test_todo.py b/backend/tests/test_todo.py index 7497a57..1a1d45b 100644 --- a/backend/tests/test_todo.py +++ b/backend/tests/test_todo.py @@ -5,14 +5,17 @@ from datetime import date from tests.helpers import generators + # Helper Function def _login_user(db: Session, client: TestClient): user, password = generators.create_user(db) login_rsp = generators.login(db, user.username, password) return user, login_rsp["access_token"], login_rsp["refresh_token"] + # --- Test CRUD Operations --- + def test_create_todo(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) today_date = date.today() @@ -20,14 +23,14 @@ def test_create_todo(client: TestClient, db: Session): todo_data = { "task": "Test TODO", "date": f"{today_date.isoformat()}T00:00:00", - "remind": True + "remind": True, } response = client.post( "/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json=todo_data + json=todo_data, ) assert response.status_code == status.HTTP_201_CREATED @@ -35,50 +38,66 @@ def test_create_todo(client: TestClient, db: Session): assert data["task"] == todo_data["task"] assert data["date"] == todo_data["date"] assert data["remind"] == todo_data["remind"] - assert data["complete"] is False # Default + assert data["complete"] is False # Default assert "id" in data assert data["owner_id"] == user.id + def test_read_todos(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) # Create some todos for the user - client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 1"}) - client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 2"}) + client.post( + "/api/todos/", + headers={"Authorization": f"Bearer {access_token}"}, + cookies={"refresh_token": refresh_token}, + json={"task": "Todo 1"}, + ) + client.post( + "/api/todos/", + headers={"Authorization": f"Bearer {access_token}"}, + cookies={"refresh_token": refresh_token}, + json={"task": "Todo 2"}, + ) # Create a todo for another user other_user, other_password = generators.create_user(db) other_login_rsp = generators.login(db, other_user.username, other_password) other_access_token = other_login_rsp["access_token"] other_refresh_token = other_login_rsp["refresh_token"] - client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"}) - + client.post( + "/api/todos/", + headers={"Authorization": f"Bearer {other_access_token}"}, + cookies={"refresh_token": other_refresh_token}, + json={"task": "Other User Todo"}, + ) response = client.get( "/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, - cookies={"refresh_token": refresh_token} + cookies={"refresh_token": refresh_token}, ) assert response.status_code == status.HTTP_200_OK data = response.json() - assert len(data) == 2 # Should only get todos for the logged-in user + assert len(data) == 2 # Should only get todos for the logged-in user assert data[0]["task"] == "Todo 1" assert data[1]["task"] == "Todo 2" + def test_read_single_todo(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) create_response = client.post( "/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json={"task": "Specific Todo"} + json={"task": "Specific Todo"}, ) todo_id = create_response.json()["id"] response = client.get( f"/api/todos/{todo_id}", headers={"Authorization": f"Bearer {access_token}"}, - cookies={"refresh_token": refresh_token} + cookies={"refresh_token": refresh_token}, ) assert response.status_code == status.HTTP_200_OK @@ -87,15 +106,17 @@ def test_read_single_todo(client: TestClient, db: Session): assert data["task"] == "Specific Todo" assert data["owner_id"] == user.id + def test_read_single_todo_not_found(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) response = client.get( - "/api/todos/9999", # Non-existent ID + "/api/todos/9999", # Non-existent ID headers={"Authorization": f"Bearer {access_token}"}, - cookies={"refresh_token": refresh_token} + cookies={"refresh_token": refresh_token}, ) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_read_single_todo_forbidden(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) @@ -104,16 +125,26 @@ def test_read_single_todo_forbidden(client: TestClient, db: Session): other_login_rsp = generators.login(db, other_user.username, other_password) other_access_token = other_login_rsp["access_token"] other_refresh_token = other_login_rsp["refresh_token"] - other_create_response = client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"}) + other_create_response = client.post( + "/api/todos/", + headers={"Authorization": f"Bearer {other_access_token}"}, + cookies={"refresh_token": other_refresh_token}, + json={"task": "Other User Todo"}, + ) other_todo_id = other_create_response.json()["id"] # Try to access the other user's todo response = client.get( f"/api/todos/{other_todo_id}", - headers={"Authorization": f"Bearer {access_token}"}, # Using the first user's token - cookies={"refresh_token": refresh_token} + headers={ + "Authorization": f"Bearer {access_token}" + }, # Using the first user's token + cookies={"refresh_token": refresh_token}, ) - assert response.status_code == status.HTTP_404_NOT_FOUND # Service raises 404 if not found for *this* user + assert ( + response.status_code == status.HTTP_404_NOT_FOUND + ) # Service raises 404 if not found for *this* user + def test_update_todo(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) @@ -121,7 +152,7 @@ def test_update_todo(client: TestClient, db: Session): "/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json={"task": "Update Me"} + json={"task": "Update Me"}, ) todo_id = create_response.json()["id"] @@ -130,7 +161,7 @@ def test_update_todo(client: TestClient, db: Session): f"/api/todos/{todo_id}", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json=update_data + json=update_data, ) assert response.status_code == status.HTTP_200_OK @@ -144,7 +175,7 @@ def test_update_todo(client: TestClient, db: Session): get_response = client.get( f"/api/todos/{todo_id}", headers={"Authorization": f"Bearer {access_token}"}, - cookies={"refresh_token": refresh_token} + cookies={"refresh_token": refresh_token}, ) assert get_response.json()["task"] == update_data["task"] assert get_response.json()["complete"] == update_data["complete"] @@ -154,55 +185,60 @@ def test_update_todo_not_found(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) update_data = {"task": "Updated Task", "complete": True} response = client.put( - "/api/todos/9999", # Non-existent ID + "/api/todos/9999", # Non-existent ID headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json=update_data + json=update_data, ) assert response.status_code == status.HTTP_404_NOT_FOUND + def test_delete_todo(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) create_response = client.post( "/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, - json={"task": "Delete Me"} + json={"task": "Delete Me"}, ) todo_id = create_response.json()["id"] response = client.delete( f"/api/todos/{todo_id}", headers={"Authorization": f"Bearer {access_token}"}, - cookies={"refresh_token": refresh_token} + cookies={"refresh_token": refresh_token}, ) - assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item + assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item assert response.json()["id"] == todo_id # Verify deletion by trying to read get_response = client.get( f"/api/todos/{todo_id}", headers={"Authorization": f"Bearer {access_token}"}, - cookies={"refresh_token": refresh_token} + cookies={"refresh_token": refresh_token}, ) assert get_response.status_code == status.HTTP_404_NOT_FOUND + def test_delete_todo_not_found(client: TestClient, db: Session): user, access_token, refresh_token = _login_user(db, client) response = client.delete( - "/api/todos/9999", # Non-existent ID + "/api/todos/9999", # Non-existent ID headers={"Authorization": f"Bearer {access_token}"}, - cookies={"refresh_token": refresh_token} + cookies={"refresh_token": refresh_token}, ) assert response.status_code == status.HTTP_404_NOT_FOUND + # --- Test Authentication/Authorization --- + def test_create_todo_unauthorized(client: TestClient): response = client.post("/api/todos/", json={"task": "No Auth"}) assert response.status_code == status.HTTP_401_UNAUTHORIZED + def test_read_todos_unauthorized(client: TestClient): response = client.get("/api/todos/") assert response.status_code == status.HTTP_401_UNAUTHORIZED