[REFORMAT] Ran black reformat

This commit is contained in:
c-d-p
2025-04-23 01:00:56 +02:00
parent d5d0a24403
commit 1553004efc
38 changed files with 1005 additions and 384 deletions

View File

@@ -77,9 +77,7 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View File

@@ -5,12 +5,12 @@ Revises:
Create Date: 2025-04-21 01:14:33.233195 Create Date: 2025-04-21 01:14:33.233195
""" """
from typing import Sequence, Union from typing import Sequence, Union
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '69069d6184b3' revision: str = "69069d6184b3"
down_revision: Union[str, None] = None down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None

View File

@@ -5,13 +5,13 @@ Revises: 69069d6184b3
Create Date: 2025-04-21 20:33:27.028529 Create Date: 2025-04-21 20:33:27.028529
""" """
from typing import Sequence, Union from typing import Sequence, Union
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '9a82960db482' revision: str = "9a82960db482"
down_revision: Union[str, None] = '69069d6184b3' down_revision: Union[str, None] = "69069d6184b3"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None

View File

@@ -6,7 +6,10 @@ celery_app = Celery(
"worker", "worker",
broker=settings.REDIS_URL, broker=settings.REDIS_URL,
backend=settings.REDIS_URL, backend=settings.REDIS_URL,
include=["modules.auth.tasks", "modules.admin.tasks"] # Add paths to modules containing tasks include=[
"modules.auth.tasks",
"modules.admin.tasks",
], # Add paths to modules containing tasks
# Add other modules with tasks here, e.g., "modules.some_other_module.tasks" # Add other modules with tasks here, e.g., "modules.some_other_module.tasks"
) )

View File

@@ -4,6 +4,7 @@ import os
DOTENV_PATH = os.path.join(os.path.dirname(__file__), "../.env") DOTENV_PATH = os.path.join(os.path.dirname(__file__), "../.env")
class Settings(BaseSettings): class Settings(BaseSettings):
# Database settings - reads DB_URL from environment or .env # Database settings - reads DB_URL from environment or .env
DB_URL: str = "postgresql://maia:maia@localhost:5432/maia" DB_URL: str = "postgresql://maia:maia@localhost:5432/maia"
@@ -24,8 +25,9 @@ class Settings(BaseSettings):
class Config: class Config:
# Tell pydantic-settings to load variables from a .env file # Tell pydantic-settings to load variables from a .env file
env_file = DOTENV_PATH env_file = DOTENV_PATH
env_file_encoding = 'utf-8' env_file_encoding = "utf-8"
extra = 'ignore' extra = "ignore"
# Create a single instance of the settings # Create a single instance of the settings
settings = Settings() settings = Settings()

View File

@@ -10,6 +10,7 @@ Base = declarative_base() # Used for models
_engine = None _engine = None
_SessionLocal = None _SessionLocal = None
def get_engine(): def get_engine():
global _engine global _engine
if _engine is None: if _engine is None:
@@ -20,10 +21,13 @@ def get_engine():
try: try:
_engine.connect() _engine.connect()
except Exception: except Exception:
raise Exception("Database connection failed. Is the database server running?") raise Exception(
"Database connection failed. Is the database server running?"
)
Base.metadata.create_all(_engine) # Create tables here Base.metadata.create_all(_engine) # Create tables here
return _engine return _engine
def get_sessionmaker(): def get_sessionmaker():
global _SessionLocal global _SessionLocal
if _SessionLocal is None: if _SessionLocal is None:
@@ -31,6 +35,7 @@ def get_sessionmaker():
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return _SessionLocal return _SessionLocal
def get_db() -> Generator[Session, None, None]: def get_db() -> Generator[Session, None, None]:
SessionLocal = get_sessionmaker() SessionLocal = get_sessionmaker()
db = SessionLocal() db = SessionLocal()

View File

@@ -8,20 +8,26 @@ from starlette.status import (
HTTP_409_CONFLICT, HTTP_409_CONFLICT,
) )
def bad_request_exception(detail: str = "Bad Request"): def bad_request_exception(detail: str = "Bad Request"):
return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail) return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
def unauthorized_exception(detail: str = "Unauthorized"): def unauthorized_exception(detail: str = "Unauthorized"):
return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail) return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
def forbidden_exception(detail: str = "Forbidden"): def forbidden_exception(detail: str = "Forbidden"):
return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail) return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
def not_found_exception(detail: str = "Not Found"): def not_found_exception(detail: str = "Not Found"):
return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail) return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
def internal_server_error_exception(detail: str = "Internal Server Error"): def internal_server_error_exception(detail: str = "Internal Server Error"):
return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail) return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
def conflict_exception(detail: str = "Conflict"): def conflict_exception(detail: str = "Conflict"):
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail) return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)

View File

@@ -11,7 +11,8 @@ import logging
# import all models to ensure they are registered before create_all # import all models to ensure they are registered before create_all
logging.getLogger('passlib').setLevel(logging.ERROR) # fix bc package logging is broken logging.getLogger("passlib").setLevel(logging.ERROR) # fix bc package logging is broken
# Create DB tables (remove in production; use migrations instead) # Create DB tables (remove in production; use migrations instead)
def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]: def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]:
@@ -24,6 +25,7 @@ def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]
return lifespan return lifespan
lifespan = lifespan_factory() lifespan = lifespan_factory()
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
@@ -41,9 +43,10 @@ app.add_middleware(
], ],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"] allow_headers=["*"],
) )
# Health endpoint # Health endpoint
@app.get("/api/health") @app.get("/api/health")
def health(): def health():

View File

@@ -9,14 +9,17 @@ from .tasks import cleardb
router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)]) router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)])
# Define a Pydantic model for the request body # Define a Pydantic model for the request body
class ClearDbRequest(BaseModel): class ClearDbRequest(BaseModel):
hard: bool hard: bool
@router.get("/") @router.get("/")
def read_admin(): def read_admin():
return {"message": "Admin route"} return {"message": "Admin route"}
# Change to POST and use the request body model # Change to POST and use the request body model
@router.post("/cleardb") @router.post("/cleardb")
def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]): def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):

View File

@@ -1,5 +1,6 @@
from core.celery_app import celery_app from core.celery_app import celery_app
@celery_app.task @celery_app.task
def cleardb(hard: bool): def cleardb(hard: bool):
""" """

View File

@@ -3,9 +3,24 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from jose import JWTError from jose import JWTError
from modules.auth.models import User from modules.auth.models import User
from modules.auth.schemas import UserCreate, UserResponse, Token, RefreshTokenRequest, LogoutRequest from modules.auth.schemas import (
UserCreate,
UserResponse,
Token,
RefreshTokenRequest,
LogoutRequest,
)
from modules.auth.services import create_user from modules.auth.services import create_user
from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens from modules.auth.security import (
TokenType,
get_current_user,
oauth2_scheme,
create_access_token,
create_refresh_token,
verify_token,
authenticate_user,
blacklist_tokens,
)
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import Annotated from typing import Annotated
from core.database import get_db from core.database import get_db
@@ -15,12 +30,19 @@ from core.exceptions import unauthorized_exception
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
)
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]): def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
return create_user(user.username, user.password, user.name, db) return create_user(user.username, user.password, user.name, db)
@router.post("/login", response_model=Token) @router.post("/login", response_model=Token)
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]): def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[Session, Depends(get_db)],
):
""" """
Authenticate user and return JWT tokens in the response body. Authenticate user and return JWT tokens in the response body.
""" """
@@ -31,37 +53,51 @@ def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annota
detail="Incorrect username or password", detail="Incorrect username or password",
) )
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)) access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
)
refresh_token = create_refresh_token(data={"sub": user.username}) refresh_token = create_refresh_token(data={"sub": user.username})
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"} return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
@router.post("/refresh") @router.post("/refresh")
def refresh_token(payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]): def refresh_token(
payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]
):
print("Refreshing token...") print("Refreshing token...")
refresh_token = payload.refresh_token refresh_token = payload.refresh_token
if not refresh_token: if not refresh_token:
raise unauthorized_exception("Refresh token missing in request body") raise unauthorized_exception("Refresh token missing in request body")
user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db) user_data = verify_token(
refresh_token, expected_token_type=TokenType.REFRESH, db=db
)
if not user_data: if not user_data:
raise unauthorized_exception("Invalid refresh token") raise unauthorized_exception("Invalid refresh token")
new_access_token = create_access_token(data={"sub": user_data.username}) new_access_token = create_access_token(data={"sub": user_data.username})
return {"access_token": new_access_token, "token_type": "bearer"} return {"access_token": new_access_token, "token_type": "bearer"}
@router.post("/logout") @router.post("/logout")
def logout(payload: LogoutRequest, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme)): def logout(
payload: LogoutRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
access_token: str = Depends(oauth2_scheme),
):
try: try:
refresh_token = payload.refresh_token refresh_token = payload.refresh_token
if not refresh_token: if not refresh_token:
raise unauthorized_exception("Refresh token not found in request body") raise unauthorized_exception("Refresh token not found in request body")
blacklist_tokens( blacklist_tokens(access_token=access_token, refresh_token=refresh_token, db=db)
access_token=access_token,
refresh_token=refresh_token,
db=db
)
return {"message": "Logged out successfully"} return {"message": "Logged out successfully"}
except JWTError: except JWTError:

View File

@@ -5,14 +5,18 @@ from modules.auth.schemas import UserRole
from modules.auth.models import User from modules.auth.models import User
from core.exceptions import forbidden_exception from core.exceptions import forbidden_exception
class RoleChecker: class RoleChecker:
def __init__(self, allowed_roles: list[UserRole]): def __init__(self, allowed_roles: list[UserRole]):
self.allowed_roles = allowed_roles self.allowed_roles = allowed_roles
def __call__(self, user: User = Depends(get_current_user)): def __call__(self, user: User = Depends(get_current_user)):
if user.role not in self.allowed_roles: if user.role not in self.allowed_roles:
raise forbidden_exception("You do not have permission to perform this action.") raise forbidden_exception(
"You do not have permission to perform this action."
)
return user return user
admin_only = RoleChecker([UserRole.ADMIN]) admin_only = RoleChecker([UserRole.ADMIN])
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER]) any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])

View File

@@ -4,10 +4,12 @@ from sqlalchemy import Column, Integer, String, Enum, DateTime
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from enum import Enum as PyEnum from enum import Enum as PyEnum
class UserRole(str, PyEnum): class UserRole(str, PyEnum):
ADMIN = "admin" ADMIN = "admin"
USER = "user" USER = "user"
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)

View File

@@ -2,33 +2,41 @@
from enum import Enum as PyEnum from enum import Enum as PyEnum
from pydantic import BaseModel from pydantic import BaseModel
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
refresh_token: str | None = None refresh_token: str | None = None
class TokenData(BaseModel): class TokenData(BaseModel):
username: str | None = None username: str | None = None
scopes: list[str] = [] scopes: list[str] = []
class RefreshTokenRequest(BaseModel): class RefreshTokenRequest(BaseModel):
refresh_token: str refresh_token: str
class LogoutRequest(BaseModel): class LogoutRequest(BaseModel):
refresh_token: str refresh_token: str
class UserRole(str, PyEnum): class UserRole(str, PyEnum):
ADMIN = "admin" ADMIN = "admin"
USER = "user" USER = "user"
class UserCreate(BaseModel): class UserCreate(BaseModel):
username: str username: str
password: str password: str
name: str name: str
class UserPatch(BaseModel): class UserPatch(BaseModel):
name: str | None = None name: str | None = None
class UserResponse(BaseModel): class UserResponse(BaseModel):
uuid: str uuid: str
username: str username: str

View File

