[REFORMAT] Ran black reformat

This commit is contained in:
c-d-p
2025-04-23 01:00:56 +02:00
parent d5d0a24403
commit 1553004efc
38 changed files with 1005 additions and 384 deletions

View File

@@ -7,7 +7,7 @@ from sqlalchemy import pool
from alembic import context
from core.database import Base # Import your Base
from core.database import Base # Import your Base
# --- Add project root to sys.path ---
@@ -77,9 +77,7 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()

View File

@@ -1,16 +1,16 @@
"""Initial migration with existing tables
Revision ID: 69069d6184b3
Revises:
Revises:
Create Date: 2025-04-21 01:14:33.233195
"""
from typing import Sequence, Union
# revision identifiers, used by Alembic.
revision: str = '69069d6184b3'
revision: str = "69069d6184b3"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

View File

@@ -5,13 +5,13 @@ Revises: 69069d6184b3
Create Date: 2025-04-21 20:33:27.028529
"""
from typing import Sequence, Union
# revision identifiers, used by Alembic.
revision: str = '9a82960db482'
down_revision: Union[str, None] = '69069d6184b3'
revision: str = "9a82960db482"
down_revision: Union[str, None] = "69069d6184b3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

View File

@@ -1,14 +1,17 @@
# core/celery_app.py
from celery import Celery
from core.config import settings # Import your settings
from core.config import settings # Import your settings
celery_app = Celery(
"worker",
broker=settings.REDIS_URL,
backend=settings.REDIS_URL,
include=["modules.auth.tasks", "modules.admin.tasks"] # Add paths to modules containing tasks
include=[
"modules.auth.tasks",
"modules.admin.tasks",
], # Add paths to modules containing tasks
# Add other modules with tasks here, e.g., "modules.some_other_module.tasks"
)
# Optional: Update Celery configuration directly if needed
# celery_app.conf.update(task_track_started=True)
# celery_app.conf.update(task_track_started=True)

View File

@@ -4,12 +4,13 @@ import os
DOTENV_PATH = os.path.join(os.path.dirname(__file__), "../.env")
class Settings(BaseSettings):
# Database settings - reads DB_URL from environment or .env
DB_URL: str = "postgresql://maia:maia@localhost:5432/maia"
# Redis settings - reads REDIS_URL from environment or .env, also used for Celery.
REDIS_URL: str ="redis://localhost:6379/0"
REDIS_URL: str = "redis://localhost:6379/0"
# JWT settings - reads from environment or .env
JWT_ALGORITHM: str = "HS256"
@@ -19,13 +20,14 @@ class Settings(BaseSettings):
JWT_SECRET_KEY: str
# Other settings
GOOGLE_API_KEY: str = "" # Example with a default
GOOGLE_API_KEY: str = "" # Example with a default
class Config:
# Tell pydantic-settings to load variables from a .env file
env_file = DOTENV_PATH
env_file_encoding = 'utf-8'
extra = 'ignore'
env_file_encoding = "utf-8"
extra = "ignore"
# Create a single instance of the settings
settings = Settings()

View File

@@ -10,6 +10,7 @@ Base = declarative_base() # Used for models
_engine = None
_SessionLocal = None
def get_engine():
global _engine
if _engine is None:
@@ -20,10 +21,13 @@ def get_engine():
try:
_engine.connect()
except Exception:
raise Exception("Database connection failed. Is the database server running?")
raise Exception(
"Database connection failed. Is the database server running?"
)
Base.metadata.create_all(_engine) # Create tables here
return _engine
def get_sessionmaker():
global _SessionLocal
if _SessionLocal is None:
@@ -31,10 +35,11 @@ def get_sessionmaker():
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return _SessionLocal
def get_db() -> Generator[Session, None, None]:
SessionLocal = get_sessionmaker()
db = SessionLocal()
try:
yield db
finally:
db.close()
db.close()

View File

@@ -8,20 +8,26 @@ from starlette.status import (
HTTP_409_CONFLICT,
)
def bad_request_exception(detail: str = "Bad Request"):
return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
def unauthorized_exception(detail: str = "Unauthorized"):
return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
def forbidden_exception(detail: str = "Forbidden"):
return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
def not_found_exception(detail: str = "Not Found"):
return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
def internal_server_error_exception(detail: str = "Internal Server Error"):
return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
def conflict_exception(detail: str = "Conflict"):
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)

View File

@@ -11,11 +11,12 @@ import logging
# import all models to ensure they are registered before create_all
logging.getLogger('passlib').setLevel(logging.ERROR) # fix bc package logging is broken
logging.getLogger("passlib").setLevel(logging.ERROR) # fix bc package logging is broken
# Create DB tables (remove in production; use migrations instead)
def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]:
@asynccontextmanager
async def lifespan(app: FastAPI):
# Base.metadata.drop_all(bind=get_engine())
@@ -24,6 +25,7 @@ def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]
return lifespan
lifespan = lifespan_factory()
app = FastAPI(lifespan=lifespan)
@@ -34,17 +36,18 @@ app.include_router(router)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:8081", # Keep for web testing if needed
"http://192.168.1.9:8081", # Add your mobile device/emulator origin (adjust port if needed)
"http://localhost:8081", # Keep for web testing if needed
"http://192.168.1.9:8081", # Add your mobile device/emulator origin (adjust port if needed)
"http://192.168.255.221:8081",
# Add other origins if necessary, e.g., production frontend URL
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
allow_headers=["*"],
)
# Health endpoint
@app.get("/api/health")
def health():
return {"status": "ok"}
return {"status": "ok"}

View File

@@ -1,7 +1,7 @@
# modules/admin/api.py
from typing import Annotated
from fastapi import APIRouter, Depends # Import Body
from pydantic import BaseModel # Import BaseModel
from fastapi import APIRouter, Depends # Import Body
from pydantic import BaseModel # Import BaseModel
from sqlalchemy.orm import Session
from core.database import get_db
from modules.auth.dependencies import admin_only
@@ -9,14 +9,17 @@ from .tasks import cleardb
router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)])
# Define a Pydantic model for the request body
class ClearDbRequest(BaseModel):
hard: bool
@router.get("/")
def read_admin():
return {"message": "Admin route"}
# Change to POST and use the request body model
@router.post("/cleardb")
def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
@@ -25,6 +28,6 @@ def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
'hard'=True: Drop and recreate all tables.
'hard'=False: Delete data from tables except users.
"""
hard = payload.hard # Get 'hard' from the payload
hard = payload.hard # Get 'hard' from the payload
cleardb.delay(hard)
return {"message": "Clearing database in the background", "hard": hard}

View File

@@ -1,4 +1,4 @@
# modules/admin/services.py
## temp
## temp

View File