@@ -18,6 +18,7 @@ from modules.auth.schemas import TokenData
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
class TokenType(str, Enum): class TokenType(str, Enum):
ACCESS = "access" ACCESS = "access"
REFRESH = "refresh" REFRESH = "refresh"
@@ -25,11 +26,13 @@ class TokenType(str, Enum):
password_hasher = PasswordHasher() password_hasher = PasswordHasher()
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
"""Hash a password with Argon2 (and optional pepper).""" """Hash a password with Argon2 (and optional pepper)."""
peppered_password = password + settings.PEPPER # Prepend/append pepper peppered_password = password + settings.PEPPER # Prepend/append pepper
return password_hasher.hash(peppered_password) return password_hasher.hash(peppered_password)
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hashed version using Argon2.""" """Verify a password against its hashed version using Argon2."""
peppered_password = plain_password + settings.PEPPER peppered_password = plain_password + settings.PEPPER
@@ -38,6 +41,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
except VerifyMismatchError: except VerifyMismatchError:
return False return False
def authenticate_user(username: str, password: str, db: Session) -> User | None: def authenticate_user(username: str, password: str, db: Session) -> User | None:
""" """
Authenticate a user by checking username/password against the database. Authenticate a user by checking username/password against the database.
@@ -52,34 +56,39 @@ def authenticate_user(username: str, password: str, db: Session) -> User | None:
return user return user
def create_access_token(data: dict, expires_delta: timedelta | None = None): def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.now(timezone.utc) + expires_delta
else: else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# expire = datetime.now(timezone.utc) + timedelta(seconds=5) # expire = datetime.now(timezone.utc) + timedelta(seconds=5)
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS}) to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
return jwt.encode( return jwt.encode(
to_encode, to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
) )
def create_refresh_token(data: dict, expires_delta: timedelta | None = None): def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.now(timezone.utc) + expires_delta
else: else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) expire = datetime.now(timezone.utc) + timedelta(
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
)
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH}) to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
return jwt.encode( return jwt.encode(
to_encode, to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
) )
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
def verify_token(
token: str, expected_token_type: TokenType, db: Session
) -> TokenData | None:
"""Verify a JWT token and return TokenData if valid. """Verify a JWT token and return TokenData if valid.
Parameters Parameters
@@ -96,12 +105,17 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok
TokenData | None TokenData | None
TokenData instance if the token is valid, None otherwise. TokenData instance if the token is valid, None otherwise.
""" """
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None is_blacklisted = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
is not None
)
if is_blacklisted: if is_blacklisted:
return None return None
try: try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub") username: str = payload.get("sub")
token_type: str = payload.get("token_type") token_type: str = payload.get("token_type")
@@ -113,7 +127,10 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok
except JWTError: except JWTError:
return None return None
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
def get_current_user(
db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)
) -> User:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials", detail="Could not validate credentials",
@@ -121,14 +138,15 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen
) )
# Check if the token is blacklisted # Check if the token is blacklisted
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None is_blacklisted = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
is not None
)
if is_blacklisted: if is_blacklisted:
raise credentials_exception raise credentials_exception
try: try:
payload = jwt.decode( payload = jwt.decode(
token, token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
) )
username: str = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:
@@ -141,6 +159,7 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen
raise credentials_exception raise credentials_exception
return user return user
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None: def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
"""Blacklist both access and refresh tokens. """Blacklist both access and refresh tokens.
@@ -154,7 +173,9 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
Database session to perform the operation. Database session to perform the operation.
""" """
for token in [access_token, refresh_token]: for token in [access_token, refresh_token]:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
expires_at = datetime.fromtimestamp(payload.get("exp")) expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist # Add the token to the blacklist
@@ -163,8 +184,11 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
db.commit() db.commit()
def blacklist_token(token: str, db: Session) -> None: def blacklist_token(token: str, db: Session) -> None:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
expires_at = datetime.fromtimestamp(payload.get("exp")) expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist # Add the token to the blacklist

View File

@@ -23,7 +23,9 @@ def create_user(username: str, password: str, name: str, db: Session) -> UserRes
hashed_password = hash_password(password) hashed_password = hash_password(password)
user_uuid = str(uuid.uuid4()) user_uuid = str(uuid.uuid4())
user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid) user = User(
username=username, hashed_password=hashed_password, name=name, uuid=user_uuid
)
db.add(user) db.add(user)
db.commit() db.commit()
db.refresh(user) # Loads the generated ID db.refresh(user) # Loads the generated ID

View File

@@ -6,50 +6,63 @@ from typing import List, Optional
from modules.auth.dependencies import get_current_user from modules.auth.dependencies import get_current_user
from core.database import get_db from core.database import get_db
from modules.auth.models import User from modules.auth.models import User
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate, CalendarEventResponse from modules.calendar.schemas import (
from modules.calendar.service import create_calendar_event, get_calendar_event_by_id, get_calendar_events, update_calendar_event, delete_calendar_event CalendarEventCreate,
CalendarEventUpdate,
CalendarEventResponse,
)
from modules.calendar.service import (
create_calendar_event,
get_calendar_event_by_id,
get_calendar_events,
update_calendar_event,
delete_calendar_event,
)
router = APIRouter(prefix="/calendar", tags=["calendar"]) router = APIRouter(prefix="/calendar", tags=["calendar"])
@router.post("/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED
)
def create_event( def create_event(
event: CalendarEventCreate, event: CalendarEventCreate,
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
return create_calendar_event(db, user.id, event) return create_calendar_event(db, user.id, event)
@router.get("/events", response_model=List[CalendarEventResponse]) @router.get("/events", response_model=List[CalendarEventResponse])
def get_events( def get_events(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
start: Optional[datetime] = None, start: Optional[datetime] = None,
end: Optional[datetime] = None end: Optional[datetime] = None,
): ):
return get_calendar_events(db, user.id, start, end) return get_calendar_events(db, user.id, start, end)
@router.get("/events/{event_id}", response_model=CalendarEventResponse) @router.get("/events/{event_id}", response_model=CalendarEventResponse)
def get_event_by_id( def get_event_by_id(
event_id: int, event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
event = get_calendar_event_by_id(db, user.id, event_id) event = get_calendar_event_by_id(db, user.id, event_id)
return event return event
@router.patch("/events/{event_id}", response_model=CalendarEventResponse) @router.patch("/events/{event_id}", response_model=CalendarEventResponse)
def update_event( def update_event(
event_id: int, event_id: int,
event: CalendarEventUpdate, event: CalendarEventUpdate,
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
return update_calendar_event(db, user.id, event_id, event) return update_calendar_event(db, user.id, event_id, event)
@router.delete("/events/{event_id}", status_code=204) @router.delete("/events/{event_id}", status_code=204)
def delete_event( def delete_event(
event_id: int, event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
delete_calendar_event(db, user.id, event_id) delete_calendar_event(db, user.id, event_id)

View File

@@ -1,8 +1,17 @@
# modules/calendar/models.py # modules/calendar/models.py
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Boolean # Add Boolean from sqlalchemy import (
Column,
Integer,
String,
DateTime,
ForeignKey,
JSON,
Boolean,
) # Add Boolean
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from core.database import Base from core.database import Base
class CalendarEvent(Base): class CalendarEvent(Base):
__tablename__ = "calendar_events" __tablename__ = "calendar_events"
@@ -15,7 +24,9 @@ class CalendarEvent(Base):
all_day = Column(Boolean, default=False) # Add all_day column all_day = Column(Boolean, default=False) # Add all_day column
tags = Column(JSON) tags = Column(JSON)
color = Column(String) # hex code for color color = Column(String) # hex code for color
user_id = Column(Integer, ForeignKey("users.id"), nullable=False) # <-- Relationship user_id = Column(
Integer, ForeignKey("users.id"), nullable=False
) # <-- Relationship
# Bi-directional relationship (for eager loading) # Bi-directional relationship (for eager loading)
user = relationship("User", back_populates="calendar_events") user = relationship("User", back_populates="calendar_events")

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from pydantic import BaseModel, field_validator # Add field_validator from pydantic import BaseModel, field_validator # Add field_validator
from typing import List, Optional # Add List and Optional from typing import List, Optional # Add List and Optional
# Base schema for common fields, including tags # Base schema for common fields, including tags
class CalendarEventBase(BaseModel): class CalendarEventBase(BaseModel):
title: str title: str
@@ -14,17 +15,19 @@ class CalendarEventBase(BaseModel):
all_day: Optional[bool] = None # Add all_day field all_day: Optional[bool] = None # Add all_day field
tags: Optional[List[str]] = None # Add optional tags tags: Optional[List[str]] = None # Add optional tags
@field_validator('tags', mode='before') @field_validator("tags", mode="before")
@classmethod @classmethod
def tags_validate_null_string(cls, v): def tags_validate_null_string(cls, v):
if v == "Null": if v == "Null":
return None return None
return v return v
# Schema for creating an event (inherits from Base) # Schema for creating an event (inherits from Base)
class CalendarEventCreate(CalendarEventBase): class CalendarEventCreate(CalendarEventBase):
pass pass
# Schema for updating an event (all fields optional) # Schema for updating an event (all fields optional)
class CalendarEventUpdate(BaseModel): class CalendarEventUpdate(BaseModel):
title: Optional[str] = None title: Optional[str] = None
@@ -36,20 +39,21 @@ class CalendarEventUpdate(BaseModel):
all_day: Optional[bool] = None # Add all_day field all_day: Optional[bool] = None # Add all_day field
tags: Optional[List[str]] = None # Add optional tags for update tags: Optional[List[str]] = None # Add optional tags for update
@field_validator('tags', mode='before') @field_validator("tags", mode="before")
@classmethod @classmethod
def tags_validate_null_string(cls, v): def tags_validate_null_string(cls, v):
if v == "Null": if v == "Null":
return None return None
return v return v
# Schema for the response (inherits from Base, adds ID and user_id) # Schema for the response (inherits from Base, adds ID and user_id)
class CalendarEventResponse(CalendarEventBase): class CalendarEventResponse(CalendarEventBase):
id: int id: int
user_id: int user_id: int
tags: List[str] # Keep as List[str], remove default [] tags: List[str] # Keep as List[str], remove default []
@field_validator('tags', mode='before') @field_validator("tags", mode="before")
@classmethod @classmethod
def tags_validate_none_to_list(cls, v): def tags_validate_none_to_list(cls, v):
# If the value from the source object (e.g., ORM model) is None, # If the value from the source object (e.g., ORM model) is None,

View File

@@ -4,22 +4,31 @@ from sqlalchemy import or_ # Import or_
from datetime import datetime from datetime import datetime
from modules.calendar.models import CalendarEvent from modules.calendar.models import CalendarEvent
from core.exceptions import not_found_exception from core.exceptions import not_found_exception
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate # Import schemas from modules.calendar.schemas import (
CalendarEventCreate,
CalendarEventUpdate,
) # Import schemas
def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate): def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate):
# Ensure tags is None if not provided or empty list, matching model # Ensure tags is None if not provided or empty list, matching model
tags_to_store = event_data.tags if event_data.tags else None tags_to_store = event_data.tags if event_data.tags else None
event = CalendarEvent( event = CalendarEvent(
**event_data.model_dump(exclude={'tags'}), # Use model_dump and exclude tags initially **event_data.model_dump(
exclude={"tags"}
), # Use model_dump and exclude tags initially
tags=tags_to_store, # Set tags separately tags=tags_to_store, # Set tags separately
user_id=user_id user_id=user_id,
) )
db.add(event) db.add(event)
db.commit() db.commit()
db.refresh(event) db.refresh(event)
return event return event
def get_calendar_events(db: Session, user_id: int, start: datetime | None, end: datetime | None):
def get_calendar_events(
db: Session, user_id: int, start: datetime | None, end: datetime | None
):
""" """
Retrieves calendar events for a user, optionally filtered by a date range. Retrieves calendar events for a user, optionally filtered by a date range.
@@ -46,9 +55,13 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
query = query.filter( query = query.filter(
or_( or_(
# Case 1: Event has duration and overlaps # Case 1: Event has duration and overlaps
(CalendarEvent.end is not None) & (CalendarEvent.start < end) & (CalendarEvent.end > start), (CalendarEvent.end is not None)
& (CalendarEvent.start < end)
& (CalendarEvent.end > start),
# Case 2: Event is a point event within the range # Case 2: Event is a point event within the range
(CalendarEvent.end is None) & (CalendarEvent.start >= start) & (CalendarEvent.start < end) (CalendarEvent.end is None)
& (CalendarEvent.start >= start)
& (CalendarEvent.start < end),
) )
) )
# If only start is provided, filter events starting on or after start # If only start is provided, filter events starting on or after start
@@ -65,32 +78,36 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
# Event ends before the specified end time # Event ends before the specified end time
(CalendarEvent.end is not None) & (CalendarEvent.end <= end), (CalendarEvent.end is not None) & (CalendarEvent.end <= end),
# Point event occurs before the specified end time # Point event occurs before the specified end time
(CalendarEvent.end is None) & (CalendarEvent.start < end) (CalendarEvent.end is None) & (CalendarEvent.start < end),
) )
) )
# Alternative interpretation for "ending before end": include events that *start* before end # Alternative interpretation for "ending before end": include events that *start* before end
# query = query.filter(CalendarEvent.start < end) # query = query.filter(CalendarEvent.start < end)
return query.order_by(CalendarEvent.start).all() # Order by start time return query.order_by(CalendarEvent.start).all() # Order by start time
def get_calendar_event_by_id(db: Session, user_id: int, event_id: int): def get_calendar_event_by_id(db: Session, user_id: int, event_id: int):
event = db.query(CalendarEvent).filter( event = (
CalendarEvent.id == event_id, db.query(CalendarEvent)
CalendarEvent.user_id == user_id .filter(CalendarEvent.id == event_id, CalendarEvent.user_id == user_id)
).first() .first()
)
if not event: if not event:
raise not_found_exception() raise not_found_exception()
return event return event
def update_calendar_event(db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate):
def update_calendar_event(
db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate
):
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
# Use model_dump with exclude_unset=True to only update provided fields # Use model_dump with exclude_unset=True to only update provided fields
update_data = event_data.model_dump(exclude_unset=True) update_data = event_data.model_dump(exclude_unset=True)
for key, value in update_data.items(): for key, value in update_data.items():
# Ensure tags is handled correctly (set to None if empty list provided) # Ensure tags is handled correctly (set to None if empty list provided)
if key == 'tags' and isinstance(value, list) and not value: if key == "tags" and isinstance(value, list) and not value:
setattr(event, key, None) setattr(event, key, None)
else: else:
setattr(event, key, value) setattr(event, key, value)
@@ -99,6 +116,7 @@ def update_calendar_event(db: Session, user_id: int, event_id: int, event_data:
db.refresh(event) db.refresh(event)
return event return event
def delete_calendar_event(db: Session, user_id: int, event_id: int): def delete_calendar_event(db: Session, user_id: int, event_id: int):
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
db.delete(event) db.delete(event)

View File

@@ -7,13 +7,27 @@ from core.database import get_db
from modules.auth.dependencies import get_current_user from modules.auth.dependencies import get_current_user
from modules.auth.models import User from modules.auth.models import User
# Import the new service functions and Enum # Import the new service functions and Enum
from modules.nlp.service import process_request, ask_ai, save_chat_message, get_chat_history, MessageSender from modules.nlp.service import (
process_request,
ask_ai,
save_chat_message,
get_chat_history,
MessageSender,
)
# Import the response schema and the new ChatMessage model for response type hinting # Import the response schema and the new ChatMessage model for response type hinting
from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse
from modules.calendar.service import create_calendar_event, get_calendar_events, update_calendar_event, delete_calendar_event from modules.calendar.service import (
create_calendar_event,
get_calendar_events,
update_calendar_event,
delete_calendar_event,
)
from modules.calendar.models import CalendarEvent from modules.calendar.models import CalendarEvent
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate
# Import TODO services, schemas, and model # Import TODO services, schemas, and model
from modules.todo import service as todo_service from modules.todo import service as todo_service
from modules.todo.models import Todo from modules.todo.models import Todo
@@ -21,6 +35,7 @@ from modules.todo.schemas import TodoCreate, TodoUpdate
from pydantic import BaseModel from pydantic import BaseModel
from datetime import datetime from datetime import datetime
class ChatMessageResponse(BaseModel): class ChatMessageResponse(BaseModel):
id: int id: int
sender: MessageSender # Use the enum directly sender: MessageSender # Use the enum directly
@@ -30,8 +45,10 @@ class ChatMessageResponse(BaseModel):
class Config: class Config:
from_attributes = True # Allow Pydantic to work with ORM models from_attributes = True # Allow Pydantic to work with ORM models
router = APIRouter(prefix="/nlp", tags=["nlp"]) router = APIRouter(prefix="/nlp", tags=["nlp"])
# Helper to format calendar events (expects list of CalendarEvent models) # Helper to format calendar events (expects list of CalendarEvent models)
def format_calendar_events(events: List[CalendarEvent]) -> List[str]: def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
if not events: if not events:
@@ -39,12 +56,15 @@ def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
formatted = ["Here are the events:"] formatted = ["Here are the events:"]
for event in events: for event in events:
# Access attributes directly from the model instance # Access attributes directly from the model instance
start_str = event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time" start_str = (
event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
)
end_str = event.end.strftime("%H:%M") if event.end else "" end_str = event.end.strftime("%H:%M") if event.end else ""
title = event.title or "Untitled Event" title = event.title or "Untitled Event"
formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})") formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})")
return formatted return formatted
# Helper to format TODO items (expects list of Todo models) # Helper to format TODO items (expects list of Todo models)
def format_todos(todos: List[Todo]) -> List[str]: def format_todos(todos: List[Todo]) -> List[str]:
if not todos: if not todos:
@@ -54,19 +74,28 @@ def format_todos(todos: List[Todo]) -> List[str]:
status = "[X]" if todo.complete else "[ ]" status = "[X]" if todo.complete else "[ ]"
date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else "" date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else ""
remind_str = " (Reminder)" if todo.remind else "" remind_str = " (Reminder)" if todo.remind else ""
formatted.append(f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})") formatted.append(
f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})"
)
return formatted return formatted
# Update the response model for the endpoint # Update the response model for the endpoint
@router.post("/process-command", response_model=ProcessCommandResponse) @router.post("/process-command", response_model=ProcessCommandResponse)
def process_command(request_data: ProcessCommandRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): def process_command(
request_data: ProcessCommandRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
""" """
Process the user command, save messages, execute action, save response, and return user-friendly responses. Process the user command, save messages, execute action, save response, and return user-friendly responses.
""" """
user_input = request_data.user_input user_input = request_data.user_input
# --- Save User Message --- # --- Save User Message ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.USER, text=user_input) save_chat_message(
db, user_id=current_user.id, sender=MessageSender.USER, text=user_input
)
# ------------------------ # ------------------------
command_data = process_request(user_input) command_data = process_request(user_input)
@@ -78,7 +107,9 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
# --- Save Initial AI Response --- # --- Save Initial AI Response ---
# Save the first response generated by process_request # Save the first response generated by process_request
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text) save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=response_text
)
# ----------------------------- # -----------------------------
if intent == "error": if intent == "error":
@@ -97,139 +128,233 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
ai_answer = ask_ai(**params) ai_answer = ask_ai(**params)
responses.append(ai_answer) responses.append(ai_answer)
# --- Save Additional AI Response --- # --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer) save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer
)
# --------------------------------- # ---------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
case "get_calendar_events": case "get_calendar_events":
events: List[CalendarEvent] = get_calendar_events(db, current_user.id, **params) events: List[CalendarEvent] = get_calendar_events(
db, current_user.id, **params
)
formatted_responses = format_calendar_events(events) formatted_responses = format_calendar_events(events)
responses.extend(formatted_responses) responses.extend(formatted_responses)
# --- Save Additional AI Responses --- # --- Save Additional AI Responses ---
for resp in formatted_responses: for resp in formatted_responses:
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp) save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
)
# ---------------------------------- # ----------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
case "add_calendar_event": case "add_calendar_event":
event_data = CalendarEventCreate(**params) event_data = CalendarEventCreate(**params)
created_event = create_calendar_event(db, current_user.id, event_data) created_event = create_calendar_event(db, current_user.id, event_data)
start_str = created_event.start.strftime("%Y-%m-%d %H:%M") if created_event.start else "No start time" start_str = (
created_event.start.strftime("%Y-%m-%d %H:%M")
if created_event.start
else "No start time"
)
title = created_event.title or "Untitled Event" title = created_event.title or "Untitled Event"
add_response = f"Added: {title} starting at {start_str}." add_response = f"Added: {title} starting at {start_str}."
responses.append(add_response) responses.append(add_response)
# --- Save Additional AI Response --- # --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=add_response,
)
# --------------------------------- # ---------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
case "update_calendar_event": case "update_calendar_event":
event_id = params.pop('event_id', None) event_id = params.pop("event_id", None)
if event_id is None: if event_id is None:
# Save the error message before raising # Save the error message before raising
error_msg = "Event ID is required for update." error_msg = "Event ID is required for update."
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg) raise HTTPException(status_code=400, detail=error_msg)
event_data = CalendarEventUpdate(**params) event_data = CalendarEventUpdate(**params)
updated_event = update_calendar_event(db, current_user.id, event_id, event_data=event_data) updated_event = update_calendar_event(
db, current_user.id, event_id, event_data=event_data
)
title = updated_event.title or "Untitled Event" title = updated_event.title or "Untitled Event"
update_response = f"Updated event ID {updated_event.id}: {title}." update_response = f"Updated event ID {updated_event.id}: {title}."
responses.append(update_response) responses.append(update_response)
# --- Save Additional AI Response --- # --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=update_response,
)
# --------------------------------- # ---------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
case "delete_calendar_event": case "delete_calendar_event":
event_id = params.get('event_id') event_id = params.get("event_id")
if event_id is None: if event_id is None:
# Save the error message before raising # Save the error message before raising
error_msg = "Event ID is required for delete." error_msg = "Event ID is required for delete."
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg) raise HTTPException(status_code=400, detail=error_msg)
delete_calendar_event(db, current_user.id, event_id) delete_calendar_event(db, current_user.id, event_id)
delete_response = f"Deleted event ID {event_id}." delete_response = f"Deleted event ID {event_id}."
responses.append(delete_response) responses.append(delete_response)
# --- Save Additional AI Response --- # --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=delete_response,
)
# --------------------------------- # ---------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
# --- Add TODO Cases --- # --- Add TODO Cases ---
case "get_todos": case "get_todos":
todos: List[Todo] = todo_service.get_todos(db, user=current_user, **params) todos: List[Todo] = todo_service.get_todos(
db, user=current_user, **params
)
formatted_responses = format_todos(todos) formatted_responses = format_todos(todos)
responses.extend(formatted_responses) responses.extend(formatted_responses)
# --- Save Additional AI Responses --- # --- Save Additional AI Responses ---
for resp in formatted_responses: for resp in formatted_responses:
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp) save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
)
# ---------------------------------- # ----------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
case "add_todo": case "add_todo":
todo_data = TodoCreate(**params) todo_data = TodoCreate(**params)
created_todo = todo_service.create_todo(db, todo=todo_data, user=current_user) created_todo = todo_service.create_todo(
add_response = f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})." db, todo=todo_data, user=current_user
)
add_response = (
f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
)
responses.append(add_response) responses.append(add_response)
# --- Save Additional AI Response --- # --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=add_response,
)
# --------------------------------- # ---------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
case "update_todo": case "update_todo":
todo_id = params.pop('todo_id', None) todo_id = params.pop("todo_id", None)
if todo_id is None: if todo_id is None:
error_msg = "TODO ID is required for update." error_msg = "TODO ID is required for update."
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg) raise HTTPException(status_code=400, detail=error_msg)
todo_data = TodoUpdate(**params) todo_data = TodoUpdate(**params)
updated_todo = todo_service.update_todo(db, todo_id=todo_id, todo_update=todo_data, user=current_user) updated_todo = todo_service.update_todo(
update_response = f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'." db, todo_id=todo_id, todo_update=todo_data, user=current_user
if 'complete' in params: )
status = "complete" if params['complete'] else "incomplete" update_response = (
f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
)
if "complete" in params:
status = "complete" if params["complete"] else "incomplete"
update_response += f" Marked as {status}." update_response += f" Marked as {status}."
responses.append(update_response) responses.append(update_response)
# --- Save Additional AI Response --- # --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=update_response,
)
# --------------------------------- # ---------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
case "delete_todo": case "delete_todo":
todo_id = params.get('todo_id') todo_id = params.get("todo_id")
if todo_id is None: if todo_id is None:
error_msg = "TODO ID is required for delete." error_msg = "TODO ID is required for delete."
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg) raise HTTPException(status_code=400, detail=error_msg)
deleted_todo = todo_service.delete_todo(db, todo_id=todo_id, user=current_user) deleted_todo = todo_service.delete_todo(
delete_response = f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'." db, todo_id=todo_id, user=current_user
)
delete_response = (
f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
)
responses.append(delete_response) responses.append(delete_response)
# --- Save Additional AI Response --- # --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=delete_response,
)
# --------------------------------- # ---------------------------------
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
# --- End TODO Cases --- # --- End TODO Cases ---
case _: case _:
print(f"Warning: Unhandled intent '{intent}' reached api.py match statement.") print(
f"Warning: Unhandled intent '{intent}' reached api.py match statement."
)
# The initial response_text was already saved # The initial response_text was already saved
return ProcessCommandResponse(responses=responses) return ProcessCommandResponse(responses=responses)
except HTTPException as http_exc: except HTTPException as http_exc:
# Don't save again if already saved before raising # Don't save again if already saved before raising
if http_exc.status_code != 400 or ('event_id' not in http_exc.detail.lower()): if http_exc.status_code != 400 or ("event_id" not in http_exc.detail.lower()):
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=http_exc.detail) save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=http_exc.detail,
)
raise http_exc raise http_exc
except Exception as e: except Exception as e:
print(f"Error executing intent '{intent}': {e}") print(f"Error executing intent '{intent}': {e}")
error_response = "Sorry, I encountered an error while trying to perform that action." error_response = (
"Sorry, I encountered an error while trying to perform that action."
)
# --- Save Final Error AI Response --- # --- Save Final Error AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_response) save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=error_response
)
# ---------------------------------- # ----------------------------------
return ProcessCommandResponse(responses=[error_response]) return ProcessCommandResponse(responses=[error_response])
@router.get("/history", response_model=List[ChatMessageResponse]) @router.get("/history", response_model=List[ChatMessageResponse])
def read_chat_history(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): def read_chat_history(
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Retrieves the last 50 chat messages for the current user.""" """Retrieves the last 50 chat messages for the current user."""
history = get_chat_history(db, user_id=current_user.id, limit=50) history = get_chat_history(db, user_id=current_user.id, limit=50)
return history return history
# ------------------------------------- # -------------------------------------

View File

@@ -1,4 +1,3 @@
\
# /home/cdp/code/MAIA/backend/modules/nlp/models.py # /home/cdp/code/MAIA/backend/modules/nlp/models.py
from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey, Enum as SQLEnum from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey, Enum as SQLEnum
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -7,10 +6,12 @@ import enum
from core.database import Base from core.database import Base
class MessageSender(enum.Enum): class MessageSender(enum.Enum):
USER = "user" USER = "user"
AI = "ai" AI = "ai"
class ChatMessage(Base): class ChatMessage(Base):
__tablename__ = "chat_messages" __tablename__ = "chat_messages"

View File

@@ -2,9 +2,11 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import List from typing import List
class ProcessCommandRequest(BaseModel): class ProcessCommandRequest(BaseModel):
user_input: str user_input: str
class ProcessCommandResponse(BaseModel): class ProcessCommandResponse(BaseModel):
responses: List[str] responses: List[str]
# Optional: Keep details if needed for specific frontend logic beyond display # Optional: Keep details if needed for specific frontend logic beyond display

View File

@@ -14,7 +14,8 @@ from core.config import settings
client = genai.Client(api_key=settings.GOOGLE_API_KEY) client = genai.Client(api_key=settings.GOOGLE_API_KEY)
### Base prompt for MAIA, used for inital user requests ### Base prompt for MAIA, used for inital user requests
SYSTEM_PROMPT = """ SYSTEM_PROMPT = (
"""
You are MAIA - My AI Assistant. Your job is to parse user requests into structured JSON commands and generate a user-facing response text. You are MAIA - My AI Assistant. Your job is to parse user requests into structured JSON commands and generate a user-facing response text.
Available functions/intents: Available functions/intents:
@@ -109,8 +110,11 @@ MAIA:
"response_text": "Okay, I've deleted task 2 from your list." "response_text": "Okay, I've deleted task 2 from your list."
} }
The datetime right now is """+str(datetime.now(timezone.utc))+""". The datetime right now is """
+ str(datetime.now(timezone.utc))
+ """.
""" """
)
### Prompt for MAIA to forward user request to AI ### Prompt for MAIA to forward user request to AI
SYSTEM_FORWARD_PROMPT = f""" SYSTEM_FORWARD_PROMPT = f"""
@@ -123,6 +127,7 @@ Here is the user request:
# --- Chat History Service Functions --- # --- Chat History Service Functions ---
def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: str): def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: str):
"""Saves a chat message to the database.""" """Saves a chat message to the database."""
db_message = ChatMessage(user_id=user_id, sender=sender, text=text) db_message = ChatMessage(user_id=user_id, sender=sender, text=text)
@@ -131,16 +136,21 @@ def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: st
db.refresh(db_message) db.refresh(db_message)
return db_message return db_message
def get_chat_history(db: Session, user_id: int, limit: int = 50) -> List[ChatMessage]: def get_chat_history(db: Session, user_id: int, limit: int = 50) -> List[ChatMessage]:
"""Retrieves the last 'limit' chat messages for a user.""" """Retrieves the last 'limit' chat messages for a user."""
return db.query(ChatMessage)\ return (
.filter(ChatMessage.user_id == user_id)\ db.query(ChatMessage)
.order_by(desc(ChatMessage.timestamp))\ .filter(ChatMessage.user_id == user_id)
.limit(limit)\ .order_by(desc(ChatMessage.timestamp))
.all()[::-1] # Reverse to get oldest first for display order .limit(limit)
.all()[::-1]
) # Reverse to get oldest first for display order
# --- Existing NLP Service Functions --- # --- Existing NLP Service Functions ---
def process_request(request: str): def process_request(request: str):
""" """
Process the user request using the Google GenAI API. Process the user request using the Google GenAI API.
@@ -152,7 +162,7 @@ def process_request(request: str):
config={ config={
"temperature": 0.3, # Less creativity, more factual "temperature": 0.3, # Less creativity, more factual
"response_mime_type": "application/json", "response_mime_type": "application/json",
} },
) )
# Parse the JSON response # Parse the JSON response
@@ -160,7 +170,9 @@ def process_request(request: str):
parsed_response = json.loads(response.text) parsed_response = json.loads(response.text)
# Validate required fields # Validate required fields
if not all(k in parsed_response for k in ("intent", "params", "response_text")): if not all(k in parsed_response for k in ("intent", "params", "response_text")):
raise ValueError("AI response missing required fields (intent, params, response_text)") raise ValueError(
"AI response missing required fields (intent, params, response_text)"
)
return parsed_response return parsed_response
except (json.JSONDecodeError, ValueError) as e: except (json.JSONDecodeError, ValueError) as e:
print(f"Error parsing AI response: {e}") print(f"Error parsing AI response: {e}")
@@ -169,9 +181,10 @@ def process_request(request: str):
return { return {
"intent": "error", "intent": "error",
"params": {}, "params": {},
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?" "response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?",
} }
def ask_ai(request: str): def ask_ai(request: str):
""" """
Ask the AI a question. Ask the AI a question.