@@ -1,5 +1,6 @@
from core.celery_app import celery_app
@celery_app.task
def cleardb(hard: bool):
"""
@@ -32,4 +33,4 @@ def cleardb(hard: bool):
print(f"Deleting table: {table_name}")
db.execute(table.delete())
db.commit()
return {"message": "Database cleared"}
return {"message": "Database cleared"}

View File

@@ -3,9 +3,24 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from jose import JWTError
from modules.auth.models import User
from modules.auth.schemas import UserCreate, UserResponse, Token, RefreshTokenRequest, LogoutRequest
from modules.auth.schemas import (
UserCreate,
UserResponse,
Token,
RefreshTokenRequest,
LogoutRequest,
)
from modules.auth.services import create_user
from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens
from modules.auth.security import (
TokenType,
get_current_user,
oauth2_scheme,
create_access_token,
create_refresh_token,
verify_token,
authenticate_user,
blacklist_tokens,
)
from sqlalchemy.orm import Session
from typing import Annotated
from core.database import get_db
@@ -15,12 +30,19 @@ from core.exceptions import unauthorized_exception
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
)
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
return create_user(user.username, user.password, user.name, db)
@router.post("/login", response_model=Token)
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[Session, Depends(get_db)],
):
"""
Authenticate user and return JWT tokens in the response body.
"""
@@ -30,39 +52,53 @@ def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annota
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
)
refresh_token = create_refresh_token(data={"sub": user.username})
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
@router.post("/refresh")
def refresh_token(payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]):
def refresh_token(
payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]
):
print("Refreshing token...")
refresh_token = payload.refresh_token
if not refresh_token:
raise unauthorized_exception("Refresh token missing in request body")
user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db)
user_data = verify_token(
refresh_token, expected_token_type=TokenType.REFRESH, db=db
)
if not user_data:
raise unauthorized_exception("Invalid refresh token")
new_access_token = create_access_token(data={"sub": user_data.username})
return {"access_token": new_access_token, "token_type": "bearer"}
@router.post("/logout")
def logout(payload: LogoutRequest, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme)):
def logout(
payload: LogoutRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
access_token: str = Depends(oauth2_scheme),
):
try:
refresh_token = payload.refresh_token
if not refresh_token:
raise unauthorized_exception("Refresh token not found in request body")
blacklist_tokens(
access_token=access_token,
refresh_token=refresh_token,
db=db
)
blacklist_tokens(access_token=access_token, refresh_token=refresh_token, db=db)
return {"message": "Logged out successfully"}
except JWTError:
raise unauthorized_exception("Invalid token")
raise unauthorized_exception("Invalid token")

View File

@@ -5,14 +5,18 @@ from modules.auth.schemas import UserRole
from modules.auth.models import User
from core.exceptions import forbidden_exception
class RoleChecker:
def __init__(self, allowed_roles: list[UserRole]):
self.allowed_roles = allowed_roles
def __call__(self, user: User = Depends(get_current_user)):
if user.role not in self.allowed_roles:
raise forbidden_exception("You do not have permission to perform this action.")
raise forbidden_exception(
"You do not have permission to perform this action."
)
return user
admin_only = RoleChecker([UserRole.ADMIN])
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])

View File

@@ -4,10 +4,12 @@ from sqlalchemy import Column, Integer, String, Enum, DateTime
from sqlalchemy.orm import relationship
from enum import Enum as PyEnum
class UserRole(str, PyEnum):
ADMIN = "admin"
USER = "user"
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)

View File

@@ -2,33 +2,41 @@
from enum import Enum as PyEnum
from pydantic import BaseModel
class Token(BaseModel):
access_token: str
token_type: str
refresh_token: str | None = None
class TokenData(BaseModel):
username: str | None = None
scopes: list[str] = []
class RefreshTokenRequest(BaseModel):
refresh_token: str
class LogoutRequest(BaseModel):
refresh_token: str
class UserRole(str, PyEnum):
ADMIN = "admin"
USER = "user"
class UserCreate(BaseModel):
username: str
password: str
name: str
class UserPatch(BaseModel):
name: str | None = None
class UserResponse(BaseModel):
uuid: str
username: str

View File

@@ -18,6 +18,7 @@ from modules.auth.schemas import TokenData
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
class TokenType(str, Enum):
ACCESS = "access"
REFRESH = "refresh"
@@ -25,11 +26,13 @@ class TokenType(str, Enum):
password_hasher = PasswordHasher()
def hash_password(password: str) -> str:
"""Hash a password with Argon2 (and optional pepper)."""
peppered_password = password + settings.PEPPER # Prepend/append pepper
return password_hasher.hash(peppered_password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hashed version using Argon2."""
peppered_password = plain_password + settings.PEPPER
@@ -38,6 +41,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
except VerifyMismatchError:
return False
def authenticate_user(username: str, password: str, db: Session) -> User | None:
"""
Authenticate a user by checking username/password against the database.
@@ -45,41 +49,46 @@ def authenticate_user(username: str, password: str, db: Session) -> User | None:
"""
# Get user from database
user = db.query(User).filter(User.username == username).first()
# If user not found or password doesn't match
if not user or not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# expire = datetime.now(timezone.utc) + timedelta(seconds=5)
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
expire = datetime.now(timezone.utc) + timedelta(
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
)
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
def verify_token(
token: str, expected_token_type: TokenType, db: Session
) -> TokenData | None:
"""Verify a JWT token and return TokenData if valid.
Parameters
@@ -96,24 +105,32 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok
TokenData | None
TokenData instance if the token is valid, None otherwise.
"""
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
is_blacklisted = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
is not None
)
if is_blacklisted:
return None
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub")
token_type: str = payload.get("token_type")
if username is None or token_type != expected_token_type:
return None
return TokenData(username=username)
except JWTError:
return None
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
def get_current_user(
db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
@@ -121,26 +138,28 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen
)
# Check if the token is blacklisted
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
is_blacklisted = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
is not None
)
if is_blacklisted:
raise credentials_exception
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user: User = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
"""Blacklist both access and refresh tokens.
@@ -154,7 +173,9 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
Database session to perform the operation.
"""
for token in [access_token, refresh_token]:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist
@@ -163,10 +184,13 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
db.commit()
def blacklist_token(token: str, db: Session) -> None:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
db.add(blacklisted_token)

View File

@@ -20,11 +20,13 @@ def create_user(username: str, password: str, name: str, db: Session) -> UserRes
existing_user = db.query(User).filter(User.username == username).first()
if existing_user:
raise conflict_exception("Username already exists")
hashed_password = hash_password(password)
user_uuid = str(uuid.uuid4())
user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid)
user = User(
username=username, hashed_password=hashed_password, name=name, uuid=user_uuid
)
db.add(user)
db.commit()
db.refresh(user) # Loads the generated ID
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic

View File

@@ -6,50 +6,63 @@ from typing import List, Optional
from modules.auth.dependencies import get_current_user
from core.database import get_db
from modules.auth.models import User
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate, CalendarEventResponse
from modules.calendar.service import create_calendar_event, get_calendar_event_by_id, get_calendar_events, update_calendar_event, delete_calendar_event
from modules.calendar.schemas import (
CalendarEventCreate,
CalendarEventUpdate,
CalendarEventResponse,
)
from modules.calendar.service import (
create_calendar_event,
get_calendar_event_by_id,
get_calendar_events,
update_calendar_event,
delete_calendar_event,
)
router = APIRouter(prefix="/calendar", tags=["calendar"])
@router.post("/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED
)
def create_event(
event: CalendarEventCreate,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
return create_calendar_event(db, user.id, event)
@router.get("/events", response_model=List[CalendarEventResponse])
def get_events(
user: User = Depends(get_current_user),
db: Session = Depends(get_db),
start: Optional[datetime] = None,
end: Optional[datetime] = None
end: Optional[datetime] = None,
):
return get_calendar_events(db, user.id, start, end)
@router.get("/events/{event_id}", response_model=CalendarEventResponse)
def get_event_by_id(
event_id: int,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
event = get_calendar_event_by_id(db, user.id, event_id)
return event
@router.patch("/events/{event_id}", response_model=CalendarEventResponse)
def update_event(
event_id: int,
event: CalendarEventUpdate,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
):
return update_calendar_event(db, user.id, event_id, event)
@router.delete("/events/{event_id}", status_code=204)
def delete_event(
event_id: int,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
delete_calendar_event(db, user.id, event_id)
delete_calendar_event(db, user.id, event_id)

View File

@@ -1,8 +1,17 @@
# modules/calendar/models.py
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Boolean # Add Boolean
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
ForeignKey,
JSON,
Boolean,
) # Add Boolean
from sqlalchemy.orm import relationship
from core.database import Base
class CalendarEvent(Base):
__tablename__ = "calendar_events"
@@ -12,10 +21,12 @@ class CalendarEvent(Base):
start = Column(DateTime, nullable=False)
end = Column(DateTime)
location = Column(String)
all_day = Column(Boolean, default=False) # Add all_day column
all_day = Column(Boolean, default=False) # Add all_day column
tags = Column(JSON)
color = Column(String) # hex code for color
user_id = Column(Integer, ForeignKey("users.id"), nullable=False) # <-- Relationship
color = Column(String) # hex code for color
user_id = Column(
Integer, ForeignKey("users.id"), nullable=False
) # <-- Relationship
# Bi-directional relationship (for eager loading)
user = relationship("User", back_populates="calendar_events")
user = relationship("User", back_populates="calendar_events")

View File

@@ -1,7 +1,8 @@
# modules/calendar/schemas.py
from datetime import datetime
from pydantic import BaseModel, field_validator # Add field_validator
from typing import List, Optional # Add List and Optional
from pydantic import BaseModel, field_validator # Add field_validator
from typing import List, Optional # Add List and Optional
# Base schema for common fields, including tags
class CalendarEventBase(BaseModel):
@@ -10,21 +11,23 @@ class CalendarEventBase(BaseModel):
start: datetime
end: Optional[datetime] = None
location: Optional[str] = None
color: Optional[str] = None # Assuming color exists
all_day: Optional[bool] = None # Add all_day field
tags: Optional[List[str]] = None # Add optional tags
color: Optional[str] = None # Assuming color exists
all_day: Optional[bool] = None # Add all_day field
tags: Optional[List[str]] = None # Add optional tags
@field_validator('tags', mode='before')
@field_validator("tags", mode="before")
@classmethod
def tags_validate_null_string(cls, v):
if v == "Null":
return None
return v
# Schema for creating an event (inherits from Base)
class CalendarEventCreate(CalendarEventBase):
pass
# Schema for updating an event (all fields optional)
class CalendarEventUpdate(BaseModel):
title: Optional[str] = None
@@ -33,23 +36,24 @@ class CalendarEventUpdate(BaseModel):
end: Optional[datetime] = None
location: Optional[str] = None
color: Optional[str] = None
all_day: Optional[bool] = None # Add all_day field
tags: Optional[List[str]] = None # Add optional tags for update
all_day: Optional[bool] = None # Add all_day field
tags: Optional[List[str]] = None # Add optional tags for update
@field_validator('tags', mode='before')
@field_validator("tags", mode="before")
@classmethod
def tags_validate_null_string(cls, v):
if v == "Null":
return None
return v
# Schema for the response (inherits from Base, adds ID and user_id)
class CalendarEventResponse(CalendarEventBase):
id: int
user_id: int
tags: List[str] # Keep as List[str], remove default []
tags: List[str] # Keep as List[str], remove default []
@field_validator('tags', mode='before')
@field_validator("tags", mode="before")
@classmethod
def tags_validate_none_to_list(cls, v):
# If the value from the source object (e.g., ORM model) is None,
@@ -59,4 +63,4 @@ class CalendarEventResponse(CalendarEventBase):
return v
class Config:
from_attributes = True
from_attributes = True

View File

@@ -1,25 +1,34 @@
# modules/calendar/service.py
from sqlalchemy.orm import Session
from sqlalchemy import or_ # Import or_
from sqlalchemy import or_ # Import or_
from datetime import datetime
from modules.calendar.models import CalendarEvent
from core.exceptions import not_found_exception
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate # Import schemas
from modules.calendar.schemas import (
CalendarEventCreate,
CalendarEventUpdate,
) # Import schemas
def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate):
# Ensure tags is None if not provided or empty list, matching model
tags_to_store = event_data.tags if event_data.tags else None
event = CalendarEvent(
**event_data.model_dump(exclude={'tags'}), # Use model_dump and exclude tags initially
tags=tags_to_store, # Set tags separately
user_id=user_id
**event_data.model_dump(
exclude={"tags"}
), # Use model_dump and exclude tags initially
tags=tags_to_store, # Set tags separately
user_id=user_id,
)
db.add(event)
db.commit()
db.refresh(event)
return event
def get_calendar_events(db: Session, user_id: int, start: datetime | None, end: datetime | None):
def get_calendar_events(
db: Session, user_id: int, start: datetime | None, end: datetime | None
):
"""
Retrieves calendar events for a user, optionally filtered by a date range.
@@ -46,9 +55,13 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
query = query.filter(
or_(
# Case 1: Event has duration and overlaps
(CalendarEvent.end is not None) & (CalendarEvent.start < end) & (CalendarEvent.end > start),
(CalendarEvent.end is not None)
& (CalendarEvent.start < end)
& (CalendarEvent.end > start),
# Case 2: Event is a point event within the range
(CalendarEvent.end is None) & (CalendarEvent.start >= start) & (CalendarEvent.start < end)
(CalendarEvent.end is None)
& (CalendarEvent.start >= start)
& (CalendarEvent.start < end),
)
)
# If only start is provided, filter events starting on or after start
@@ -60,37 +73,41 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
elif end:
# Includes events with duration ending <= end (or starting before end if end is None)
# Includes point events occurring < end
query = query.filter(
query = query.filter(
or_(
# Event ends before the specified end time
(CalendarEvent.end is not None) & (CalendarEvent.end <= end),
# Point event occurs before the specified end time
(CalendarEvent.end is None) & (CalendarEvent.start < end)
(CalendarEvent.end is None) & (CalendarEvent.start < end),
)
)
# Alternative interpretation for "ending before end": include events that *start* before end
# query = query.filter(CalendarEvent.start < end)
# Alternative interpretation for "ending before end": include events that *start* before end
# query = query.filter(CalendarEvent.start < end)
return query.order_by(CalendarEvent.start).all() # Order by start time
return query.order_by(CalendarEvent.start).all() # Order by start time
def get_calendar_event_by_id(db: Session, user_id: int, event_id: int):
event = db.query(CalendarEvent).filter(
CalendarEvent.id == event_id,
CalendarEvent.user_id == user_id
).first()
event = (
db.query(CalendarEvent)
.filter(CalendarEvent.id == event_id, CalendarEvent.user_id == user_id)
.first()
)
if not event:
raise not_found_exception()
return event
def update_calendar_event(db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate):
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
def update_calendar_event(
db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate
):
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
# Use model_dump with exclude_unset=True to only update provided fields
update_data = event_data.model_dump(exclude_unset=True)
for key, value in update_data.items():
# Ensure tags is handled correctly (set to None if empty list provided)
if key == 'tags' and isinstance(value, list) and not value:
if key == "tags" and isinstance(value, list) and not value:
setattr(event, key, None)
else:
setattr(event, key, value)
@@ -99,7 +116,8 @@ def update_calendar_event(db: Session, user_id: int, event_id: int, event_data:
db.refresh(event)
return event
def delete_calendar_event(db: Session, user_id: int, event_id: int):
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
db.delete(event)
db.commit()
db.commit()

View File

@@ -7,13 +7,27 @@ from core.database import get_db
from modules.auth.dependencies import get_current_user
from modules.auth.models import User
# Import the new service functions and Enum
from modules.nlp.service import process_request, ask_ai, save_chat_message, get_chat_history, MessageSender
from modules.nlp.service import (
process_request,
ask_ai,
save_chat_message,
get_chat_history,
MessageSender,
)
# Import the response schema and the new ChatMessage model for response type hinting
from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse
from modules.calendar.service import create_calendar_event, get_calendar_events, update_calendar_event, delete_calendar_event
from modules.calendar.service import (
create_calendar_event,
get_calendar_events,
update_calendar_event,
delete_calendar_event,
)
from modules.calendar.models import CalendarEvent
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate
# Import TODO services, schemas, and model
from modules.todo import service as todo_service
from modules.todo.models import Todo
@@ -21,17 +35,20 @@ from modules.todo.schemas import TodoCreate, TodoUpdate
from pydantic import BaseModel
from datetime import datetime
class ChatMessageResponse(BaseModel):
id: int
sender: MessageSender # Use the enum directly
sender: MessageSender # Use the enum directly
text: str
timestamp: datetime
class Config:
from_attributes = True # Allow Pydantic to work with ORM models
from_attributes = True # Allow Pydantic to work with ORM models
router = APIRouter(prefix="/nlp", tags=["nlp"])
# Helper to format calendar events (expects list of CalendarEvent models)
def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
if not events:
@@ -39,12 +56,15 @@ def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
formatted = ["Here are the events:"]
for event in events:
# Access attributes directly from the model instance
start_str = event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
start_str = (
event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
)
end_str = event.end.strftime("%H:%M") if event.end else ""
title = event.title or "Untitled Event"
formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})")
return formatted
# Helper to format TODO items (expects list of Todo models)
def format_todos(todos: List[Todo]) -> List[str]:
if not todos:
@@ -54,19 +74,28 @@ def format_todos(todos: List[Todo]) -> List[str]:
status = "[X]" if todo.complete else "[ ]"
date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else ""
remind_str = " (Reminder)" if todo.remind else ""
formatted.append(f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})")
formatted.append(
f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})"
)
return formatted
# Update the response model for the endpoint
@router.post("/process-command", response_model=ProcessCommandResponse)
def process_command(request_data: ProcessCommandRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
def process_command(
request_data: ProcessCommandRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Process the user command, save messages, execute action, save response, and return user-friendly responses.
"""
user_input = request_data.user_input
# --- Save User Message ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.USER, text=user_input)
save_chat_message(
db, user_id=current_user.id, sender=MessageSender.USER, text=user_input
)
# ------------------------
command_data = process_request(user_input)
@@ -74,11 +103,13 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
params = command_data["params"]
response_text = command_data["response_text"]
responses = [response_text] # Start with the initial response
responses = [response_text] # Start with the initial response
# --- Save Initial AI Response ---
# Save the first response generated by process_request
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text)
save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=response_text
)
# -----------------------------
if intent == "error":
@@ -97,139 +128,233 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
ai_answer = ask_ai(**params)
responses.append(ai_answer)
# --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer)
save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer
)
# ---------------------------------
return ProcessCommandResponse(responses=responses)
case "get_calendar_events":
events: List[CalendarEvent] = get_calendar_events(db, current_user.id, **params)
events: List[CalendarEvent] = get_calendar_events(
db, current_user.id, **params
)
formatted_responses = format_calendar_events(events)
responses.extend(formatted_responses)
# --- Save Additional AI Responses ---
for resp in formatted_responses:
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp)
save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
)
# ----------------------------------
return ProcessCommandResponse(responses=responses)
case "add_calendar_event":
event_data = CalendarEventCreate(**params)
created_event = create_calendar_event(db, current_user.id, event_data)
start_str = created_event.start.strftime("%Y-%m-%d %H:%M") if created_event.start else "No start time"
start_str = (
created_event.start.strftime("%Y-%m-%d %H:%M")
if created_event.start
else "No start time"
)
title = created_event.title or "Untitled Event"
add_response = f"Added: {title} starting at {start_str}."
responses.append(add_response)
# --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=add_response,
)
# ---------------------------------
return ProcessCommandResponse(responses=responses)
case "update_calendar_event":
event_id = params.pop('event_id', None)
event_id = params.pop("event_id", None)
if event_id is None:
# Save the error message before raising
error_msg = "Event ID is required for update."
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg)
event_data = CalendarEventUpdate(**params)
updated_event = update_calendar_event(db, current_user.id, event_id, event_data=event_data)
updated_event = update_calendar_event(
db, current_user.id, event_id, event_data=event_data
)
title = updated_event.title or "Untitled Event"
update_response = f"Updated event ID {updated_event.id}: {title}."
responses.append(update_response)
# --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=update_response,
)
# ---------------------------------
return ProcessCommandResponse(responses=responses)
case "delete_calendar_event":
event_id = params.get('event_id')
event_id = params.get("event_id")
if event_id is None:
# Save the error message before raising
error_msg = "Event ID is required for delete."
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg)
delete_calendar_event(db, current_user.id, event_id)
delete_response = f"Deleted event ID {event_id}."
responses.append(delete_response)
# --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=delete_response,
)
# ---------------------------------
return ProcessCommandResponse(responses=responses)
# --- Add TODO Cases ---
# --- Add TODO Cases ---
case "get_todos":
todos: List[Todo] = todo_service.get_todos(db, user=current_user, **params)
todos: List[Todo] = todo_service.get_todos(
db, user=current_user, **params
)
formatted_responses = format_todos(todos)
responses.extend(formatted_responses)
# --- Save Additional AI Responses ---
for resp in formatted_responses:
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp)
save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
)
# ----------------------------------
return ProcessCommandResponse(responses=responses)
case "add_todo":
todo_data = TodoCreate(**params)
created_todo = todo_service.create_todo(db, todo=todo_data, user=current_user)
add_response = f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
created_todo = todo_service.create_todo(
db, todo=todo_data, user=current_user
)
add_response = (
f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
)
responses.append(add_response)
# --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=add_response,
)
# ---------------------------------
return ProcessCommandResponse(responses=responses)
case "update_todo":
todo_id = params.pop('todo_id', None)
todo_id = params.pop("todo_id", None)
if todo_id is None:
error_msg = "TODO ID is required for update."
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg)
todo_data = TodoUpdate(**params)
updated_todo = todo_service.update_todo(db, todo_id=todo_id, todo_update=todo_data, user=current_user)
update_response = f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
if 'complete' in params:
status = "complete" if params['complete'] else "incomplete"
updated_todo = todo_service.update_todo(
db, todo_id=todo_id, todo_update=todo_data, user=current_user
)
update_response = (
f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
)
if "complete" in params:
status = "complete" if params["complete"] else "incomplete"
update_response += f" Marked as {status}."
responses.append(update_response)
# --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=update_response,
)
# ---------------------------------
return ProcessCommandResponse(responses=responses)
case "delete_todo":
todo_id = params.get('todo_id')
todo_id = params.get("todo_id")
if todo_id is None:
error_msg = "TODO ID is required for delete."
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg)
deleted_todo = todo_service.delete_todo(db, todo_id=todo_id, user=current_user)
delete_response = f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
deleted_todo = todo_service.delete_todo(
db, todo_id=todo_id, user=current_user
)
delete_response = (
f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
)
responses.append(delete_response)
# --- Save Additional AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response)
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=delete_response,
)
# ---------------------------------
return ProcessCommandResponse(responses=responses)
# --- End TODO Cases ---
case _:
print(f"Warning: Unhandled intent '{intent}' reached api.py match statement.")
case _:
print(
f"Warning: Unhandled intent '{intent}' reached api.py match statement."
)
# The initial response_text was already saved
return ProcessCommandResponse(responses=responses)
except HTTPException as http_exc:
# Don't save again if already saved before raising
if http_exc.status_code != 400 or ('event_id' not in http_exc.detail.lower()):
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=http_exc.detail)
if http_exc.status_code != 400 or ("event_id" not in http_exc.detail.lower()):
save_chat_message(
db,
user_id=current_user.id,
sender=MessageSender.AI,
text=http_exc.detail,
)
raise http_exc
except Exception as e:
print(f"Error executing intent '{intent}': {e}")
error_response = "Sorry, I encountered an error while trying to perform that action."
error_response = (
"Sorry, I encountered an error while trying to perform that action."
)
# --- Save Final Error AI Response ---
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_response)
save_chat_message(
db, user_id=current_user.id, sender=MessageSender.AI, text=error_response
)
# ----------------------------------
return ProcessCommandResponse(responses=[error_response])
@router.get("/history", response_model=List[ChatMessageResponse])
def read_chat_history(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
def read_chat_history(
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
):
"""Retrieves the last 50 chat messages for the current user."""
history = get_chat_history(db, user_id=current_user.id, limit=50)
return history
# -------------------------------------
# -------------------------------------

View File

@@ -1,4 +1,3 @@
\
# /home/cdp/code/MAIA/backend/modules/nlp/models.py
from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey, Enum as SQLEnum
from sqlalchemy.orm import relationship
@@ -7,10 +6,12 @@ import enum
from core.database import Base
class MessageSender(enum.Enum):
USER = "user"
AI = "ai"
class ChatMessage(Base):
__tablename__ = "chat_messages"
@@ -20,4 +21,4 @@ class ChatMessage(Base):
text = Column(Text, nullable=False)
timestamp = Column(DateTime(timezone=True), server_default=func.now())
owner = relationship("User") # Relationship to the User model
owner = relationship("User") # Relationship to the User model

View File

@@ -2,9 +2,11 @@
from pydantic import BaseModel
from typing import List
class ProcessCommandRequest(BaseModel):
user_input: str
class ProcessCommandResponse(BaseModel):
responses: List[str]
# Optional: Keep details if needed for specific frontend logic beyond display

View File

@@ -1,11 +1,11 @@
# modules/nlp/service.py
from sqlalchemy.orm import Session
from sqlalchemy import desc # Import desc for ordering
from sqlalchemy import desc # Import desc for ordering
from google import genai
import json
from datetime import datetime, timezone
from typing import List # Import List
from typing import List # Import List
# Import the new model and Enum
from .models import ChatMessage, MessageSender
@@ -14,7 +14,8 @@ from core.config import settings
client = genai.Client(api_key=settings.GOOGLE_API_KEY)
### Base prompt for MAIA, used for inital user requests
SYSTEM_PROMPT = """
SYSTEM_PROMPT = (
"""
You are MAIA - My AI Assistant. Your job is to parse user requests into structured JSON commands and generate a user-facing response text.
Available functions/intents:
@@ -109,8 +110,11 @@ MAIA:
"response_text": "Okay, I've deleted task 2 from your list."
}
The datetime right now is """+str(datetime.now(timezone.utc))+""".
The datetime right now is """
+ str(datetime.now(timezone.utc))
+ """.
"""
)
### Prompt for MAIA to forward user request to AI
SYSTEM_FORWARD_PROMPT = f"""
@@ -123,6 +127,7 @@ Here is the user request:
# --- Chat History Service Functions ---
def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: str):
"""Saves a chat message to the database."""
db_message = ChatMessage(user_id=user_id, sender=sender, text=text)
@@ -131,16 +136,21 @@ def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: st
db.refresh(db_message)
return db_message
def get_chat_history(db: Session, user_id: int, limit: int = 50) -> List[ChatMessage]:
"""Retrieves the last 'limit' chat messages for a user."""
return db.query(ChatMessage)\
.filter(ChatMessage.user_id == user_id)\
.order_by(desc(ChatMessage.timestamp))\
.limit(limit)\
.all()[::-1] # Reverse to get oldest first for display order
return (
db.query(ChatMessage)
.filter(ChatMessage.user_id == user_id)
.order_by(desc(ChatMessage.timestamp))
.limit(limit)
.all()[::-1]
) # Reverse to get oldest first for display order
# --- Existing NLP Service Functions ---
def process_request(request: str):
"""
Process the user request using the Google GenAI API.
@@ -152,7 +162,7 @@ def process_request(request: str):
config={
"temperature": 0.3, # Less creativity, more factual
"response_mime_type": "application/json",
}
},
)
# Parse the JSON response
@@ -160,7 +170,9 @@ def process_request(request: str):
parsed_response = json.loads(response.text)
# Validate required fields
if not all(k in parsed_response for k in ("intent", "params", "response_text")):
raise ValueError("AI response missing required fields (intent, params, response_text)")
raise ValueError(
"AI response missing required fields (intent, params, response_text)"
)
return parsed_response
except (json.JSONDecodeError, ValueError) as e:
print(f"Error parsing AI response: {e}")
@@ -169,9 +181,10 @@ def process_request(request: str):
return {
"intent": "error",
"params": {},
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?"
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?",
}
def ask_ai(request: str):
"""
Ask the AI a question.
@@ -179,6 +192,6 @@ def ask_ai(request: str):
"""
response = client.models.generate_content(
model="gemini-2.0-flash",
contents=SYSTEM_FORWARD_PROMPT+request,
contents=SYSTEM_FORWARD_PROMPT + request,
)
return response.text
return response.text

View File

@@ -5,58 +5,65 @@ from typing import List
from . import service, schemas
from core.database import get_db
from modules.auth.dependencies import get_current_user # Corrected import
from modules.auth.models import User # Assuming User model is in auth.models
from modules.auth.dependencies import get_current_user # Corrected import
from modules.auth.models import User # Assuming User model is in auth.models
router = APIRouter(
prefix="/todos",
tags=["todos"],
dependencies=[Depends(get_current_user)], # Corrected dependency
dependencies=[Depends(get_current_user)], # Corrected dependency
responses={404: {"description": "Not found"}},
)
@router.post("/", response_model=schemas.Todo, status_code=status.HTTP_201_CREATED)
def create_todo_endpoint(
todo: schemas.TodoCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency
current_user: User = Depends(get_current_user), # Corrected dependency
):
return service.create_todo(db=db, todo=todo, user=current_user)
@router.get("/", response_model=List[schemas.Todo])
def read_todos_endpoint(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency
current_user: User = Depends(get_current_user), # Corrected dependency
):
todos = service.get_todos(db=db, user=current_user, skip=skip, limit=limit)
return todos
@router.get("/{todo_id}", response_model=schemas.Todo)
def read_todo_endpoint(
todo_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency
current_user: User = Depends(get_current_user), # Corrected dependency
):
db_todo = service.get_todo(db=db, todo_id=todo_id, user=current_user)
if db_todo is None:
raise HTTPException(status_code=404, detail="Todo not found")
return db_todo
@router.put("/{todo_id}", response_model=schemas.Todo)
def update_todo_endpoint(
todo_id: int,
todo_update: schemas.TodoUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency
current_user: User = Depends(get_current_user), # Corrected dependency
):
return service.update_todo(db=db, todo_id=todo_id, todo_update=todo_update, user=current_user)
return service.update_todo(
db=db, todo_id=todo_id, todo_update=todo_update, user=current_user
)
@router.delete("/{todo_id}", response_model=schemas.Todo)
def delete_todo_endpoint(
todo_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) # Corrected dependency
current_user: User = Depends(get_current_user), # Corrected dependency
):
return service.delete_todo(db=db, todo_id=todo_id, user=current_user)

View File

@@ -3,6 +3,7 @@ from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from core.database import Base
class Todo(Base):
__tablename__ = "todos"
@@ -13,4 +14,6 @@ class Todo(Base):
complete = Column(Boolean, default=False)
owner_id = Column(Integer, ForeignKey("users.id"))
owner = relationship("User") # Add relationship if needed, assuming User model exists in auth.models
owner = relationship(
"User"
) # Add relationship if needed, assuming User model exists in auth.models

View File

@@ -3,21 +3,25 @@ from pydantic import BaseModel
from typing import Optional
import datetime
class TodoBase(BaseModel):
task: str
date: Optional[datetime.datetime] = None
remind: bool = False
complete: bool = False
class TodoCreate(TodoBase):
pass
class TodoUpdate(BaseModel):
task: Optional[str] = None
date: Optional[datetime.datetime] = None
remind: Optional[bool] = None
complete: Optional[bool] = None
class Todo(TodoBase):
id: int
owner_id: int

View File

@@ -1,9 +1,10 @@
# backend/modules/todo/service.py
from sqlalchemy.orm import Session
from . import models, schemas
from modules.auth.models import User # Assuming User model is in auth.models
from modules.auth.models import User # Assuming User model is in auth.models
from fastapi import HTTPException, status
def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
db_todo = models.Todo(**todo.dict(), owner_id=user.id)
db.add(db_todo)
@@ -11,17 +12,34 @@ def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
db.refresh(db_todo)
return db_todo
def get_todos(db: Session, user: User, skip: int = 0, limit: int = 100):
return db.query(models.Todo).filter(models.Todo.owner_id == user.id).offset(skip).limit(limit).all()
return (
db.query(models.Todo)
.filter(models.Todo.owner_id == user.id)
.offset(skip)
.limit(limit)
.all()
)
def get_todo(db: Session, todo_id: int, user: User):
db_todo = db.query(models.Todo).filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id).first()
db_todo = (
db.query(models.Todo)
.filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id)
.first()
)
if db_todo is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found"
)
return db_todo
def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user: User):
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence
db_todo = get_todo(
db=db, todo_id=todo_id, user=user
) # Reuse get_todo to check ownership and existence
update_data = todo_update.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_todo, key, value)
@@ -29,8 +47,11 @@ def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user
db.refresh(db_todo)
return db_todo
def delete_todo(db: Session, todo_id: int, user: User):
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence
db_todo = get_todo(
db=db, todo_id=todo_id, user=user
) # Reuse get_todo to check ownership and existence
db.delete(db_todo)
db.commit()
return db_todo

View File

@@ -11,37 +11,52 @@ from modules.auth.models import User
router = APIRouter(prefix="/user", tags=["user"])
@router.get("/me", response_model=UserResponse)
def me(db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
def me(
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
) -> UserResponse:
"""
Get the current user. Requires user to be logged in.
Returns the user object.
"""
"""
return current_user
@router.get("/{username}", response_model=UserResponse)
def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
def get_user(
username: str,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
) -> UserResponse:
"""
Get a user by username.
Returns the user object.
"""
if current_user.username != username:
raise forbidden_exception("You can only view your own profile")
user = db.query(User).filter(User.username == username).first()
if not user:
raise not_found_exception("User not found")
return user
@router.patch("/{username}", response_model=UserResponse)
def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
def update_user(
username: str,
user_data: UserPatch,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
) -> UserResponse:
"""
Update a user by username.
Returns the updated user object.
"""
if current_user.username != username:
raise forbidden_exception("You can only update your own profile")
user = db.query(User).filter(User.username == username).first()
if not user:
raise not_found_exception("User not found")
@@ -60,19 +75,24 @@ def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depe
db.refresh(user)
return user
@router.delete("/{username}", response_model=UserResponse)
def delete_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
def delete_user(
username: str,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
) -> UserResponse:
"""
Delete a user by username.
Returns the deleted user object.
"""
if current_user.username != username:
raise forbidden_exception("You can only delete your own profile")
user = db.query(User).filter(User.username == username).first()
if not user:
raise not_found_exception("User not found")
db.delete(user)
db.commit()
return user
return user