View File

@@ -15,48 +15,55 @@ router = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
@router.post("/", response_model=schemas.Todo, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=schemas.Todo, status_code=status.HTTP_201_CREATED)
def create_todo_endpoint( def create_todo_endpoint(
todo: schemas.TodoCreate, todo: schemas.TodoCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency current_user: User = Depends(get_current_user), # Corrected dependency
): ):
return service.create_todo(db=db, todo=todo, user=current_user) return service.create_todo(db=db, todo=todo, user=current_user)
@router.get("/", response_model=List[schemas.Todo]) @router.get("/", response_model=List[schemas.Todo])
def read_todos_endpoint( def read_todos_endpoint(
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency current_user: User = Depends(get_current_user), # Corrected dependency
): ):
todos = service.get_todos(db=db, user=current_user, skip=skip, limit=limit) todos = service.get_todos(db=db, user=current_user, skip=skip, limit=limit)
return todos return todos
@router.get("/{todo_id}", response_model=schemas.Todo) @router.get("/{todo_id}", response_model=schemas.Todo)
def read_todo_endpoint( def read_todo_endpoint(
todo_id: int, todo_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency current_user: User = Depends(get_current_user), # Corrected dependency
): ):
db_todo = service.get_todo(db=db, todo_id=todo_id, user=current_user) db_todo = service.get_todo(db=db, todo_id=todo_id, user=current_user)
if db_todo is None: if db_todo is None:
raise HTTPException(status_code=404, detail="Todo not found") raise HTTPException(status_code=404, detail="Todo not found")
return db_todo return db_todo
@router.put("/{todo_id}", response_model=schemas.Todo) @router.put("/{todo_id}", response_model=schemas.Todo)
def update_todo_endpoint( def update_todo_endpoint(
todo_id: int, todo_id: int,
todo_update: schemas.TodoUpdate, todo_update: schemas.TodoUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency current_user: User = Depends(get_current_user), # Corrected dependency
): ):
return service.update_todo(db=db, todo_id=todo_id, todo_update=todo_update, user=current_user) return service.update_todo(
db=db, todo_id=todo_id, todo_update=todo_update, user=current_user
)
@router.delete("/{todo_id}", response_model=schemas.Todo) @router.delete("/{todo_id}", response_model=schemas.Todo)
def delete_todo_endpoint( def delete_todo_endpoint(
todo_id: int, todo_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency current_user: User = Depends(get_current_user), # Corrected dependency
): ):
return service.delete_todo(db=db, todo_id=todo_id, user=current_user) return service.delete_todo(db=db, todo_id=todo_id, user=current_user)

View File

@@ -3,6 +3,7 @@ from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from core.database import Base from core.database import Base
class Todo(Base): class Todo(Base):
__tablename__ = "todos" __tablename__ = "todos"
@@ -13,4 +14,6 @@ class Todo(Base):
complete = Column(Boolean, default=False) complete = Column(Boolean, default=False)
owner_id = Column(Integer, ForeignKey("users.id")) owner_id = Column(Integer, ForeignKey("users.id"))
owner = relationship("User") # Add relationship if needed, assuming User model exists in auth.models owner = relationship(
"User"
) # Add relationship if needed, assuming User model exists in auth.models

View File

@@ -3,21 +3,25 @@ from pydantic import BaseModel
from typing import Optional from typing import Optional
import datetime import datetime
class TodoBase(BaseModel): class TodoBase(BaseModel):
task: str task: str
date: Optional[datetime.datetime] = None date: Optional[datetime.datetime] = None
remind: bool = False remind: bool = False
complete: bool = False complete: bool = False
class TodoCreate(TodoBase): class TodoCreate(TodoBase):
pass pass
class TodoUpdate(BaseModel): class TodoUpdate(BaseModel):
task: Optional[str] = None task: Optional[str] = None
date: Optional[datetime.datetime] = None date: Optional[datetime.datetime] = None
remind: Optional[bool] = None remind: Optional[bool] = None
complete: Optional[bool] = None complete: Optional[bool] = None
class Todo(TodoBase): class Todo(TodoBase):
id: int id: int
owner_id: int owner_id: int

View File