View File

@@ -12,6 +12,7 @@ from core.database import get_db, get_sessionmaker
fake = Faker()
@pytest.fixture(scope="session")
def postgres_container() -> Generator[PostgresContainer, None, None]:
"""Fixture to create a PostgreSQL container for testing."""
@@ -21,7 +22,8 @@ def postgres_container() -> Generator[PostgresContainer, None, None]:
print(f"Postgres container started at {settings.DB_URL}")
yield postgres
print("Postgres container stopped.")
@pytest.fixture(scope="function")
def db(postgres_container) -> Generator[Session, None, None]:
"""Function-scoped database session with rollback"""
@@ -34,25 +36,28 @@ def db(postgres_container) -> Generator[Session, None, None]:
session.rollback()
session.close()
@pytest.fixture(scope="function")
def client(db: Session) -> Generator[TestClient, None, None]:
"""Function-scoped test client with dependency override"""
from main import app
# Override the database dependency
def override_get_db():
try:
yield db
finally:
pass # Don't close session here
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None:
from main import app
app.dependency_overrides[dependency] = lambda: mocked_response
app.dependency_overrides[dependency] = lambda: mocked_response

View File

@@ -5,17 +5,24 @@ from sqlalchemy.orm import Session
from core.config import settings
from modules.auth.models import User
from modules.auth.security import authenticate_user, create_access_token, create_refresh_token, hash_password
from modules.auth.security import (
authenticate_user,
create_access_token,
create_refresh_token,
hash_password,
)
from modules.auth.schemas import UserRole
from tests.conftest import fake
from typing import Optional # Import Optional
from typing import Optional # Import Optional
def create_user(db: Session, is_admin: bool = False, username: Optional[str] = None) -> User:
def create_user(
db: Session, is_admin: bool = False, username: Optional[str] = None
) -> User:
unhashed_password = fake.password()
_user = User(
name=fake.name(),
username=username or fake.user_name(), # Use provided username or generate one
username=username or fake.user_name(), # Use provided username or generate one
hashed_password=hash_password(unhashed_password),
uuid=uuid_pkg.uuid4(),
role=UserRole.ADMIN if is_admin else UserRole.USER,
@@ -24,14 +31,18 @@ def create_user(db: Session, is_admin: bool = False, username: Optional[str] = N
db.add(_user)
db.commit()
db.refresh(_user)
return _user, unhashed_password # return for testing
return _user, unhashed_password # return for testing
def login(db: Session, username: str, password: str) -> str:
user = authenticate_user(username, password, db)
if not user:
raise Exception("Incorrect username or password")
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
)
refresh_token = create_refresh_token(data={"sub": user.username})
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
@@ -40,4 +51,4 @@ def login(db: Session, username: str, password: str) -> str:
"access_token": access_token,
"refresh_token": refresh_token,
"max_age": max_age,
}
}

View File

@@ -7,71 +7,93 @@ from tests.helpers import generators
# Test admin routes require admin privileges
def test_read_admin_unauthorized(client: TestClient) -> None:
"""Test accessing admin route without authentication."""
response = client.get("/api/admin/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_read_admin_forbidden(db: Session, client: TestClient) -> None:
"""Test accessing admin route as a non-admin user."""
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"]
response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"})
response = client.get(
"/api/admin/", headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_read_admin_success(db: Session, client: TestClient) -> None:
"""Test accessing admin route as an admin user."""
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
admin_user, password = generators.create_user(
db, is_admin=True
) # Use is_admin=True
login_rsp = generators.login(db, admin_user.username, password)
access_token = login_rsp["access_token"]
response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"})
response = client.get(
"/api/admin/", headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"message": "Admin route"}
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
def test_clear_db_soft(mock_cleardb_delay, db: Session, client: TestClient) -> None:
"""Test soft clearing the database as admin."""
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
admin_user, password = generators.create_user(
db, is_admin=True
) # Use is_admin=True
login_rsp = generators.login(db, admin_user.username, password)
access_token = login_rsp["access_token"]
response = client.post(
"/api/admin/cleardb",
headers={"Authorization": f"Bearer {access_token}"},
json={"hard": False}
json={"hard": False},
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"message": "Clearing database in the background", "hard": False}
assert response.json() == {
"message": "Clearing database in the background",
"hard": False,
}
mock_cleardb_delay.assert_called_once_with(False)
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
def test_clear_db_hard(mock_cleardb_delay, db: Session, client: TestClient) -> None:
"""Test hard clearing the database as admin."""
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
admin_user, password = generators.create_user(
db, is_admin=True
) # Use is_admin=True
login_rsp = generators.login(db, admin_user.username, password)
access_token = login_rsp["access_token"]
response = client.post(
"/api/admin/cleardb",
headers={"Authorization": f"Bearer {access_token}"},
json={"hard": True}
json={"hard": True},
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"message": "Clearing database in the background", "hard": True}
assert response.json() == {
"message": "Clearing database in the background",
"hard": True,
}
mock_cleardb_delay.assert_called_once_with(True)
def test_clear_db_forbidden(db: Session, client: TestClient) -> None:
"""Test clearing the database as a non-admin user."""
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"]
response = client.post(
"/api/admin/cleardb",
headers={"Authorization": f"Bearer {access_token}"},
json={"hard": False}
json={"hard": False},
)
assert response.status_code == status.HTTP_403_FORBIDDEN

View File

@@ -34,6 +34,7 @@ def test_register(client: TestClient) -> None:
)
assert response.status_code == status.HTTP_201_CREATED
def test_login(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
@@ -51,17 +52,21 @@ def test_login(db: Session, client: TestClient) -> None:
assert "token_type" in response_data
assert response_data["token_type"] == "bearer"
def test_refresh_token(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
rsp = generators.login(db, user.username, unhashed_password)
access_token = rsp["access_token"]
refresh_token = rsp["refresh_token"]
time.sleep(1) # Sleep to ensure tokens won't be identical
time.sleep(1) # Sleep to ensure tokens won't be identical
response = client.post(
"/api/auth/refresh",
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK
@@ -70,7 +75,10 @@ def test_refresh_token(db: Session, client: TestClient) -> None:
assert "access_token" in response_data
assert "token_type" in response_data
assert response_data["token_type"] == "bearer"
assert response_data["access_token"] != access_token # Ensure the token is refreshed
assert (
response_data["access_token"] != access_token
) # Ensure the token is refreshed
def test_logout(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
@@ -80,15 +88,20 @@ def test_logout(db: Session, client: TestClient) -> None:
response = client.post(
"/api/auth/logout",
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK
# Verify that the token is blacklisted
blacklisted_token = db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
blacklisted_token = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
)
assert blacklisted_token is not None
# Verify that we can't still actually do anything
response = client.get(
"/api/user/me",
@@ -98,7 +111,10 @@ def test_logout(db: Session, client: TestClient) -> None:
response = client.post(
"/api/auth/refresh",
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
},
json={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -106,7 +122,9 @@ def test_logout(db: Session, client: TestClient) -> None:
def test_get_me(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.get(
"/api/user/me",
@@ -119,14 +137,18 @@ def test_get_me(db: Session, client: TestClient) -> None:
assert response_data["uuid"] == user.uuid
assert response_data["username"] == user.username
def test_get_me_unauthorized(client: TestClient) -> None:
### This test should fail (unauthorized) because the user isn't logged in
response = client.get("/api/user/me")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.get(
f"/api/user/{user.username}",
@@ -139,11 +161,14 @@ def test_get_user(db: Session, client: TestClient) -> None:
assert response_data["uuid"] == user.uuid
assert response_data["username"] == user.username
def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
### This test should fail (unauthorized) because the user isn't us
user, unhashed_password = generators.create_user(db)
user2, _ = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.get(
f"/api/user/{user2.username}",
@@ -151,11 +176,14 @@ def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_update_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
new_name = fake.name()
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.patch(
f"/api/user/{user.username}",
headers={"Authorization": f"Bearer {access_token}"},
@@ -168,7 +196,9 @@ def test_update_user(db: Session, client: TestClient) -> None:
def test_delete_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
access_token = generators.login(db, user.username, unhashed_password)[
"access_token"
]
response = client.delete(
f"/api/user/{user.username}",
headers={"Authorization": f"Bearer {access_token}"},
@@ -179,6 +209,7 @@ def test_delete_user(db: Session, client: TestClient) -> None:
deleted_user = db.query(User).filter(User.username == user.username).first()
assert deleted_user is None
def test_get_user_forbidden(db: Session, client: TestClient) -> None:
"""Test getting another user's profile (should be forbidden)."""
user1, password_user1 = generators.create_user(db, username="user1_get_forbidden")
@@ -195,9 +226,12 @@ def test_get_user_forbidden(db: Session, client: TestClient) -> None:
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_update_user_forbidden(db: Session, client: TestClient) -> None:
"""Test updating another user's profile (should be forbidden)."""
user1, password_user1 = generators.create_user(db, username="user1_update_forbidden")
user1, password_user1 = generators.create_user(
db, username="user1_update_forbidden"
)
user2, _ = generators.create_user(db, username="user2_update_forbidden")
new_name = fake.name()
@@ -213,9 +247,12 @@ def test_update_user_forbidden(db: Session, client: TestClient) -> None:
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_delete_user_forbidden(db: Session, client: TestClient) -> None:
"""Test deleting another user's profile (should be forbidden)."""
user1, password_user1 = generators.create_user(db, username="user1_delete_forbidden")
user1, password_user1 = generators.create_user(
db, username="user1_delete_forbidden"
)
user2, _ = generators.create_user(db, username="user2_delete_forbidden")
# Log in as user1

View File

@@ -4,9 +4,10 @@ from sqlalchemy.orm import Session
from datetime import datetime, timedelta
from tests.helpers import generators
from modules.calendar.models import CalendarEvent # Assuming model exists
from modules.calendar.models import CalendarEvent # Assuming model exists
from tests.conftest import fake
# Helper function to create an event payload
def create_event_payload(start_offset_days=0, end_offset_days=1):
start_time = datetime.utcnow() + timedelta(days=start_offset_days)
@@ -14,19 +15,22 @@ def create_event_payload(start_offset_days=0, end_offset_days=1):
return {
"title": fake.sentence(nb_words=3),
"description": fake.text(),
"start": start_time.isoformat(), # Rename start_time to start
"end": end_time.isoformat(), # Rename end_time to end
"start": start_time.isoformat(), # Rename start_time to start
"end": end_time.isoformat(), # Rename end_time to end
"all_day": fake.boolean(),
}
# --- Test Create Event ---
def test_create_event_unauthorized(client: TestClient) -> None:
"""Test creating an event without authentication."""
payload = create_event_payload()
response = client.post("/api/calendar/events", json=payload)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_create_event_success(db: Session, client: TestClient) -> None:
"""Test creating a calendar event successfully."""
user, password = generators.create_user(db)
@@ -37,9 +41,11 @@ def test_create_event_success(db: Session, client: TestClient) -> None:
response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload
json=payload,
)
assert response.status_code == status.HTTP_201_CREATED # Change expected status to 201
assert (
response.status_code == status.HTTP_201_CREATED
) # Change expected status to 201
data = response.json()
assert data["title"] == payload["title"]
assert data["description"] == payload["description"]
@@ -56,13 +62,16 @@ def test_create_event_success(db: Session, client: TestClient) -> None:
assert event_in_db.user_id == user.id
assert event_in_db.title == payload["title"]
# --- Test Get Events ---
def test_get_events_unauthorized(client: TestClient) -> None:
"""Test getting events without authentication."""
response = client.get("/api/calendar/events")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_events_success(db: Session, client: TestClient) -> None:
"""Test getting all calendar events for a user."""
user, password = generators.create_user(db)
@@ -71,21 +80,31 @@ def test_get_events_success(db: Session, client: TestClient) -> None:
# Create a couple of events for the user
payload1 = create_event_payload(0, 1)
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1)
client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload1,
)
payload2 = create_event_payload(2, 3)
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2)
client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload2,
)
# Create an event for another user (should not be returned)
other_user, other_password = generators.create_user(db)
other_login_rsp = generators.login(db, other_user.username, other_password)
other_access_token = other_login_rsp["access_token"]
other_payload = create_event_payload(4, 5)
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {other_access_token}"}, json=other_payload)
client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {other_access_token}"},
json=other_payload,
)
response = client.get(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"}
"/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
@@ -103,12 +122,24 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
access_token = login_rsp["access_token"]
# Create events
payload1 = create_event_payload(0, 1) # Today -> Tomorrow
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1)
payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2)
payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload3)
payload1 = create_event_payload(0, 1) # Today -> Tomorrow
client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload1,
)
payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days
client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload2,
)
payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days
client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload3,
)
# Filter for events starting within the next week
start_filter = datetime.utcnow().isoformat()
@@ -117,11 +148,11 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
response = client.get(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
params={"start": start_filter, "end": end_filter}
params={"start": start_filter, "end": end_filter},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 2 # Should get event 1 and 2
assert len(data) == 2 # Should get event 1 and 2
assert data[0]["title"] == payload1["title"]
assert data[1]["title"] == payload2["title"]
@@ -130,40 +161,50 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
response = client.get(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
params={"start": start_filter_late}
params={"start": start_filter_late},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 1 # Should get event 3
assert len(data) == 1 # Should get event 3
assert data[0]["title"] == payload3["title"]
# --- Test Get Event By ID ---
def test_get_event_by_id_unauthorized(db: Session, client: TestClient) -> None:
"""Test getting a specific event without authentication."""
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
event_id = create_response.json()["id"]
response = client.get(f"/api/calendar/events/{event_id}")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
"""Test getting a specific event successfully."""
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
event_id = create_response.json()["id"]
response = client.get(
f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token}"}
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
@@ -171,6 +212,7 @@ def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
assert data["title"] == payload["title"]
assert data["user_id"] == user.id
def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
"""Test getting a non-existent event."""
user, password = generators.create_user(db)
@@ -180,10 +222,11 @@ def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
response = client.get(
f"/api/calendar/events/{non_existent_id}",
headers={"Authorization": f"Bearer {access_token}"}
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
"""Test getting another user's event."""
user1, password_user1 = generators.create_user(db)
@@ -193,7 +236,11 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
login_rsp1 = generators.login(db, user1.username, password_user1)
access_token1 = login_rsp1["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token1}"},
json=payload,
)
event_id = create_response.json()["id"]
# Log in as user2 and try to get user1's event
@@ -202,45 +249,60 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
response = client.get(
f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token2}"}
headers={"Authorization": f"Bearer {access_token2}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match
assert (
response.status_code == status.HTTP_404_NOT_FOUND
) # Service layer returns 404 if user_id doesn't match
# --- Test Update Event ---
def test_update_event_unauthorized(db: Session, client: TestClient) -> None:
"""Test updating an event without authentication."""
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
event_id = create_response.json()["id"]
update_payload = {"title": "Updated Title"}
response = client.patch(f"/api/calendar/events/{event_id}", json=update_payload)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_update_event_success(db: Session, client: TestClient) -> None:
"""Test updating an event successfully."""
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
assert (
create_response.status_code == status.HTTP_201_CREATED
) # Ensure creation check uses 201
event_id = create_response.json()["id"]
update_payload = {
"title": "Updated Title",
"description": "Updated description.",
"all_day": not payload["all_day"] # Toggle all_day
"all_day": not payload["all_day"], # Toggle all_day
}
response = client.patch(
f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token}"},
json=update_payload
json=update_payload,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
@@ -248,7 +310,7 @@ def test_update_event_success(db: Session, client: TestClient) -> None:
assert data["title"] == update_payload["title"]
assert data["description"] == update_payload["description"]
assert data["all_day"] == update_payload["all_day"]
assert data["start"] == payload["start"] # Check correct field name 'start'
assert data["start"] == payload["start"] # Check correct field name 'start'
assert data["user_id"] == user.id
# Verify in DB
@@ -258,6 +320,7 @@ def test_update_event_success(db: Session, client: TestClient) -> None:
assert event_in_db.description == update_payload["description"]
assert event_in_db.all_day == update_payload["all_day"]
def test_update_event_not_found(db: Session, client: TestClient) -> None:
"""Test updating a non-existent event."""
user, password = generators.create_user(db)
@@ -269,10 +332,11 @@ def test_update_event_not_found(db: Session, client: TestClient) -> None:
response = client.patch(
f"/api/calendar/events/{non_existent_id}",
headers={"Authorization": f"Bearer {access_token}"},
json=update_payload
json=update_payload,
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_update_event_forbidden(db: Session, client: TestClient) -> None:
"""Test updating another user's event."""
user1, password_user1 = generators.create_user(db)
@@ -282,7 +346,11 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None:
login_rsp1 = generators.login(db, user1.username, password_user1)
access_token1 = login_rsp1["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token1}"},
json=payload,
)
event_id = create_response.json()["id"]
# Log in as user2 and try to update user1's event
@@ -293,32 +361,47 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None:
response = client.patch(
f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token2}"},
json=update_payload
json=update_payload,
)
assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match
assert (
response.status_code == status.HTTP_404_NOT_FOUND
) # Service layer returns 404 if user_id doesn't match
# --- Test Delete Event ---
def test_delete_event_unauthorized(db: Session, client: TestClient) -> None:
"""Test deleting an event without authentication."""
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
event_id = create_response.json()["id"]
response = client.delete(f"/api/calendar/events/{event_id}")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_delete_event_success(db: Session, client: TestClient) -> None:
"""Test deleting an event successfully."""
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
access_token = login_rsp["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token}"},
json=payload,
)
assert (
create_response.status_code == status.HTTP_201_CREATED
) # Ensure creation check uses 201
event_id = create_response.json()["id"]
# Verify event exists before delete
@@ -327,7 +410,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None:
response = client.delete(
f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token}"}
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_204_NO_CONTENT
@@ -338,7 +421,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None:
# Try getting the deleted event (should be 404)
get_response = client.get(
f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token}"}
headers={"Authorization": f"Bearer {access_token}"},
)
assert get_response.status_code == status.HTTP_404_NOT_FOUND
@@ -352,7 +435,7 @@ def test_delete_event_not_found(db: Session, client: TestClient) -> None:
response = client.delete(
f"/api/calendar/events/{non_existent_id}",
headers={"Authorization": f"Bearer {access_token}"}
headers={"Authorization": f"Bearer {access_token}"},
)
# The service layer raises NotFound, which should result in 404
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -367,7 +450,11 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
login_rsp1 = generators.login(db, user1.username, password_user1)
access_token1 = login_rsp1["access_token"]
payload = create_event_payload()
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
create_response = client.post(
"/api/calendar/events",
headers={"Authorization": f"Bearer {access_token1}"},
json=payload,
)
event_id = create_response.json()["id"]
# Log in as user2 and try to delete user1's event
@@ -376,7 +463,7 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
response = client.delete(
f"/api/calendar/events/{event_id}",
headers={"Authorization": f"Bearer {access_token2}"}
headers={"Authorization": f"Bearer {access_token2}"},
)
# The service layer raises NotFound if user_id doesn't match, resulting in 404
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -385,4 +472,3 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
event_in_db = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first()
assert event_in_db is not None
assert event_in_db.user_id == user1.id

View File

@@ -2,6 +2,7 @@ from fastapi.testclient import TestClient
# No database needed for this simple test
def test_health_check(client: TestClient):
"""Test the health check endpoint."""
response = client.get("/api/health")

View File

@@ -7,24 +7,37 @@ from datetime import datetime
from tests.helpers import generators
from modules.nlp.schemas import ProcessCommandResponse
from modules.nlp.models import MessageSender, ChatMessage # Import necessary models/enums
from modules.nlp.models import (
MessageSender,
ChatMessage,
) # Import necessary models/enums
# --- Mocks ---
# Mock the external AI call and internal service functions
@pytest.fixture(autouse=True)
def mock_nlp_services():
with patch("modules.nlp.api.process_request") as mock_process, \
patch("modules.nlp.api.ask_ai") as mock_ask, \
patch("modules.nlp.api.save_chat_message") as mock_save, \
patch("modules.nlp.api.get_chat_history") as mock_get_history, \
patch("modules.nlp.api.create_calendar_event") as mock_create_event, \
patch("modules.nlp.api.get_calendar_events") as mock_get_events, \
patch("modules.nlp.api.update_calendar_event") as mock_update_event, \
patch("modules.nlp.api.delete_calendar_event") as mock_delete_event, \
patch("modules.nlp.api.todo_service.create_todo") as mock_create_todo, \
patch("modules.nlp.api.todo_service.get_todos") as mock_get_todos, \
patch("modules.nlp.api.todo_service.update_todo") as mock_update_todo, \
patch("modules.nlp.api.todo_service.delete_todo") as mock_delete_todo:
with patch("modules.nlp.api.process_request") as mock_process, patch(
"modules.nlp.api.ask_ai"
) as mock_ask, patch("modules.nlp.api.save_chat_message") as mock_save, patch(
"modules.nlp.api.get_chat_history"
) as mock_get_history, patch(
"modules.nlp.api.create_calendar_event"
) as mock_create_event, patch(
"modules.nlp.api.get_calendar_events"
) as mock_get_events, patch(
"modules.nlp.api.update_calendar_event"
) as mock_update_event, patch(
"modules.nlp.api.delete_calendar_event"
) as mock_delete_event, patch(
"modules.nlp.api.todo_service.create_todo"
) as mock_create_todo, patch(
"modules.nlp.api.todo_service.get_todos"
) as mock_get_todos, patch(
"modules.nlp.api.todo_service.update_todo"
) as mock_update_todo, patch(
"modules.nlp.api.todo_service.delete_todo"
) as mock_delete_todo:
mocks = {
"process_request": mock_process,
"ask_ai": mock_ask,
@@ -41,21 +54,24 @@ def mock_nlp_services():
}
yield mocks
# --- Helper Function ---
def _login_user(db: Session, client: TestClient):
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
return user, login_rsp["access_token"], login_rsp["refresh_token"]
# --- Tests for /process-command ---
def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client)
user_input = "What is the capital of France?"
mock_nlp_services["process_request"].return_value = {
"intent": "ask_ai",
"params": {"request": user_input},
"response_text": "Let me check that for you."
"response_text": "Let me check that for you.",
}
mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris."
@@ -63,25 +79,45 @@ def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_servic
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input}
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ProcessCommandResponse(responses=["Let me check that for you.", "The capital of France is Paris."]).model_dump()
assert (
response.json()
== ProcessCommandResponse(
responses=["Let me check that for you.", "The capital of France is Paris."]
).model_dump()
)
# Verify save calls: user message, initial AI response, final AI answer
assert mock_nlp_services["save_chat_message"].call_count == 3
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you.")
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="The capital of France is Paris.")
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.USER, text=user_input
)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you."
)
mock_nlp_services["save_chat_message"].assert_any_call(
db,
user_id=user.id,
sender=MessageSender.AI,
text="The capital of France is Paris.",
)
mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input)
def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_services):
def test_process_command_get_calendar(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client)
user_input = "What are my events today?"
mock_nlp_services["process_request"].return_value = {
"intent": "get_calendar_events",
"params": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T23:59:59Z"}, # Example params
"response_text": "Okay, fetching your events."
"params": {
"start": "2024-01-01T00:00:00Z",
"end": "2024-01-01T23:59:59Z",
}, # Example params
"response_text": "Okay, fetching your events.",
}
# Mock the actual event model returned by the service
mock_event = MagicMock()
@@ -94,26 +130,32 @@ def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input}
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
expected_responses = [
"Okay, fetching your events.",
"Here are the events:",
"- Team Meeting (2024-01-01 10:00 - 11:00)"
"- Team Meeting (2024-01-01 10:00 - 11:00)",
]
assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump()
assert mock_nlp_services["save_chat_message"].call_count == 4 # User, Initial AI, Header, Event
assert (
response.json()
== ProcessCommandResponse(responses=expected_responses).model_dump()
)
assert (
mock_nlp_services["save_chat_message"].call_count == 4
) # User, Initial AI, Header, Event
mock_nlp_services["get_calendar_events"].assert_called_once()
def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client)
user_input = "Add buy milk to my list"
mock_nlp_services["process_request"].return_value = {
"intent": "add_todo",
"params": {"task": "buy milk"},
"response_text": "Adding it now."
"response_text": "Adding it now.",
}
# Mock the actual Todo model returned by the service
mock_todo = MagicMock()
@@ -125,81 +167,119 @@ def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_serv
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input}
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."]
assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump()
assert mock_nlp_services["save_chat_message"].call_count == 3 # User, Initial AI, Confirmation AI
assert (
response.json()
== ProcessCommandResponse(responses=expected_responses).model_dump()
)
assert (
mock_nlp_services["save_chat_message"].call_count == 3
) # User, Initial AI, Confirmation AI
mock_nlp_services["create_todo"].assert_called_once()
def test_process_command_clarification(client: TestClient, db: Session, mock_nlp_services):
def test_process_command_clarification(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client)
user_input = "Delete the event"
clarification_text = "Which event do you mean? Please provide the ID."
mock_nlp_services["process_request"].return_value = {
"intent": "clarification_needed",
"params": {"request": user_input},
"response_text": clarification_text
"response_text": clarification_text,
}
response = client.post(
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input}
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ProcessCommandResponse(responses=[clarification_text]).model_dump()
assert (
response.json()
== ProcessCommandResponse(responses=[clarification_text]).model_dump()
)
# Verify save calls: user message, clarification AI response
assert mock_nlp_services["save_chat_message"].call_count == 2
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=clarification_text)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.USER, text=user_input
)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text=clarification_text
)
# Ensure no action services were called
mock_nlp_services["delete_calendar_event"].assert_not_called()
def test_process_command_error_intent(client: TestClient, db: Session, mock_nlp_services):
def test_process_command_error_intent(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client)
user_input = "Gibberish request"
error_text = "Sorry, I didn't understand that."
mock_nlp_services["process_request"].return_value = {
"intent": "error",
"params": {},
"response_text": error_text
"response_text": error_text,
}
response = client.post(
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input}
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
assert (
response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
)
# Verify save calls: user message, error AI response
assert mock_nlp_services["save_chat_message"].call_count == 2
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=error_text)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.USER, text=user_input
)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text=error_text
)
# --- Tests for /history ---
def test_get_history(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client)
# Mock the history data returned by the service
mock_history = [
ChatMessage(id=1, user_id=user.id, sender=MessageSender.USER, text="Hello", timestamp=datetime.now()),
ChatMessage(id=2, user_id=user.id, sender=MessageSender.AI, text="Hi there!", timestamp=datetime.now())
ChatMessage(
id=1,
user_id=user.id,
sender=MessageSender.USER,
text="Hello",
timestamp=datetime.now(),
),
ChatMessage(
id=2,
user_id=user.id,
sender=MessageSender.AI,
text="Hi there!",
timestamp=datetime.now(),
),
]
mock_nlp_services["get_chat_history"].return_value = mock_history
response = client.get(
"/api/nlp/history",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK
@@ -208,11 +288,15 @@ def test_get_history(client: TestClient, db: Session, mock_nlp_services):
assert len(response_data) == 2
assert response_data[0]["text"] == "Hello"
assert response_data[1]["text"] == "Hi there!"
mock_nlp_services["get_chat_history"].assert_called_once_with(db, user_id=user.id, limit=50)
mock_nlp_services["get_chat_history"].assert_called_once_with(
db, user_id=user.id, limit=50
)
def test_get_history_unauthorized(client: TestClient):
response = client.get("/api/nlp/history")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
# Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.)
# Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete)

View File

@@ -5,14 +5,17 @@ from datetime import date
from tests.helpers import generators
# Helper Function
def _login_user(db: Session, client: TestClient):
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
return user, login_rsp["access_token"], login_rsp["refresh_token"]
# --- Test CRUD Operations ---
def test_create_todo(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
today_date = date.today()
@@ -20,14 +23,14 @@ def test_create_todo(client: TestClient, db: Session):
todo_data = {
"task": "Test TODO",
"date": f"{today_date.isoformat()}T00:00:00",
"remind": True
"remind": True,
}
response = client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json=todo_data
json=todo_data,
)
assert response.status_code == status.HTTP_201_CREATED
@@ -35,50 +38,66 @@ def test_create_todo(client: TestClient, db: Session):
assert data["task"] == todo_data["task"]
assert data["date"] == todo_data["date"]
assert data["remind"] == todo_data["remind"]
assert data["complete"] is False # Default
assert data["complete"] is False # Default
assert "id" in data
assert data["owner_id"] == user.id
def test_read_todos(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
# Create some todos for the user
client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 1"})
client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 2"})
client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"task": "Todo 1"},
)
client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"task": "Todo 2"},
)
# Create a todo for another user
other_user, other_password = generators.create_user(db)
other_login_rsp = generators.login(db, other_user.username, other_password)
other_access_token = other_login_rsp["access_token"]
other_refresh_token = other_login_rsp["refresh_token"]
client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"})
client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {other_access_token}"},
cookies={"refresh_token": other_refresh_token},
json={"task": "Other User Todo"},
)
response = client.get(
"/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data) == 2 # Should only get todos for the logged-in user
assert len(data) == 2 # Should only get todos for the logged-in user
assert data[0]["task"] == "Todo 1"
assert data[1]["task"] == "Todo 2"
def test_read_single_todo(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
create_response = client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"task": "Specific Todo"}
json={"task": "Specific Todo"},
)
todo_id = create_response.json()["id"]
response = client.get(
f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK
@@ -87,15 +106,17 @@ def test_read_single_todo(client: TestClient, db: Session):
assert data["task"] == "Specific Todo"
assert data["owner_id"] == user.id
def test_read_single_todo_not_found(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
response = client.get(
"/api/todos/9999", # Non-existent ID
"/api/todos/9999", # Non-existent ID
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_read_single_todo_forbidden(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
@@ -104,16 +125,26 @@ def test_read_single_todo_forbidden(client: TestClient, db: Session):
other_login_rsp = generators.login(db, other_user.username, other_password)
other_access_token = other_login_rsp["access_token"]
other_refresh_token = other_login_rsp["refresh_token"]
other_create_response = client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"})
other_create_response = client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {other_access_token}"},
cookies={"refresh_token": other_refresh_token},
json={"task": "Other User Todo"},
)
other_todo_id = other_create_response.json()["id"]
# Try to access the other user's todo
response = client.get(
f"/api/todos/{other_todo_id}",
headers={"Authorization": f"Bearer {access_token}"}, # Using the first user's token
cookies={"refresh_token": refresh_token}
headers={
"Authorization": f"Bearer {access_token}"
}, # Using the first user's token
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_404_NOT_FOUND # Service raises 404 if not found for *this* user
assert (
response.status_code == status.HTTP_404_NOT_FOUND
) # Service raises 404 if not found for *this* user
def test_update_todo(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
@@ -121,7 +152,7 @@ def test_update_todo(client: TestClient, db: Session):
"/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"task": "Update Me"}
json={"task": "Update Me"},
)
todo_id = create_response.json()["id"]
@@ -130,7 +161,7 @@ def test_update_todo(client: TestClient, db: Session):
f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json=update_data
json=update_data,
)
assert response.status_code == status.HTTP_200_OK
@@ -144,7 +175,7 @@ def test_update_todo(client: TestClient, db: Session):
get_response = client.get(
f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}
cookies={"refresh_token": refresh_token},
)
assert get_response.json()["task"] == update_data["task"]
assert get_response.json()["complete"] == update_data["complete"]
@@ -154,55 +185,60 @@ def test_update_todo_not_found(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
update_data = {"task": "Updated Task", "complete": True}
response = client.put(
"/api/todos/9999", # Non-existent ID
"/api/todos/9999", # Non-existent ID
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json=update_data
json=update_data,
)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_delete_todo(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
create_response = client.post(
"/api/todos/",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"task": "Delete Me"}
json={"task": "Delete Me"},
)
todo_id = create_response.json()["id"]
response = client.delete(
f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item
assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item
assert response.json()["id"] == todo_id
# Verify deletion by trying to read
get_response = client.get(
f"/api/todos/{todo_id}",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}
cookies={"refresh_token": refresh_token},
)
assert get_response.status_code == status.HTTP_404_NOT_FOUND
def test_delete_todo_not_found(client: TestClient, db: Session):
user, access_token, refresh_token = _login_user(db, client)
response = client.delete(
"/api/todos/9999", # Non-existent ID
"/api/todos/9999", # Non-existent ID
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token}
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
# --- Test Authentication/Authorization ---
def test_create_todo_unauthorized(client: TestClient):
response = client.post("/api/todos/", json={"task": "No Auth"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_read_todos_unauthorized(client: TestClient):
response = client.get("/api/todos/")
assert response.status_code == status.HTTP_401_UNAUTHORIZED