@@ -4,6 +4,7 @@ from . import models, schemas
from modules.auth.models import User # Assuming User model is in auth.models from modules.auth.models import User # Assuming User model is in auth.models
from fastapi import HTTPException, status from fastapi import HTTPException, status
def create_todo(db: Session, todo: schemas.TodoCreate, user: User): def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
db_todo = models.Todo(**todo.dict(), owner_id=user.id) db_todo = models.Todo(**todo.dict(), owner_id=user.id)
db.add(db_todo) db.add(db_todo)
@@ -11,17 +12,34 @@ def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
db.refresh(db_todo) db.refresh(db_todo)
return db_todo return db_todo
def get_todos(db: Session, user: User, skip: int = 0, limit: int = 100): def get_todos(db: Session, user: User, skip: int = 0, limit: int = 100):
return db.query(models.Todo).filter(models.Todo.owner_id == user.id).offset(skip).limit(limit).all() return (
db.query(models.Todo)
.filter(models.Todo.owner_id == user.id)
.offset(skip)
.limit(limit)
.all()
)
def get_todo(db: Session, todo_id: int, user: User): def get_todo(db: Session, todo_id: int, user: User):
db_todo = db.query(models.Todo).filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id).first() db_todo = (
db.query(models.Todo)
.filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id)
.first()
)
if db_todo is None: if db_todo is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found") raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found"
)
return db_todo return db_todo
def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user: User): def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user: User):
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence db_todo = get_todo(
db=db, todo_id=todo_id, user=user
) # Reuse get_todo to check ownership and existence
update_data = todo_update.dict(exclude_unset=True) update_data = todo_update.dict(exclude_unset=True)
for key, value in update_data.items(): for key, value in update_data.items():
setattr(db_todo, key, value) setattr(db_todo, key, value)
@@ -29,8 +47,11 @@ def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user
db.refresh(db_todo) db.refresh(db_todo)
return db_todo return db_todo
def delete_todo(db: Session, todo_id: int, user: User): def delete_todo(db: Session, todo_id: int, user: User):
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence db_todo = get_todo(
db=db, todo_id=todo_id, user=user
) # Reuse get_todo to check ownership and existence
db.delete(db_todo) db.delete(db_todo)
db.commit() db.commit()
return db_todo return db_todo

View File

@@ -11,16 +11,25 @@ from modules.auth.models import User
router = APIRouter(prefix="/user", tags=["user"]) router = APIRouter(prefix="/user", tags=["user"])
@router.get("/me", response_model=UserResponse) @router.get("/me", response_model=UserResponse)
def me(db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse: def me(
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
) -> UserResponse:
""" """
Get the current user. Requires user to be logged in. Get the current user. Requires user to be logged in.
Returns the user object. Returns the user object.
""" """
return current_user return current_user
@router.get("/{username}", response_model=UserResponse) @router.get("/{username}", response_model=UserResponse)
def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse: def get_user(
username: str,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
) -> UserResponse:
""" """
Get a user by username. Get a user by username.
Returns the user object. Returns the user object.
@@ -33,8 +42,14 @@ def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_use
raise not_found_exception("User not found") raise not_found_exception("User not found")
return user return user
@router.patch("/{username}", response_model=UserResponse) @router.patch("/{username}", response_model=UserResponse)
def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse: def update_user(
username: str,
user_data: UserPatch,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
) -> UserResponse:
""" """
Update a user by username. Update a user by username.
Returns the updated user object. Returns the updated user object.
@@ -60,8 +75,13 @@ def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depe
db.refresh(user) db.refresh(user)
return user return user
@router.delete("/{username}", response_model=UserResponse) @router.delete("/{username}", response_model=UserResponse)
def delete_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse: def delete_user(
username: str,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
) -> UserResponse:
""" """
Delete a user by username. Delete a user by username.
Returns the deleted user object. Returns the deleted user object.

View File

@@ -12,6 +12,7 @@ from core.database import get_db, get_sessionmaker
fake = Faker() fake = Faker()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def postgres_container() -> Generator[PostgresContainer, None, None]: def postgres_container() -> Generator[PostgresContainer, None, None]:
"""Fixture to create a PostgreSQL container for testing.""" """Fixture to create a PostgreSQL container for testing."""
@@ -22,6 +23,7 @@ def postgres_container() -> Generator[PostgresContainer, None, None]:
yield postgres yield postgres
print("Postgres container stopped.") print("Postgres container stopped.")
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def db(postgres_container) -> Generator[Session, None, None]: def db(postgres_container) -> Generator[Session, None, None]:
"""Function-scoped database session with rollback""" """Function-scoped database session with rollback"""
@@ -34,6 +36,7 @@ def db(postgres_container) -> Generator[Session, None, None]:
session.rollback() session.rollback()
session.close() session.close()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client(db: Session) -> Generator[TestClient, None, None]: def client(db: Session) -> Generator[TestClient, None, None]:
"""Function-scoped test client with dependency override""" """Function-scoped test client with dependency override"""
@@ -53,6 +56,8 @@ def client(db: Session) -> Generator[TestClient, None, None]:
app.dependency_overrides.clear() app.dependency_overrides.clear()
def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None: def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None:
from main import app from main import app
app.dependency_overrides[dependency] = lambda: mocked_response app.dependency_overrides[dependency] = lambda: mocked_response

View File

@@ -5,13 +5,20 @@ from sqlalchemy.orm import Session
from core.config import settings from core.config import settings
from modules.auth.models import User from modules.auth.models import User
from modules.auth.security import authenticate_user, create_access_token, create_refresh_token, hash_password from modules.auth.security import (
authenticate_user,
create_access_token,
create_refresh_token,
hash_password,
)
from modules.auth.schemas import UserRole from modules.auth.schemas import UserRole
from tests.conftest import fake from tests.conftest import fake
from typing import Optional # Import Optional from typing import Optional # Import Optional
def create_user(db: Session, is_admin: bool = False, username: Optional[str] = None) -> User: def create_user(
db: Session, is_admin: bool = False, username: Optional[str] = None
) -> User:
unhashed_password = fake.password() unhashed_password = fake.password()
_user = User( _user = User(
name=fake.name(), name=fake.name(),
@@ -26,12 +33,16 @@ def create_user(db: Session, is_admin: bool = False, username: Optional[str] = N
db.refresh(_user) db.refresh(_user)
return _user, unhashed_password # return for testing return _user, unhashed_password # return for testing
def login(db: Session, username: str, password: str) -> str: def login(db: Session, username: str, password: str) -> str:
user = authenticate_user(username, password, db) user = authenticate_user(username, password, db)
if not user: if not user:
raise Exception("Incorrect username or password") raise Exception("Incorrect username or password")
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)) access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
)
refresh_token = create_refresh_token(data={"sub": user.username}) refresh_token = create_refresh_token(data={"sub": user.username})
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60 max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60

View File

@@ -7,62 +7,84 @@ from tests.helpers import generators
# Test admin routes require admin privileges # Test admin routes require admin privileges
def test_read_admin_unauthorized(client: TestClient) -> None: def test_read_admin_unauthorized(client: TestClient) -> None:
"""Test accessing admin route without authentication.""" """Test accessing admin route without authentication."""
response = client.get("/api/admin/") response = client.get("/api/admin/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_read_admin_forbidden(db: Session, client: TestClient) -> None: def test_read_admin_forbidden(db: Session, client: TestClient) -> None:
"""Test accessing admin route as a non-admin user.""" """Test accessing admin route as a non-admin user."""
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"}) response = client.get(
"/api/admin/", headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
def test_read_admin_success(db: Session, client: TestClient) -> None: def test_read_admin_success(db: Session, client: TestClient) -> None:
"""Test accessing admin route as an admin user.""" """Test accessing admin route as an admin user."""
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True admin_user, password = generators.create_user(
db, is_admin=True
) # Use is_admin=True
login_rsp = generators.login(db, admin_user.username, password) login_rsp = generators.login(db, admin_user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"}) response = client.get(
"/api/admin/", headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == {"message": "Admin route"} assert response.json() == {"message": "Admin route"}
@patch("modules.admin.api.cleardb.delay") # Mock the celery task @patch("modules.admin.api.cleardb.delay") # Mock the celery task
def test_clear_db_soft(mock_cleardb_delay, db: Session, client: TestClient) -> None: def test_clear_db_soft(mock_cleardb_delay, db: Session, client: TestClient) -> None:
"""Test soft clearing the database as admin.""" """Test soft clearing the database as admin."""
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True admin_user, password = generators.create_user(
db, is_admin=True
) # Use is_admin=True
login_rsp = generators.login(db, admin_user.username, password) login_rsp = generators.login(db, admin_user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
response = client.post( response = client.post(
"/api/admin/cleardb", "/api/admin/cleardb",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
json={"hard": False} json={"hard": False},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == {"message": "Clearing database in the background", "hard": False} assert response.json() == {
"message": "Clearing database in the background",
"hard": False,
}
mock_cleardb_delay.assert_called_once_with(False) mock_cleardb_delay.assert_called_once_with(False)
@patch("modules.admin.api.cleardb.delay") # Mock the celery task @patch("modules.admin.api.cleardb.delay") # Mock the celery task
def test_clear_db_hard(mock_cleardb_delay, db: Session, client: TestClient) -> None: def test_clear_db_hard(mock_cleardb_delay, db: Session, client: TestClient) -> None:
"""Test hard clearing the database as admin.""" """Test hard clearing the database as admin."""
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True admin_user, password = generators.create_user(
db, is_admin=True
) # Use is_admin=True
login_rsp = generators.login(db, admin_user.username, password) login_rsp = generators.login(db, admin_user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
response = client.post( response = client.post(
"/api/admin/cleardb", "/api/admin/cleardb",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
json={"hard": True} json={"hard": True},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == {"message": "Clearing database in the background", "hard": True} assert response.json() == {
"message": "Clearing database in the background",
"hard": True,
}
mock_cleardb_delay.assert_called_once_with(True) mock_cleardb_delay.assert_called_once_with(True)
def test_clear_db_forbidden(db: Session, client: TestClient) -> None: def test_clear_db_forbidden(db: Session, client: TestClient) -> None:
"""Test clearing the database as a non-admin user.""" """Test clearing the database as a non-admin user."""
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
@@ -72,6 +94,6 @@ def test_clear_db_forbidden(db: Session, client: TestClient) -> None:
response = client.post( response = client.post(
"/api/admin/cleardb", "/api/admin/cleardb",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
json={"hard": False} json={"hard": False},
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN

View File

@@ -34,6 +34,7 @@ def test_register(client: TestClient) -> None:
) )
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
def test_login(db: Session, client: TestClient) -> None: def test_login(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db) user, unhashed_password = generators.create_user(db)
@@ -51,6 +52,7 @@ def test_login(db: Session, client: TestClient) -> None:
assert "token_type" in response_data assert "token_type" in response_data
assert response_data["token_type"] == "bearer" assert response_data["token_type"] == "bearer"
def test_refresh_token(db: Session, client: TestClient) -> None: def test_refresh_token(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db) user, unhashed_password = generators.create_user(db)
rsp = generators.login(db, user.username, unhashed_password) rsp = generators.login(db, user.username, unhashed_password)
@@ -61,7 +63,10 @@ def test_refresh_token(db: Session, client: TestClient) -> None:
response = client.post( response = client.post(
"/api/auth/refresh", "/api/auth/refresh",
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}, headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"refresh_token": refresh_token}, json={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -70,7 +75,10 @@ def test_refresh_token(db: Session, client: TestClient) -> None:
assert "access_token" in response_data assert "access_token" in response_data
assert "token_type" in response_data assert "token_type" in response_data
assert response_data["token_type"] == "bearer" assert response_data["token_type"] == "bearer"
assert response_data["access_token"] != access_token # Ensure the token is refreshed assert (
response_data["access_token"] != access_token
) # Ensure the token is refreshed
def test_logout(db: Session, client: TestClient) -> None: def test_logout(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db) user, unhashed_password = generators.create_user(db)
@@ -80,13 +88,18 @@ def test_logout(db: Session, client: TestClient) -> None:
response = client.post( response = client.post(
"/api/auth/logout", "/api/auth/logout",
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}, headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"refresh_token": refresh_token}, json={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
# Verify that the token is blacklisted # Verify that the token is blacklisted
blacklisted_token = db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first() blacklisted_token = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
)
assert blacklisted_token is not None assert blacklisted_token is not None
# Verify that we can't still actually do anything # Verify that we can't still actually do anything
@@ -98,7 +111,10 @@ def test_logout(db: Session, client: TestClient) -> None:
response = client.post( response = client.post(
"/api/auth/refresh", "/api/auth/refresh",
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}, headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"refresh_token": refresh_token}, json={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -106,7 +122,9 @@ def test_logout(db: Session, client: TestClient) -> None:
def test_get_me(db: Session, client: TestClient) -> None: def test_get_me(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db) user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"] access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.get( response = client.get(
"/api/user/me", "/api/user/me",
@@ -119,14 +137,18 @@ def test_get_me(db: Session, client: TestClient) -> None:
assert response_data["uuid"] == user.uuid assert response_data["uuid"] == user.uuid
assert response_data["username"] == user.username assert response_data["username"] == user.username
def test_get_me_unauthorized(client: TestClient) -> None: def test_get_me_unauthorized(client: TestClient) -> None:
### This test should fail (unauthorized) because the user isn't logged in ### This test should fail (unauthorized) because the user isn't logged in
response = client.get("/api/user/me") response = client.get("/api/user/me")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_user(db: Session, client: TestClient) -> None: def test_get_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db) user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"] access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.get( response = client.get(
f"/api/user/{user.username}", f"/api/user/{user.username}",
@@ -139,11 +161,14 @@ def test_get_user(db: Session, client: TestClient) -> None:
assert response_data["uuid"] == user.uuid assert response_data["uuid"] == user.uuid
assert response_data["username"] == user.username assert response_data["username"] == user.username
def test_get_user_unauthorized(db: Session, client: TestClient) -> None: def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
### This test should fail (unauthorized) because the user isn't us ### This test should fail (unauthorized) because the user isn't us
user, unhashed_password = generators.create_user(db) user, unhashed_password = generators.create_user(db)
user2, _ = generators.create_user(db) user2, _ = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"] access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.get( response = client.get(
f"/api/user/{user2.username}", f"/api/user/{user2.username}",
@@ -151,11 +176,14 @@ def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
def test_update_user(db: Session, client: TestClient) -> None: def test_update_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db) user, unhashed_password = generators.create_user(db)
new_name = fake.name() new_name = fake.name()
access_token = generators.login(db, user.username, unhashed_password)["access_token"] access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.patch( response = client.patch(
f"/api/user/{user.username}", f"/api/user/{user.username}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
@@ -168,7 +196,9 @@ def test_update_user(db: Session, client: TestClient) -> None:
def test_delete_user(db: Session, client: TestClient) -> None: def test_delete_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db) user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"] access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.delete( response = client.delete(
f"/api/user/{user.username}", f"/api/user/{user.username}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
@@ -179,6 +209,7 @@ def test_delete_user(db: Session, client: TestClient) -> None:
deleted_user = db.query(User).filter(User.username == user.username).first() deleted_user = db.query(User).filter(User.username == user.username).first()
assert deleted_user is None assert deleted_user is None
def test_get_user_forbidden(db: Session, client: TestClient) -> None: def test_get_user_forbidden(db: Session, client: TestClient) -> None:
"""Test getting another user's profile (should be forbidden).""" """Test getting another user's profile (should be forbidden)."""
user1, password_user1 = generators.create_user(db, username="user1_get_forbidden") user1, password_user1 = generators.create_user(db, username="user1_get_forbidden")
@@ -195,9 +226,12 @@ def test_get_user_forbidden(db: Session, client: TestClient) -> None:
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
def test_update_user_forbidden(db: Session, client: TestClient) -> None: def test_update_user_forbidden(db: Session, client: TestClient) -> None:
"""Test updating another user's profile (should be forbidden).""" """Test updating another user's profile (should be forbidden)."""
user1, password_user1 = generators.create_user(db, username="user1_update_forbidden") user1, password_user1 = generators.create_user(
db, username="user1_update_forbidden"
)
user2, _ = generators.create_user(db, username="user2_update_forbidden") user2, _ = generators.create_user(db, username="user2_update_forbidden")
new_name = fake.name() new_name = fake.name()
@@ -213,9 +247,12 @@ def test_update_user_forbidden(db: Session, client: TestClient) -> None:
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
def test_delete_user_forbidden(db: Session, client: TestClient) -> None: def test_delete_user_forbidden(db: Session, client: TestClient) -> None:
"""Test deleting another user's profile (should be forbidden).""" """Test deleting another user's profile (should be forbidden)."""
user1, password_user1 = generators.create_user(db, username="user1_delete_forbidden") user1, password_user1 = generators.create_user(
db, username="user1_delete_forbidden"
)
user2, _ = generators.create_user(db, username="user2_delete_forbidden") user2, _ = generators.create_user(db, username="user2_delete_forbidden")
# Log in as user1 # Log in as user1

View File

@@ -7,6 +7,7 @@ from tests.helpers import generators
from modules.calendar.models import CalendarEvent # Assuming model exists from modules.calendar.models import CalendarEvent # Assuming model exists
from tests.conftest import fake from tests.conftest import fake
# Helper function to create an event payload # Helper function to create an event payload
def create_event_payload(start_offset_days=0, end_offset_days=1): def create_event_payload(start_offset_days=0, end_offset_days=1):
start_time = datetime.utcnow() + timedelta(days=start_offset_days) start_time = datetime.utcnow() + timedelta(days=start_offset_days)
@@ -19,14 +20,17 @@ def create_event_payload(start_offset_days=0, end_offset_days=1):
"all_day": fake.boolean(), "all_day": fake.boolean(),
} }
# --- Test Create Event --- # --- Test Create Event ---
def test_create_event_unauthorized(client: TestClient) -> None: def test_create_event_unauthorized(client: TestClient) -> None:
"""Test creating an event without authentication.""" """Test creating an event without authentication."""
payload = create_event_payload() payload = create_event_payload()
response = client.post("/api/calendar/events", json=payload) response = client.post("/api/calendar/events", json=payload)
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_create_event_success(db: Session, client: TestClient) -> None: def test_create_event_success(db: Session, client: TestClient) -> None:
"""Test creating a calendar event successfully.""" """Test creating a calendar event successfully."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
@@ -37,9 +41,11 @@ def test_create_event_success(db: Session, client: TestClient) -> None:
response = client.post( response = client.post(
"/api/calendar/events", "/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
json=payload json=payload,
) )
assert response.status_code == status.HTTP_201_CREATED # Change expected status to 201 assert (
response.status_code == status.HTTP_201_CREATED
) # Change expected status to 201
data = response.json() data = response.json()
assert data["title"] == payload["title"] assert data["title"] == payload["title"]
assert data["description"] == payload["description"] assert data["description"] == payload["description"]
@@ -56,13 +62,16 @@ def test_create_event_success(db: Session, client: TestClient) -> None:
assert event_in_db.user_id == user.id assert event_in_db.user_id == user.id
assert event_in_db.title == payload["title"] assert event_in_db.title == payload["title"]
# --- Test Get Events --- # --- Test Get Events ---
def test_get_events_unauthorized(client: TestClient) -> None: def test_get_events_unauthorized(client: TestClient) -> None:
"""Test getting events without authentication.""" """Test getting events without authentication."""
response = client.get("/api/calendar/events") response = client.get("/api/calendar/events")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_events_success(db: Session, client: TestClient) -> None: def test_get_events_success(db: Session, client: TestClient) -> None:
"""Test getting all calendar events for a user.""" """Test getting all calendar events for a user."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
@@ -71,21 +80,31 @@ def test_get_events_success(db: Session, client: TestClient) -> None:
# Create a couple of events for the user # Create a couple of events for the user
payload1 = create_event_payload(0, 1) payload1 = create_event_payload(0, 1)
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1) client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload1,
)
payload2 = create_event_payload(2, 3) payload2 = create_event_payload(2, 3)
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2) client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload2,
)
# Create an event for another user (should not be returned) # Create an event for another user (should not be returned)
other_user, other_password = generators.create_user(db) other_user, other_password = generators.create_user(db)
other_login_rsp = generators.login(db, other_user.username, other_password) other_login_rsp = generators.login(db, other_user.username, other_password)
other_access_token = other_login_rsp["access_token"] other_access_token = other_login_rsp["access_token"]
other_payload = create_event_payload(4, 5) other_payload = create_event_payload(4, 5)
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {other_access_token}"}, json=other_payload) client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {other_access_token}"},
json=other_payload,
)
response = client.get( response = client.get(
"/api/calendar/events", "/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}
headers={"Authorization": f"Bearer {access_token}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@@ -104,11 +123,23 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
# Create events # Create events
payload1 = create_event_payload(0, 1) # Today -> Tomorrow payload1 = create_event_payload(0, 1) # Today -> Tomorrow
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1) client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload1,
)
payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2) client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload2,
)
payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload3) client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload3,
)
# Filter for events starting within the next week # Filter for events starting within the next week
start_filter = datetime.utcnow().isoformat() start_filter = datetime.utcnow().isoformat()
@@ -117,7 +148,7 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
response = client.get( response = client.get(
"/api/calendar/events", "/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
params={"start": start_filter, "end": end_filter} params={"start": start_filter, "end": end_filter},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@@ -130,7 +161,7 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
response = client.get( response = client.get(
"/api/calendar/events", "/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
params={"start": start_filter_late} params={"start": start_filter_late},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@@ -140,30 +171,40 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
# --- Test Get Event By ID --- # --- Test Get Event By ID ---
def test_get_event_by_id_unauthorized(db: Session, client: TestClient) -> None: def test_get_event_by_id_unauthorized(db: Session, client: TestClient) -> None:
"""Test getting a specific event without authentication.""" """Test getting a specific event without authentication."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
response = client.get(f"/api/calendar/events/{event_id}") response = client.get(f"/api/calendar/events/{event_id}")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_event_by_id_success(db: Session, client: TestClient) -> None: def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
"""Test getting a specific event successfully.""" """Test getting a specific event successfully."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
response = client.get( response = client.get(
f"/api/calendar/events/{event_id}", f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token}"} headers={"Authorization": f"Bearer {access_token}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@@ -171,6 +212,7 @@ def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
assert data["title"] == payload["title"] assert data["title"] == payload["title"]
assert data["user_id"] == user.id assert data["user_id"] == user.id
def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None: def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
"""Test getting a non-existent event.""" """Test getting a non-existent event."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
@@ -180,10 +222,11 @@ def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
response = client.get( response = client.get(
f"/api/calendar/events/{non_existent_id}", f"/api/calendar/events/{non_existent_id}",
headers={"Authorization": f"Bearer {access_token}"} headers={"Authorization": f"Bearer {access_token}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None: def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
"""Test getting another user's event.""" """Test getting another user's event."""
user1, password_user1 = generators.create_user(db) user1, password_user1 = generators.create_user(db)
@@ -193,7 +236,11 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
login_rsp1 = generators.login(db, user1.username, password_user1) login_rsp1 = generators.login(db, user1.username, password_user1)
access_token1 = login_rsp1["access_token"] access_token1 = login_rsp1["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload) create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token1}"},
json=payload,
)
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
# Log in as user2 and try to get user1's event # Log in as user2 and try to get user1's event
@@ -202,45 +249,60 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
response = client.get( response = client.get(
f"/api/calendar/events/{event_id}", f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token2}"} headers={"Authorization": f"Bearer {access_token2}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match assert (
response.status_code == status.HTTP_404_NOT_FOUND
) # Service layer returns 404 if user_id doesn't match
# --- Test Update Event --- # --- Test Update Event ---
def test_update_event_unauthorized(db: Session, client: TestClient) -> None: def test_update_event_unauthorized(db: Session, client: TestClient) -> None:
"""Test updating an event without authentication.""" """Test updating an event without authentication."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
update_payload = {"title": "Updated Title"} update_payload = {"title": "Updated Title"}
response = client.patch(f"/api/calendar/events/{event_id}", json=update_payload) response = client.patch(f"/api/calendar/events/{event_id}", json=update_payload)
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_update_event_success(db: Session, client: TestClient) -> None: def test_update_event_success(db: Session, client: TestClient) -> None:
"""Test updating an event successfully.""" """Test updating an event successfully."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) create_response = client.post(
assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201 "/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
assert (
create_response.status_code == status.HTTP_201_CREATED
) # Ensure creation check uses 201
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
update_payload = { update_payload = {
"title": "Updated Title", "title": "Updated Title",
"description": "Updated description.", "description": "Updated description.",
"all_day": not payload["all_day"] # Toggle all_day "all_day": not payload["all_day"], # Toggle all_day
} }
response = client.patch( response = client.patch(
f"/api/calendar/events/{event_id}", f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
json=update_payload json=update_payload,
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
data = response.json() data = response.json()
@@ -258,6 +320,7 @@ def test_update_event_success(db: Session, client: TestClient) -> None:
assert event_in_db.description == update_payload["description"] assert event_in_db.description == update_payload["description"]
assert event_in_db.all_day == update_payload["all_day"] assert event_in_db.all_day == update_payload["all_day"]
def test_update_event_not_found(db: Session, client: TestClient) -> None: def test_update_event_not_found(db: Session, client: TestClient) -> None:
"""Test updating a non-existent event.""" """Test updating a non-existent event."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
@@ -269,10 +332,11 @@ def test_update_event_not_found(db: Session, client: TestClient) -> None:
response = client.patch( response = client.patch(
f"/api/calendar/events/{non_existent_id}", f"/api/calendar/events/{non_existent_id}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
json=update_payload json=update_payload,
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_update_event_forbidden(db: Session, client: TestClient) -> None: def test_update_event_forbidden(db: Session, client: TestClient) -> None:
"""Test updating another user's event.""" """Test updating another user's event."""
user1, password_user1 = generators.create_user(db) user1, password_user1 = generators.create_user(db)
@@ -282,7 +346,11 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None:
login_rsp1 = generators.login(db, user1.username, password_user1) login_rsp1 = generators.login(db, user1.username, password_user1)
access_token1 = login_rsp1["access_token"] access_token1 = login_rsp1["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload) create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token1}"},
json=payload,
)
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
# Log in as user2 and try to update user1's event # Log in as user2 and try to update user1's event
@@ -293,32 +361,47 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None:
response = client.patch( response = client.patch(
f"/api/calendar/events/{event_id}", f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token2}"}, headers={"Authorization": f"Bearer {access_token2}"},
json=update_payload json=update_payload,
) )
assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match assert (
response.status_code == status.HTTP_404_NOT_FOUND
) # Service layer returns 404 if user_id doesn't match
# --- Test Delete Event --- # --- Test Delete Event ---
def test_delete_event_unauthorized(db: Session, client: TestClient) -> None: def test_delete_event_unauthorized(db: Session, client: TestClient) -> None:
"""Test deleting an event without authentication.""" """Test deleting an event without authentication."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
response = client.delete(f"/api/calendar/events/{event_id}") response = client.delete(f"/api/calendar/events/{event_id}")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_delete_event_success(db: Session, client: TestClient) -> None: def test_delete_event_success(db: Session, client: TestClient) -> None:
"""Test deleting an event successfully.""" """Test deleting an event successfully."""
user, password = generators.create_user(db) user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"] access_token = login_rsp["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload) create_response = client.post(
assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201 "/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
assert (
create_response.status_code == status.HTTP_201_CREATED
) # Ensure creation check uses 201
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
# Verify event exists before delete # Verify event exists before delete
@@ -327,7 +410,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None:
response = client.delete( response = client.delete(
f"/api/calendar/events/{event_id}", f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token}"} headers={"Authorization": f"Bearer {access_token}"},
) )
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
@@ -338,7 +421,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None:
# Try getting the deleted event (should be 404) # Try getting the deleted event (should be 404)
get_response = client.get( get_response = client.get(
f"/api/calendar/events/{event_id}", f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token}"} headers={"Authorization": f"Bearer {access_token}"},
) )
assert get_response.status_code == status.HTTP_404_NOT_FOUND assert get_response.status_code == status.HTTP_404_NOT_FOUND
@@ -352,7 +435,7 @@ def test_delete_event_not_found(db: Session, client: TestClient) -> None:
response = client.delete( response = client.delete(
f"/api/calendar/events/{non_existent_id}", f"/api/calendar/events/{non_existent_id}",
headers={"Authorization": f"Bearer {access_token}"} headers={"Authorization": f"Bearer {access_token}"},
) )
# The service layer raises NotFound, which should result in 404 # The service layer raises NotFound, which should result in 404
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -367,7 +450,11 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
login_rsp1 = generators.login(db, user1.username, password_user1) login_rsp1 = generators.login(db, user1.username, password_user1)
access_token1 = login_rsp1["access_token"] access_token1 = login_rsp1["access_token"]
payload = create_event_payload() payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload) create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token1}"},
json=payload,
)
event_id = create_response.json()["id"] event_id = create_response.json()["id"]
# Log in as user2 and try to delete user1's event # Log in as user2 and try to delete user1's event
@@ -376,7 +463,7 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
response = client.delete( response = client.delete(
f"/api/calendar/events/{event_id}", f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token2}"} headers={"Authorization": f"Bearer {access_token2}"},
) )
# The service layer raises NotFound if user_id doesn't match, resulting in 404 # The service layer raises NotFound if user_id doesn't match, resulting in 404
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -385,4 +472,3 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
event_in_db = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first() event_in_db = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first()
assert event_in_db is not None assert event_in_db is not None
assert event_in_db.user_id == user1.id assert event_in_db.user_id == user1.id

View File

@@ -2,6 +2,7 @@ from fastapi.testclient import TestClient
# No database needed for this simple test # No database needed for this simple test
def test_health_check(client: TestClient): def test_health_check(client: TestClient):
"""Test the health check endpoint.""" """Test the health check endpoint."""
response = client.get("/api/health") response = client.get("/api/health")

View File

@@ -7,24 +7,37 @@ from datetime import datetime
from tests.helpers import generators from tests.helpers import generators
from modules.nlp.schemas import ProcessCommandResponse from modules.nlp.schemas import ProcessCommandResponse
from modules.nlp.models import MessageSender, ChatMessage # Import necessary models/enums from modules.nlp.models import (
MessageSender,
ChatMessage,
) # Import necessary models/enums
# --- Mocks --- # --- Mocks ---
# Mock the external AI call and internal service functions # Mock the external AI call and internal service functions
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_nlp_services(): def mock_nlp_services():
with patch("modules.nlp.api.process_request") as mock_process, \ with patch("modules.nlp.api.process_request") as mock_process, patch(
patch("modules.nlp.api.ask_ai") as mock_ask, \ "modules.nlp.api.ask_ai"
patch("modules.nlp.api.save_chat_message") as mock_save, \ ) as mock_ask, patch("modules.nlp.api.save_chat_message") as mock_save, patch(
patch("modules.nlp.api.get_chat_history") as mock_get_history, \ "modules.nlp.api.get_chat_history"
patch("modules.nlp.api.create_calendar_event") as mock_create_event, \ ) as mock_get_history, patch(
patch("modules.nlp.api.get_calendar_events") as mock_get_events, \ "modules.nlp.api.create_calendar_event"
patch("modules.nlp.api.update_calendar_event") as mock_update_event, \ ) as mock_create_event, patch(
patch("modules.nlp.api.delete_calendar_event") as mock_delete_event, \ "modules.nlp.api.get_calendar_events"
patch("modules.nlp.api.todo_service.create_todo") as mock_create_todo, \ ) as mock_get_events, patch(
patch("modules.nlp.api.todo_service.get_todos") as mock_get_todos, \ "modules.nlp.api.update_calendar_event"
patch("modules.nlp.api.todo_service.update_todo") as mock_update_todo, \ ) as mock_update_event, patch(
patch("modules.nlp.api.todo_service.delete_todo") as mock_delete_todo: "modules.nlp.api.delete_calendar_event"
) as mock_delete_event, patch(
"modules.nlp.api.todo_service.create_todo"
) as mock_create_todo, patch(
"modules.nlp.api.todo_service.get_todos"
) as mock_get_todos, patch(
"modules.nlp.api.todo_service.update_todo"
) as mock_update_todo, patch(
"modules.nlp.api.todo_service.delete_todo"
) as mock_delete_todo:
mocks = { mocks = {
"process_request": mock_process, "process_request": mock_process,
"ask_ai": mock_ask, "ask_ai": mock_ask,
@@ -41,21 +54,24 @@ def mock_nlp_services():
} }
yield mocks yield mocks
# --- Helper Function --- # --- Helper Function ---
def _login_user(db: Session, client: TestClient): def _login_user(db: Session, client: TestClient):
user, password = generators.create_user(db) user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
return user, login_rsp["access_token"], login_rsp["refresh_token"] return user, login_rsp["access_token"], login_rsp["refresh_token"]
# --- Tests for /process-command --- # --- Tests for /process-command ---
def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services): def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
user_input = "What is the capital of France?" user_input = "What is the capital of France?"
mock_nlp_services["process_request"].return_value = { mock_nlp_services["process_request"].return_value = {
"intent": "ask_ai", "intent": "ask_ai",
"params": {"request": user_input}, "params": {"request": user_input},
"response_text": "Let me check that for you." "response_text": "Let me check that for you.",
} }
mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris." mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris."
@@ -63,25 +79,45 @@ def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_servic
"/api/nlp/process-command", "/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json={"user_input": user_input} json={"user_input": user_input},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == ProcessCommandResponse(responses=["Let me check that for you.", "The capital of France is Paris."]).model_dump() assert (
response.json()
== ProcessCommandResponse(
responses=["Let me check that for you.", "The capital of France is Paris."]
).model_dump()
)
# Verify save calls: user message, initial AI response, final AI answer # Verify save calls: user message, initial AI response, final AI answer
assert mock_nlp_services["save_chat_message"].call_count == 3 assert mock_nlp_services["save_chat_message"].call_count == 3
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input) mock_nlp_services["save_chat_message"].assert_any_call(
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you.") db, user_id=user.id, sender=MessageSender.USER, text=user_input
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="The capital of France is Paris.") )
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you."
)
mock_nlp_services["save_chat_message"].assert_any_call(
db,
user_id=user.id,
sender=MessageSender.AI,
text="The capital of France is Paris.",
)
mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input) mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input)
def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_services):
def test_process_command_get_calendar(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
user_input = "What are my events today?" user_input = "What are my events today?"
mock_nlp_services["process_request"].return_value = { mock_nlp_services["process_request"].return_value = {
"intent": "get_calendar_events", "intent": "get_calendar_events",
"params": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T23:59:59Z"}, # Example params "params": {
"response_text": "Okay, fetching your events." "start": "2024-01-01T00:00:00Z",
"end": "2024-01-01T23:59:59Z",
}, # Example params
"response_text": "Okay, fetching your events.",
} }
# Mock the actual event model returned by the service # Mock the actual event model returned by the service
mock_event = MagicMock() mock_event = MagicMock()
@@ -94,26 +130,32 @@ def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_
"/api/nlp/process-command", "/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json={"user_input": user_input} json={"user_input": user_input},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
expected_responses = [ expected_responses = [
"Okay, fetching your events.", "Okay, fetching your events.",
"Here are the events:", "Here are the events:",
"- Team Meeting (2024-01-01 10:00 - 11:00)" "- Team Meeting (2024-01-01 10:00 - 11:00)",
] ]
assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump() assert (
assert mock_nlp_services["save_chat_message"].call_count == 4 # User, Initial AI, Header, Event response.json()
== ProcessCommandResponse(responses=expected_responses).model_dump()
)
assert (
mock_nlp_services["save_chat_message"].call_count == 4
) # User, Initial AI, Header, Event
mock_nlp_services["get_calendar_events"].assert_called_once() mock_nlp_services["get_calendar_events"].assert_called_once()
def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services): def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
user_input = "Add buy milk to my list" user_input = "Add buy milk to my list"
mock_nlp_services["process_request"].return_value = { mock_nlp_services["process_request"].return_value = {
"intent": "add_todo", "intent": "add_todo",
"params": {"task": "buy milk"}, "params": {"task": "buy milk"},
"response_text": "Adding it now." "response_text": "Adding it now.",
} }
# Mock the actual Todo model returned by the service # Mock the actual Todo model returned by the service
mock_todo = MagicMock() mock_todo = MagicMock()
@@ -125,81 +167,119 @@ def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_serv
"/api/nlp/process-command", "/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json={"user_input": user_input} json={"user_input": user_input},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."] expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."]
assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump() assert (
assert mock_nlp_services["save_chat_message"].call_count == 3 # User, Initial AI, Confirmation AI response.json()
== ProcessCommandResponse(responses=expected_responses).model_dump()
)
assert (
mock_nlp_services["save_chat_message"].call_count == 3
) # User, Initial AI, Confirmation AI
mock_nlp_services["create_todo"].assert_called_once() mock_nlp_services["create_todo"].assert_called_once()
def test_process_command_clarification(client: TestClient, db: Session, mock_nlp_services):
def test_process_command_clarification(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
user_input = "Delete the event" user_input = "Delete the event"
clarification_text = "Which event do you mean? Please provide the ID." clarification_text = "Which event do you mean? Please provide the ID."
mock_nlp_services["process_request"].return_value = { mock_nlp_services["process_request"].return_value = {
"intent": "clarification_needed", "intent": "clarification_needed",
"params": {"request": user_input}, "params": {"request": user_input},
"response_text": clarification_text "response_text": clarification_text,
} }
response = client.post( response = client.post(
"/api/nlp/process-command", "/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json={"user_input": user_input} json={"user_input": user_input},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == ProcessCommandResponse(responses=[clarification_text]).model_dump() assert (
response.json()
== ProcessCommandResponse(responses=[clarification_text]).model_dump()
)
# Verify save calls: user message, clarification AI response # Verify save calls: user message, clarification AI response
assert mock_nlp_services["save_chat_message"].call_count == 2 assert mock_nlp_services["save_chat_message"].call_count == 2
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input) mock_nlp_services["save_chat_message"].assert_any_call(
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=clarification_text) db, user_id=user.id, sender=MessageSender.USER, text=user_input
)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text=clarification_text
)
# Ensure no action services were called # Ensure no action services were called
mock_nlp_services["delete_calendar_event"].assert_not_called() mock_nlp_services["delete_calendar_event"].assert_not_called()
def test_process_command_error_intent(client: TestClient, db: Session, mock_nlp_services):
def test_process_command_error_intent(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
user_input = "Gibberish request" user_input = "Gibberish request"
error_text = "Sorry, I didn't understand that." error_text = "Sorry, I didn't understand that."
mock_nlp_services["process_request"].return_value = { mock_nlp_services["process_request"].return_value = {
"intent": "error", "intent": "error",
"params": {}, "params": {},
"response_text": error_text "response_text": error_text,
} }
response = client.post( response = client.post(
"/api/nlp/process-command", "/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json={"user_input": user_input} json={"user_input": user_input},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == ProcessCommandResponse(responses=[error_text]).model_dump() assert (
response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
)
# Verify save calls: user message, error AI response # Verify save calls: user message, error AI response
assert mock_nlp_services["save_chat_message"].call_count == 2 assert mock_nlp_services["save_chat_message"].call_count == 2
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input) mock_nlp_services["save_chat_message"].assert_any_call(
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=error_text) db, user_id=user.id, sender=MessageSender.USER, text=user_input
)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text=error_text
)
# --- Tests for /history --- # --- Tests for /history ---
def test_get_history(client: TestClient, db: Session, mock_nlp_services): def test_get_history(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
# Mock the history data returned by the service # Mock the history data returned by the service
mock_history = [ mock_history = [
ChatMessage(id=1, user_id=user.id, sender=MessageSender.USER, text="Hello", timestamp=datetime.now()), ChatMessage(
ChatMessage(id=2, user_id=user.id, sender=MessageSender.AI, text="Hi there!", timestamp=datetime.now()) id=1,
user_id=user.id,
sender=MessageSender.USER,
text="Hello",
timestamp=datetime.now(),
),
ChatMessage(
id=2,
user_id=user.id,
sender=MessageSender.AI,
text="Hi there!",
timestamp=datetime.now(),
),
] ]
mock_nlp_services["get_chat_history"].return_value = mock_history mock_nlp_services["get_chat_history"].return_value = mock_history
response = client.get( response = client.get(
"/api/nlp/history", "/api/nlp/history",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token} cookies={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -208,11 +288,15 @@ def test_get_history(client: TestClient, db: Session, mock_nlp_services):
assert len(response_data) == 2 assert len(response_data) == 2
assert response_data[0]["text"] == "Hello" assert response_data[0]["text"] == "Hello"
assert response_data[1]["text"] == "Hi there!" assert response_data[1]["text"] == "Hi there!"
mock_nlp_services["get_chat_history"].assert_called_once_with(db, user_id=user.id, limit=50) mock_nlp_services["get_chat_history"].assert_called_once_with(
db, user_id=user.id, limit=50
)
def test_get_history_unauthorized(client: TestClient): def test_get_history_unauthorized(client: TestClient):
response = client.get("/api/nlp/history") response = client.get("/api/nlp/history")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
# Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.) # Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.)
# Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete) # Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete)

View File

@@ -5,14 +5,17 @@ from datetime import date
from tests.helpers import generators from tests.helpers import generators
# Helper Function # Helper Function
def _login_user(db: Session, client: TestClient): def _login_user(db: Session, client: TestClient):
user, password = generators.create_user(db) user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password) login_rsp = generators.login(db, user.username, password)
return user, login_rsp["access_token"], login_rsp["refresh_token"] return user, login_rsp["access_token"], login_rsp["refresh_token"]
# --- Test CRUD Operations --- # --- Test CRUD Operations ---
def test_create_todo(client: TestClient, db: Session): def test_create_todo(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
today_date = date.today() today_date = date.today()
@@ -20,14 +23,14 @@ def test_create_todo(client: TestClient, db: Session):
todo_data = { todo_data = {
"task": "Test TODO", "task": "Test TODO",
"date": f"{today_date.isoformat()}T00:00:00", "date": f"{today_date.isoformat()}T00:00:00",
"remind": True "remind": True,
} }
response = client.post( response = client.post(
"/api/todos/", "/api/todos/",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json=todo_data json=todo_data,
) )
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
@@ -39,24 +42,39 @@ def test_create_todo(client: TestClient, db: Session):
assert "id" in data assert "id" in data
assert data["owner_id"] == user.id assert data["owner_id"] == user.id
def test_read_todos(client: TestClient, db: Session): def test_read_todos(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
# Create some todos for the user # Create some todos for the user
client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 1"}) client.post(
client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 2"}) "/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"task": "Todo 1"},
)
client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"task": "Todo 2"},
)
# Create a todo for another user # Create a todo for another user
other_user, other_password = generators.create_user(db) other_user, other_password = generators.create_user(db)
other_login_rsp = generators.login(db, other_user.username, other_password) other_login_rsp = generators.login(db, other_user.username, other_password)
other_access_token = other_login_rsp["access_token"] other_access_token = other_login_rsp["access_token"]
other_refresh_token = other_login_rsp["refresh_token"] other_refresh_token = other_login_rsp["refresh_token"]
client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"}) client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {other_access_token}"},
cookies={"refresh_token": other_refresh_token},
json={"task": "Other User Todo"},
)
response = client.get( response = client.get(
"/api/todos/", "/api/todos/",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token} cookies={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -65,20 +83,21 @@ def test_read_todos(client: TestClient, db: Session):
assert data[0]["task"] == "Todo 1" assert data[0]["task"] == "Todo 1"
assert data[1]["task"] == "Todo 2" assert data[1]["task"] == "Todo 2"
def test_read_single_todo(client: TestClient, db: Session): def test_read_single_todo(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
create_response = client.post( create_response = client.post(
"/api/todos/", "/api/todos/",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json={"task": "Specific Todo"} json={"task": "Specific Todo"},
) )
todo_id = create_response.json()["id"] todo_id = create_response.json()["id"]
response = client.get( response = client.get(
f"/api/todos/{todo_id}", f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token} cookies={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -87,15 +106,17 @@ def test_read_single_todo(client: TestClient, db: Session):
assert data["task"] == "Specific Todo" assert data["task"] == "Specific Todo"
assert data["owner_id"] == user.id assert data["owner_id"] == user.id
def test_read_single_todo_not_found(client: TestClient, db: Session): def test_read_single_todo_not_found(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
response = client.get( response = client.get(
"/api/todos/9999", # Non-existent ID "/api/todos/9999", # Non-existent ID
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token} cookies={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_read_single_todo_forbidden(client: TestClient, db: Session): def test_read_single_todo_forbidden(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
@@ -104,16 +125,26 @@ def test_read_single_todo_forbidden(client: TestClient, db: Session):
other_login_rsp = generators.login(db, other_user.username, other_password) other_login_rsp = generators.login(db, other_user.username, other_password)
other_access_token = other_login_rsp["access_token"] other_access_token = other_login_rsp["access_token"]
other_refresh_token = other_login_rsp["refresh_token"] other_refresh_token = other_login_rsp["refresh_token"]
other_create_response = client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"}) other_create_response = client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {other_access_token}"},
cookies={"refresh_token": other_refresh_token},
json={"task": "Other User Todo"},
)
other_todo_id = other_create_response.json()["id"] other_todo_id = other_create_response.json()["id"]
# Try to access the other user's todo # Try to access the other user's todo
response = client.get( response = client.get(
f"/api/todos/{other_todo_id}", f"/api/todos/{other_todo_id}",
headers={"Authorization": f"Bearer {access_token}"}, # Using the first user's token headers={
cookies={"refresh_token": refresh_token} "Authorization": f"Bearer {access_token}"
}, # Using the first user's token
cookies={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND # Service raises 404 if not found for *this* user assert (
response.status_code == status.HTTP_404_NOT_FOUND
) # Service raises 404 if not found for *this* user
def test_update_todo(client: TestClient, db: Session): def test_update_todo(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
@@ -121,7 +152,7 @@ def test_update_todo(client: TestClient, db: Session):
"/api/todos/", "/api/todos/",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json={"task": "Update Me"} json={"task": "Update Me"},
) )
todo_id = create_response.json()["id"] todo_id = create_response.json()["id"]
@@ -130,7 +161,7 @@ def test_update_todo(client: TestClient, db: Session):
f"/api/todos/{todo_id}", f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json=update_data json=update_data,
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@@ -144,7 +175,7 @@ def test_update_todo(client: TestClient, db: Session):
get_response = client.get( get_response = client.get(
f"/api/todos/{todo_id}", f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token} cookies={"refresh_token": refresh_token},
) )
assert get_response.json()["task"] == update_data["task"] assert get_response.json()["task"] == update_data["task"]
assert get_response.json()["complete"] == update_data["complete"] assert get_response.json()["complete"] == update_data["complete"]
@@ -157,24 +188,25 @@ def test_update_todo_not_found(client: TestClient, db: Session):
"/api/todos/9999", # Non-existent ID "/api/todos/9999", # Non-existent ID
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json=update_data json=update_data,
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_delete_todo(client: TestClient, db: Session): def test_delete_todo(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
create_response = client.post( create_response = client.post(
"/api/todos/", "/api/todos/",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}, cookies={"refresh_token": refresh_token},
json={"task": "Delete Me"} json={"task": "Delete Me"},
) )
todo_id = create_response.json()["id"] todo_id = create_response.json()["id"]
response = client.delete( response = client.delete(
f"/api/todos/{todo_id}", f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token} cookies={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item
@@ -184,25 +216,29 @@ def test_delete_todo(client: TestClient, db: Session):
get_response = client.get( get_response = client.get(
f"/api/todos/{todo_id}", f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token} cookies={"refresh_token": refresh_token},
) )
assert get_response.status_code == status.HTTP_404_NOT_FOUND assert get_response.status_code == status.HTTP_404_NOT_FOUND
def test_delete_todo_not_found(client: TestClient, db: Session): def test_delete_todo_not_found(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client) user, access_token, refresh_token = _login_user(db, client)
response = client.delete( response = client.delete(
"/api/todos/9999", # Non-existent ID "/api/todos/9999", # Non-existent ID
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token} cookies={"refresh_token": refresh_token},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
# --- Test Authentication/Authorization --- # --- Test Authentication/Authorization ---
def test_create_todo_unauthorized(client: TestClient): def test_create_todo_unauthorized(client: TestClient):
response = client.post("/api/todos/", json={"task": "No Auth"}) response = client.post("/api/todos/", json={"task": "No Auth"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_read_todos_unauthorized(client: TestClient): def test_read_todos_unauthorized(client: TestClient):
response = client.get("/api/todos/") response = client.get("/api/todos/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED