[REFORMAT] Ran black reformat
This commit is contained in:
@@ -7,7 +7,7 @@ from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
from core.database import Base # Import your Base
|
||||
from core.database import Base # Import your Base
|
||||
|
||||
|
||||
# --- Add project root to sys.path ---
|
||||
@@ -77,9 +77,7 @@ def run_migrations_online() -> None:
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
"""Initial migration with existing tables
|
||||
|
||||
Revision ID: 69069d6184b3
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2025-04-21 01:14:33.233195
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '69069d6184b3'
|
||||
revision: str = "69069d6184b3"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -5,13 +5,13 @@ Revises: 69069d6184b3
|
||||
Create Date: 2025-04-21 20:33:27.028529
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9a82960db482'
|
||||
down_revision: Union[str, None] = '69069d6184b3'
|
||||
revision: str = "9a82960db482"
|
||||
down_revision: Union[str, None] = "69069d6184b3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
# core/celery_app.py
|
||||
from celery import Celery
|
||||
from core.config import settings # Import your settings
|
||||
from core.config import settings # Import your settings
|
||||
|
||||
celery_app = Celery(
|
||||
"worker",
|
||||
broker=settings.REDIS_URL,
|
||||
backend=settings.REDIS_URL,
|
||||
include=["modules.auth.tasks", "modules.admin.tasks"] # Add paths to modules containing tasks
|
||||
include=[
|
||||
"modules.auth.tasks",
|
||||
"modules.admin.tasks",
|
||||
], # Add paths to modules containing tasks
|
||||
# Add other modules with tasks here, e.g., "modules.some_other_module.tasks"
|
||||
)
|
||||
|
||||
# Optional: Update Celery configuration directly if needed
|
||||
# celery_app.conf.update(task_track_started=True)
|
||||
# celery_app.conf.update(task_track_started=True)
|
||||
|
||||
@@ -4,12 +4,13 @@ import os
|
||||
|
||||
DOTENV_PATH = os.path.join(os.path.dirname(__file__), "../.env")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Database settings - reads DB_URL from environment or .env
|
||||
DB_URL: str = "postgresql://maia:maia@localhost:5432/maia"
|
||||
|
||||
# Redis settings - reads REDIS_URL from environment or .env, also used for Celery.
|
||||
REDIS_URL: str ="redis://localhost:6379/0"
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
|
||||
# JWT settings - reads from environment or .env
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
@@ -19,13 +20,14 @@ class Settings(BaseSettings):
|
||||
JWT_SECRET_KEY: str
|
||||
|
||||
# Other settings
|
||||
GOOGLE_API_KEY: str = "" # Example with a default
|
||||
GOOGLE_API_KEY: str = "" # Example with a default
|
||||
|
||||
class Config:
|
||||
# Tell pydantic-settings to load variables from a .env file
|
||||
env_file = DOTENV_PATH
|
||||
env_file_encoding = 'utf-8'
|
||||
extra = 'ignore'
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
# Create a single instance of the settings
|
||||
settings = Settings()
|
||||
|
||||
@@ -10,6 +10,7 @@ Base = declarative_base() # Used for models
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
global _engine
|
||||
if _engine is None:
|
||||
@@ -20,10 +21,13 @@ def get_engine():
|
||||
try:
|
||||
_engine.connect()
|
||||
except Exception:
|
||||
raise Exception("Database connection failed. Is the database server running?")
|
||||
raise Exception(
|
||||
"Database connection failed. Is the database server running?"
|
||||
)
|
||||
Base.metadata.create_all(_engine) # Create tables here
|
||||
return _engine
|
||||
|
||||
|
||||
def get_sessionmaker():
|
||||
global _SessionLocal
|
||||
if _SessionLocal is None:
|
||||
@@ -31,10 +35,11 @@ def get_sessionmaker():
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
return _SessionLocal
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
SessionLocal = get_sessionmaker()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
db.close()
|
||||
|
||||
@@ -8,20 +8,26 @@ from starlette.status import (
|
||||
HTTP_409_CONFLICT,
|
||||
)
|
||||
|
||||
|
||||
def bad_request_exception(detail: str = "Bad Request"):
|
||||
return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
|
||||
|
||||
|
||||
def unauthorized_exception(detail: str = "Unauthorized"):
|
||||
return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
|
||||
|
||||
|
||||
def forbidden_exception(detail: str = "Forbidden"):
|
||||
return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def not_found_exception(detail: str = "Not Found"):
|
||||
return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
|
||||
|
||||
|
||||
def internal_server_error_exception(detail: str = "Internal Server Error"):
|
||||
return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
|
||||
|
||||
|
||||
def conflict_exception(detail: str = "Conflict"):
|
||||
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
|
||||
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
|
||||
|
||||
@@ -11,11 +11,12 @@ import logging
|
||||
# import all models to ensure they are registered before create_all
|
||||
|
||||
|
||||
logging.getLogger('passlib').setLevel(logging.ERROR) # fix bc package logging is broken
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR) # fix bc package logging is broken
|
||||
|
||||
|
||||
# Create DB tables (remove in production; use migrations instead)
|
||||
def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]:
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Base.metadata.drop_all(bind=get_engine())
|
||||
@@ -24,6 +25,7 @@ def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]
|
||||
|
||||
return lifespan
|
||||
|
||||
|
||||
lifespan = lifespan_factory()
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@@ -34,17 +36,18 @@ app.include_router(router)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:8081", # Keep for web testing if needed
|
||||
"http://192.168.1.9:8081", # Add your mobile device/emulator origin (adjust port if needed)
|
||||
"http://localhost:8081", # Keep for web testing if needed
|
||||
"http://192.168.1.9:8081", # Add your mobile device/emulator origin (adjust port if needed)
|
||||
"http://192.168.255.221:8081",
|
||||
# Add other origins if necessary, e.g., production frontend URL
|
||||
],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"]
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Health endpoint
|
||||
@app.get("/api/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# modules/admin/api.py
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends # Import Body
|
||||
from pydantic import BaseModel # Import BaseModel
|
||||
from fastapi import APIRouter, Depends # Import Body
|
||||
from pydantic import BaseModel # Import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from core.database import get_db
|
||||
from modules.auth.dependencies import admin_only
|
||||
@@ -9,14 +9,17 @@ from .tasks import cleardb
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)])
|
||||
|
||||
|
||||
# Define a Pydantic model for the request body
|
||||
class ClearDbRequest(BaseModel):
|
||||
hard: bool
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def read_admin():
|
||||
return {"message": "Admin route"}
|
||||
|
||||
|
||||
# Change to POST and use the request body model
|
||||
@router.post("/cleardb")
|
||||
def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
|
||||
@@ -25,6 +28,6 @@ def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
|
||||
'hard'=True: Drop and recreate all tables.
|
||||
'hard'=False: Delete data from tables except users.
|
||||
"""
|
||||
hard = payload.hard # Get 'hard' from the payload
|
||||
hard = payload.hard # Get 'hard' from the payload
|
||||
cleardb.delay(hard)
|
||||
return {"message": "Clearing database in the background", "hard": hard}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# modules/admin/services.py
|
||||
|
||||
|
||||
## temp
|
||||
## temp
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from core.celery_app import celery_app
|
||||
|
||||
|
||||
@celery_app.task
|
||||
def cleardb(hard: bool):
|
||||
"""
|
||||
@@ -32,4 +33,4 @@ def cleardb(hard: bool):
|
||||
print(f"Deleting table: {table_name}")
|
||||
db.execute(table.delete())
|
||||
db.commit()
|
||||
return {"message": "Database cleared"}
|
||||
return {"message": "Database cleared"}
|
||||
|
||||
@@ -3,9 +3,24 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from jose import JWTError
|
||||
from modules.auth.models import User
|
||||
from modules.auth.schemas import UserCreate, UserResponse, Token, RefreshTokenRequest, LogoutRequest
|
||||
from modules.auth.schemas import (
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
RefreshTokenRequest,
|
||||
LogoutRequest,
|
||||
)
|
||||
from modules.auth.services import create_user
|
||||
from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens
|
||||
from modules.auth.security import (
|
||||
TokenType,
|
||||
get_current_user,
|
||||
oauth2_scheme,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_token,
|
||||
authenticate_user,
|
||||
blacklist_tokens,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
from core.database import get_db
|
||||
@@ -15,12 +30,19 @@ from core.exceptions import unauthorized_exception
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@router.post(
|
||||
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
|
||||
return create_user(user.username, user.password, user.name, db)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
|
||||
def login(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
):
|
||||
"""
|
||||
Authenticate user and return JWT tokens in the response body.
|
||||
"""
|
||||
@@ -30,39 +52,53 @@ def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annota
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
)
|
||||
|
||||
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username},
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
)
|
||||
refresh_token = create_refresh_token(data={"sub": user.username})
|
||||
|
||||
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
def refresh_token(payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]):
|
||||
def refresh_token(
|
||||
payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]
|
||||
):
|
||||
print("Refreshing token...")
|
||||
refresh_token = payload.refresh_token
|
||||
if not refresh_token:
|
||||
raise unauthorized_exception("Refresh token missing in request body")
|
||||
|
||||
user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db)
|
||||
user_data = verify_token(
|
||||
refresh_token, expected_token_type=TokenType.REFRESH, db=db
|
||||
)
|
||||
if not user_data:
|
||||
raise unauthorized_exception("Invalid refresh token")
|
||||
|
||||
new_access_token = create_access_token(data={"sub": user_data.username})
|
||||
return {"access_token": new_access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(payload: LogoutRequest, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme)):
|
||||
def logout(
|
||||
payload: LogoutRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
access_token: str = Depends(oauth2_scheme),
|
||||
):
|
||||
try:
|
||||
refresh_token = payload.refresh_token
|
||||
if not refresh_token:
|
||||
raise unauthorized_exception("Refresh token not found in request body")
|
||||
|
||||
blacklist_tokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
db=db
|
||||
)
|
||||
|
||||
blacklist_tokens(access_token=access_token, refresh_token=refresh_token, db=db)
|
||||
|
||||
return {"message": "Logged out successfully"}
|
||||
except JWTError:
|
||||
raise unauthorized_exception("Invalid token")
|
||||
raise unauthorized_exception("Invalid token")
|
||||
|
||||
@@ -5,14 +5,18 @@ from modules.auth.schemas import UserRole
|
||||
from modules.auth.models import User
|
||||
from core.exceptions import forbidden_exception
|
||||
|
||||
|
||||
class RoleChecker:
|
||||
def __init__(self, allowed_roles: list[UserRole]):
|
||||
self.allowed_roles = allowed_roles
|
||||
|
||||
def __call__(self, user: User = Depends(get_current_user)):
|
||||
if user.role not in self.allowed_roles:
|
||||
raise forbidden_exception("You do not have permission to perform this action.")
|
||||
raise forbidden_exception(
|
||||
"You do not have permission to perform this action."
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
admin_only = RoleChecker([UserRole.ADMIN])
|
||||
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
|
||||
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
|
||||
|
||||
@@ -4,10 +4,12 @@ from sqlalchemy import Column, Integer, String, Enum, DateTime
|
||||
from sqlalchemy.orm import relationship
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class UserRole(str, PyEnum):
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
@@ -2,33 +2,41 @@
|
||||
from enum import Enum as PyEnum
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
scopes: list[str] = []
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserRole(str, PyEnum):
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
name: str
|
||||
|
||||
|
||||
class UserPatch(BaseModel):
|
||||
name: str | None = None
|
||||
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
uuid: str
|
||||
username: str
|
||||
|
||||
@@ -18,6 +18,7 @@ from modules.auth.schemas import TokenData
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||
|
||||
|
||||
class TokenType(str, Enum):
|
||||
ACCESS = "access"
|
||||
REFRESH = "refresh"
|
||||
@@ -25,11 +26,13 @@ class TokenType(str, Enum):
|
||||
|
||||
password_hasher = PasswordHasher()
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password with Argon2 (and optional pepper)."""
|
||||
peppered_password = password + settings.PEPPER # Prepend/append pepper
|
||||
return password_hasher.hash(peppered_password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hashed version using Argon2."""
|
||||
peppered_password = plain_password + settings.PEPPER
|
||||
@@ -38,6 +41,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
except VerifyMismatchError:
|
||||
return False
|
||||
|
||||
|
||||
def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
||||
"""
|
||||
Authenticate a user by checking username/password against the database.
|
||||
@@ -45,41 +49,46 @@ def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
||||
"""
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
|
||||
|
||||
# If user not found or password doesn't match
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
# expire = datetime.now(timezone.utc) + timedelta(seconds=5)
|
||||
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
|
||||
return jwt.encode(
|
||||
to_encode,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM
|
||||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
|
||||
|
||||
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
expire = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
|
||||
return jwt.encode(
|
||||
to_encode,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM
|
||||
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
|
||||
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
|
||||
|
||||
def verify_token(
|
||||
token: str, expected_token_type: TokenType, db: Session
|
||||
) -> TokenData | None:
|
||||
"""Verify a JWT token and return TokenData if valid.
|
||||
|
||||
Parameters
|
||||
@@ -96,24 +105,32 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok
|
||||
TokenData | None
|
||||
TokenData instance if the token is valid, None otherwise.
|
||||
"""
|
||||
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
|
||||
is_blacklisted = (
|
||||
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
|
||||
is not None
|
||||
)
|
||||
if is_blacklisted:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
token_type: str = payload.get("token_type")
|
||||
|
||||
|
||||
if username is None or token_type != expected_token_type:
|
||||
return None
|
||||
|
||||
|
||||
return TokenData(username=username)
|
||||
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
|
||||
|
||||
def get_current_user(
|
||||
db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
@@ -121,26 +138,28 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen
|
||||
)
|
||||
|
||||
# Check if the token is blacklisted
|
||||
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
|
||||
is_blacklisted = (
|
||||
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
|
||||
is not None
|
||||
)
|
||||
if is_blacklisted:
|
||||
raise credentials_exception
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM]
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
user: User = db.query(User).filter(User.username == username).first()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
|
||||
"""Blacklist both access and refresh tokens.
|
||||
|
||||
@@ -154,7 +173,9 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
|
||||
Database session to perform the operation.
|
||||
"""
|
||||
for token in [access_token, refresh_token]:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||
|
||||
# Add the token to the blacklist
|
||||
@@ -163,10 +184,13 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def blacklist_token(token: str, db: Session) -> None:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||
|
||||
|
||||
# Add the token to the blacklist
|
||||
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
|
||||
db.add(blacklisted_token)
|
||||
|
||||
@@ -20,11 +20,13 @@ def create_user(username: str, password: str, name: str, db: Session) -> UserRes
|
||||
existing_user = db.query(User).filter(User.username == username).first()
|
||||
if existing_user:
|
||||
raise conflict_exception("Username already exists")
|
||||
|
||||
|
||||
hashed_password = hash_password(password)
|
||||
user_uuid = str(uuid.uuid4())
|
||||
user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid)
|
||||
user = User(
|
||||
username=username, hashed_password=hashed_password, name=name, uuid=user_uuid
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user) # Loads the generated ID
|
||||
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic
|
||||
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic
|
||||
|
||||
@@ -6,50 +6,63 @@ from typing import List, Optional
|
||||
from modules.auth.dependencies import get_current_user
|
||||
from core.database import get_db
|
||||
from modules.auth.models import User
|
||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate, CalendarEventResponse
|
||||
from modules.calendar.service import create_calendar_event, get_calendar_event_by_id, get_calendar_events, update_calendar_event, delete_calendar_event
|
||||
from modules.calendar.schemas import (
|
||||
CalendarEventCreate,
|
||||
CalendarEventUpdate,
|
||||
CalendarEventResponse,
|
||||
)
|
||||
from modules.calendar.service import (
|
||||
create_calendar_event,
|
||||
get_calendar_event_by_id,
|
||||
get_calendar_events,
|
||||
update_calendar_event,
|
||||
delete_calendar_event,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/calendar", tags=["calendar"])
|
||||
|
||||
@router.post("/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@router.post(
|
||||
"/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
def create_event(
|
||||
event: CalendarEventCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return create_calendar_event(db, user.id, event)
|
||||
|
||||
|
||||
@router.get("/events", response_model=List[CalendarEventResponse])
|
||||
def get_events(
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
start: Optional[datetime] = None,
|
||||
end: Optional[datetime] = None
|
||||
end: Optional[datetime] = None,
|
||||
):
|
||||
return get_calendar_events(db, user.id, start, end)
|
||||
|
||||
|
||||
@router.get("/events/{event_id}", response_model=CalendarEventResponse)
|
||||
def get_event_by_id(
|
||||
event_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
event = get_calendar_event_by_id(db, user.id, event_id)
|
||||
return event
|
||||
|
||||
|
||||
@router.patch("/events/{event_id}", response_model=CalendarEventResponse)
|
||||
def update_event(
|
||||
event_id: int,
|
||||
event: CalendarEventUpdate,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
return update_calendar_event(db, user.id, event_id, event)
|
||||
|
||||
|
||||
@router.delete("/events/{event_id}", status_code=204)
|
||||
def delete_event(
|
||||
event_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
delete_calendar_event(db, user.id, event_id)
|
||||
delete_calendar_event(db, user.id, event_id)
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
# modules/calendar/models.py
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Boolean # Add Boolean
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Boolean,
|
||||
) # Add Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from core.database import Base
|
||||
|
||||
|
||||
class CalendarEvent(Base):
|
||||
__tablename__ = "calendar_events"
|
||||
|
||||
@@ -12,10 +21,12 @@ class CalendarEvent(Base):
|
||||
start = Column(DateTime, nullable=False)
|
||||
end = Column(DateTime)
|
||||
location = Column(String)
|
||||
all_day = Column(Boolean, default=False) # Add all_day column
|
||||
all_day = Column(Boolean, default=False) # Add all_day column
|
||||
tags = Column(JSON)
|
||||
color = Column(String) # hex code for color
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False) # <-- Relationship
|
||||
color = Column(String) # hex code for color
|
||||
user_id = Column(
|
||||
Integer, ForeignKey("users.id"), nullable=False
|
||||
) # <-- Relationship
|
||||
|
||||
# Bi-directional relationship (for eager loading)
|
||||
user = relationship("User", back_populates="calendar_events")
|
||||
user = relationship("User", back_populates="calendar_events")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# modules/calendar/schemas.py
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, field_validator # Add field_validator
|
||||
from typing import List, Optional # Add List and Optional
|
||||
from pydantic import BaseModel, field_validator # Add field_validator
|
||||
from typing import List, Optional # Add List and Optional
|
||||
|
||||
|
||||
# Base schema for common fields, including tags
|
||||
class CalendarEventBase(BaseModel):
|
||||
@@ -10,21 +11,23 @@ class CalendarEventBase(BaseModel):
|
||||
start: datetime
|
||||
end: Optional[datetime] = None
|
||||
location: Optional[str] = None
|
||||
color: Optional[str] = None # Assuming color exists
|
||||
all_day: Optional[bool] = None # Add all_day field
|
||||
tags: Optional[List[str]] = None # Add optional tags
|
||||
color: Optional[str] = None # Assuming color exists
|
||||
all_day: Optional[bool] = None # Add all_day field
|
||||
tags: Optional[List[str]] = None # Add optional tags
|
||||
|
||||
@field_validator('tags', mode='before')
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def tags_validate_null_string(cls, v):
|
||||
if v == "Null":
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
# Schema for creating an event (inherits from Base)
|
||||
class CalendarEventCreate(CalendarEventBase):
|
||||
pass
|
||||
|
||||
|
||||
# Schema for updating an event (all fields optional)
|
||||
class CalendarEventUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
@@ -33,23 +36,24 @@ class CalendarEventUpdate(BaseModel):
|
||||
end: Optional[datetime] = None
|
||||
location: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
all_day: Optional[bool] = None # Add all_day field
|
||||
tags: Optional[List[str]] = None # Add optional tags for update
|
||||
all_day: Optional[bool] = None # Add all_day field
|
||||
tags: Optional[List[str]] = None # Add optional tags for update
|
||||
|
||||
@field_validator('tags', mode='before')
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def tags_validate_null_string(cls, v):
|
||||
if v == "Null":
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
# Schema for the response (inherits from Base, adds ID and user_id)
|
||||
class CalendarEventResponse(CalendarEventBase):
|
||||
id: int
|
||||
user_id: int
|
||||
tags: List[str] # Keep as List[str], remove default []
|
||||
tags: List[str] # Keep as List[str], remove default []
|
||||
|
||||
@field_validator('tags', mode='before')
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def tags_validate_none_to_list(cls, v):
|
||||
# If the value from the source object (e.g., ORM model) is None,
|
||||
@@ -59,4 +63,4 @@ class CalendarEventResponse(CalendarEventBase):
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
from_attributes = True
|
||||
|
||||
@@ -1,25 +1,34 @@
|
||||
# modules/calendar/service.py
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import or_ # Import or_
|
||||
from sqlalchemy import or_ # Import or_
|
||||
from datetime import datetime
|
||||
from modules.calendar.models import CalendarEvent
|
||||
from core.exceptions import not_found_exception
|
||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate # Import schemas
|
||||
from modules.calendar.schemas import (
|
||||
CalendarEventCreate,
|
||||
CalendarEventUpdate,
|
||||
) # Import schemas
|
||||
|
||||
|
||||
def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate):
|
||||
# Ensure tags is None if not provided or empty list, matching model
|
||||
tags_to_store = event_data.tags if event_data.tags else None
|
||||
event = CalendarEvent(
|
||||
**event_data.model_dump(exclude={'tags'}), # Use model_dump and exclude tags initially
|
||||
tags=tags_to_store, # Set tags separately
|
||||
user_id=user_id
|
||||
**event_data.model_dump(
|
||||
exclude={"tags"}
|
||||
), # Use model_dump and exclude tags initially
|
||||
tags=tags_to_store, # Set tags separately
|
||||
user_id=user_id,
|
||||
)
|
||||
db.add(event)
|
||||
db.commit()
|
||||
db.refresh(event)
|
||||
return event
|
||||
|
||||
def get_calendar_events(db: Session, user_id: int, start: datetime | None, end: datetime | None):
|
||||
|
||||
def get_calendar_events(
|
||||
db: Session, user_id: int, start: datetime | None, end: datetime | None
|
||||
):
|
||||
"""
|
||||
Retrieves calendar events for a user, optionally filtered by a date range.
|
||||
|
||||
@@ -46,9 +55,13 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
|
||||
query = query.filter(
|
||||
or_(
|
||||
# Case 1: Event has duration and overlaps
|
||||
(CalendarEvent.end is not None) & (CalendarEvent.start < end) & (CalendarEvent.end > start),
|
||||
(CalendarEvent.end is not None)
|
||||
& (CalendarEvent.start < end)
|
||||
& (CalendarEvent.end > start),
|
||||
# Case 2: Event is a point event within the range
|
||||
(CalendarEvent.end is None) & (CalendarEvent.start >= start) & (CalendarEvent.start < end)
|
||||
(CalendarEvent.end is None)
|
||||
& (CalendarEvent.start >= start)
|
||||
& (CalendarEvent.start < end),
|
||||
)
|
||||
)
|
||||
# If only start is provided, filter events starting on or after start
|
||||
@@ -60,37 +73,41 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
|
||||
elif end:
|
||||
# Includes events with duration ending <= end (or starting before end if end is None)
|
||||
# Includes point events occurring < end
|
||||
query = query.filter(
|
||||
query = query.filter(
|
||||
or_(
|
||||
# Event ends before the specified end time
|
||||
(CalendarEvent.end is not None) & (CalendarEvent.end <= end),
|
||||
# Point event occurs before the specified end time
|
||||
(CalendarEvent.end is None) & (CalendarEvent.start < end)
|
||||
(CalendarEvent.end is None) & (CalendarEvent.start < end),
|
||||
)
|
||||
)
|
||||
# Alternative interpretation for "ending before end": include events that *start* before end
|
||||
# query = query.filter(CalendarEvent.start < end)
|
||||
# Alternative interpretation for "ending before end": include events that *start* before end
|
||||
# query = query.filter(CalendarEvent.start < end)
|
||||
|
||||
return query.order_by(CalendarEvent.start).all() # Order by start time
|
||||
|
||||
return query.order_by(CalendarEvent.start).all() # Order by start time
|
||||
|
||||
def get_calendar_event_by_id(db: Session, user_id: int, event_id: int):
|
||||
event = db.query(CalendarEvent).filter(
|
||||
CalendarEvent.id == event_id,
|
||||
CalendarEvent.user_id == user_id
|
||||
).first()
|
||||
event = (
|
||||
db.query(CalendarEvent)
|
||||
.filter(CalendarEvent.id == event_id, CalendarEvent.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not event:
|
||||
raise not_found_exception()
|
||||
return event
|
||||
|
||||
def update_calendar_event(db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate):
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
|
||||
def update_calendar_event(
|
||||
db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate
|
||||
):
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
# Use model_dump with exclude_unset=True to only update provided fields
|
||||
update_data = event_data.model_dump(exclude_unset=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
# Ensure tags is handled correctly (set to None if empty list provided)
|
||||
if key == 'tags' and isinstance(value, list) and not value:
|
||||
if key == "tags" and isinstance(value, list) and not value:
|
||||
setattr(event, key, None)
|
||||
else:
|
||||
setattr(event, key, value)
|
||||
@@ -99,7 +116,8 @@ def update_calendar_event(db: Session, user_id: int, event_id: int, event_data:
|
||||
db.refresh(event)
|
||||
return event
|
||||
|
||||
|
||||
def delete_calendar_event(db: Session, user_id: int, event_id: int):
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
db.delete(event)
|
||||
db.commit()
|
||||
db.commit()
|
||||
|
||||
@@ -7,13 +7,27 @@ from core.database import get_db
|
||||
|
||||
from modules.auth.dependencies import get_current_user
|
||||
from modules.auth.models import User
|
||||
|
||||
# Import the new service functions and Enum
|
||||
from modules.nlp.service import process_request, ask_ai, save_chat_message, get_chat_history, MessageSender
|
||||
from modules.nlp.service import (
|
||||
process_request,
|
||||
ask_ai,
|
||||
save_chat_message,
|
||||
get_chat_history,
|
||||
MessageSender,
|
||||
)
|
||||
|
||||
# Import the response schema and the new ChatMessage model for response type hinting
|
||||
from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse
|
||||
from modules.calendar.service import create_calendar_event, get_calendar_events, update_calendar_event, delete_calendar_event
|
||||
from modules.calendar.service import (
|
||||
create_calendar_event,
|
||||
get_calendar_events,
|
||||
update_calendar_event,
|
||||
delete_calendar_event,
|
||||
)
|
||||
from modules.calendar.models import CalendarEvent
|
||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate
|
||||
|
||||
# Import TODO services, schemas, and model
|
||||
from modules.todo import service as todo_service
|
||||
from modules.todo.models import Todo
|
||||
@@ -21,17 +35,20 @@ from modules.todo.schemas import TodoCreate, TodoUpdate
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
id: int
|
||||
sender: MessageSender # Use the enum directly
|
||||
sender: MessageSender # Use the enum directly
|
||||
text: str
|
||||
timestamp: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True # Allow Pydantic to work with ORM models
|
||||
|
||||
from_attributes = True # Allow Pydantic to work with ORM models
|
||||
|
||||
|
||||
router = APIRouter(prefix="/nlp", tags=["nlp"])
|
||||
|
||||
|
||||
# Helper to format calendar events (expects list of CalendarEvent models)
|
||||
def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
|
||||
if not events:
|
||||
@@ -39,12 +56,15 @@ def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
|
||||
formatted = ["Here are the events:"]
|
||||
for event in events:
|
||||
# Access attributes directly from the model instance
|
||||
start_str = event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
|
||||
start_str = (
|
||||
event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
|
||||
)
|
||||
end_str = event.end.strftime("%H:%M") if event.end else ""
|
||||
title = event.title or "Untitled Event"
|
||||
formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})")
|
||||
return formatted
|
||||
|
||||
|
||||
# Helper to format TODO items (expects list of Todo models)
|
||||
def format_todos(todos: List[Todo]) -> List[str]:
|
||||
if not todos:
|
||||
@@ -54,19 +74,28 @@ def format_todos(todos: List[Todo]) -> List[str]:
|
||||
status = "[X]" if todo.complete else "[ ]"
|
||||
date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else ""
|
||||
remind_str = " (Reminder)" if todo.remind else ""
|
||||
formatted.append(f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})")
|
||||
formatted.append(
|
||||
f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})"
|
||||
)
|
||||
return formatted
|
||||
|
||||
|
||||
# Update the response model for the endpoint
|
||||
@router.post("/process-command", response_model=ProcessCommandResponse)
|
||||
def process_command(request_data: ProcessCommandRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
def process_command(
|
||||
request_data: ProcessCommandRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Process the user command, save messages, execute action, save response, and return user-friendly responses.
|
||||
"""
|
||||
user_input = request_data.user_input
|
||||
|
||||
# --- Save User Message ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.USER, text=user_input)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.USER, text=user_input
|
||||
)
|
||||
# ------------------------
|
||||
|
||||
command_data = process_request(user_input)
|
||||
@@ -74,11 +103,13 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
|
||||
params = command_data["params"]
|
||||
response_text = command_data["response_text"]
|
||||
|
||||
responses = [response_text] # Start with the initial response
|
||||
responses = [response_text] # Start with the initial response
|
||||
|
||||
# --- Save Initial AI Response ---
|
||||
# Save the first response generated by process_request
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=response_text
|
||||
)
|
||||
# -----------------------------
|
||||
|
||||
if intent == "error":
|
||||
@@ -97,139 +128,233 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
|
||||
ai_answer = ask_ai(**params)
|
||||
responses.append(ai_answer)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "get_calendar_events":
|
||||
events: List[CalendarEvent] = get_calendar_events(db, current_user.id, **params)
|
||||
events: List[CalendarEvent] = get_calendar_events(
|
||||
db, current_user.id, **params
|
||||
)
|
||||
formatted_responses = format_calendar_events(events)
|
||||
responses.extend(formatted_responses)
|
||||
# --- Save Additional AI Responses ---
|
||||
for resp in formatted_responses:
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
|
||||
)
|
||||
# ----------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "add_calendar_event":
|
||||
event_data = CalendarEventCreate(**params)
|
||||
created_event = create_calendar_event(db, current_user.id, event_data)
|
||||
start_str = created_event.start.strftime("%Y-%m-%d %H:%M") if created_event.start else "No start time"
|
||||
start_str = (
|
||||
created_event.start.strftime("%Y-%m-%d %H:%M")
|
||||
if created_event.start
|
||||
else "No start time"
|
||||
)
|
||||
title = created_event.title or "Untitled Event"
|
||||
add_response = f"Added: {title} starting at {start_str}."
|
||||
responses.append(add_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=add_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "update_calendar_event":
|
||||
event_id = params.pop('event_id', None)
|
||||
event_id = params.pop("event_id", None)
|
||||
if event_id is None:
|
||||
# Save the error message before raising
|
||||
error_msg = "Event ID is required for update."
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
event_data = CalendarEventUpdate(**params)
|
||||
updated_event = update_calendar_event(db, current_user.id, event_id, event_data=event_data)
|
||||
updated_event = update_calendar_event(
|
||||
db, current_user.id, event_id, event_data=event_data
|
||||
)
|
||||
title = updated_event.title or "Untitled Event"
|
||||
update_response = f"Updated event ID {updated_event.id}: {title}."
|
||||
responses.append(update_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=update_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "delete_calendar_event":
|
||||
event_id = params.get('event_id')
|
||||
event_id = params.get("event_id")
|
||||
if event_id is None:
|
||||
# Save the error message before raising
|
||||
error_msg = "Event ID is required for delete."
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
delete_calendar_event(db, current_user.id, event_id)
|
||||
delete_response = f"Deleted event ID {event_id}."
|
||||
responses.append(delete_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=delete_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
# --- Add TODO Cases ---
|
||||
# --- Add TODO Cases ---
|
||||
case "get_todos":
|
||||
todos: List[Todo] = todo_service.get_todos(db, user=current_user, **params)
|
||||
todos: List[Todo] = todo_service.get_todos(
|
||||
db, user=current_user, **params
|
||||
)
|
||||
formatted_responses = format_todos(todos)
|
||||
responses.extend(formatted_responses)
|
||||
# --- Save Additional AI Responses ---
|
||||
for resp in formatted_responses:
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
|
||||
)
|
||||
# ----------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "add_todo":
|
||||
todo_data = TodoCreate(**params)
|
||||
created_todo = todo_service.create_todo(db, todo=todo_data, user=current_user)
|
||||
add_response = f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
|
||||
created_todo = todo_service.create_todo(
|
||||
db, todo=todo_data, user=current_user
|
||||
)
|
||||
add_response = (
|
||||
f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
|
||||
)
|
||||
responses.append(add_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=add_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "update_todo":
|
||||
todo_id = params.pop('todo_id', None)
|
||||
todo_id = params.pop("todo_id", None)
|
||||
if todo_id is None:
|
||||
error_msg = "TODO ID is required for update."
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
todo_data = TodoUpdate(**params)
|
||||
updated_todo = todo_service.update_todo(db, todo_id=todo_id, todo_update=todo_data, user=current_user)
|
||||
update_response = f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
|
||||
if 'complete' in params:
|
||||
status = "complete" if params['complete'] else "incomplete"
|
||||
updated_todo = todo_service.update_todo(
|
||||
db, todo_id=todo_id, todo_update=todo_data, user=current_user
|
||||
)
|
||||
update_response = (
|
||||
f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
|
||||
)
|
||||
if "complete" in params:
|
||||
status = "complete" if params["complete"] else "incomplete"
|
||||
update_response += f" Marked as {status}."
|
||||
responses.append(update_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=update_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
case "delete_todo":
|
||||
todo_id = params.get('todo_id')
|
||||
todo_id = params.get("todo_id")
|
||||
if todo_id is None:
|
||||
error_msg = "TODO ID is required for delete."
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
deleted_todo = todo_service.delete_todo(db, todo_id=todo_id, user=current_user)
|
||||
delete_response = f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
|
||||
deleted_todo = todo_service.delete_todo(
|
||||
db, todo_id=todo_id, user=current_user
|
||||
)
|
||||
delete_response = (
|
||||
f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
|
||||
)
|
||||
responses.append(delete_response)
|
||||
# --- Save Additional AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response)
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=delete_response,
|
||||
)
|
||||
# ---------------------------------
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
# --- End TODO Cases ---
|
||||
|
||||
case _:
|
||||
print(f"Warning: Unhandled intent '{intent}' reached api.py match statement.")
|
||||
case _:
|
||||
print(
|
||||
f"Warning: Unhandled intent '{intent}' reached api.py match statement."
|
||||
)
|
||||
# The initial response_text was already saved
|
||||
return ProcessCommandResponse(responses=responses)
|
||||
|
||||
except HTTPException as http_exc:
|
||||
# Don't save again if already saved before raising
|
||||
if http_exc.status_code != 400 or ('event_id' not in http_exc.detail.lower()):
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=http_exc.detail)
|
||||
if http_exc.status_code != 400 or ("event_id" not in http_exc.detail.lower()):
|
||||
save_chat_message(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
sender=MessageSender.AI,
|
||||
text=http_exc.detail,
|
||||
)
|
||||
raise http_exc
|
||||
except Exception as e:
|
||||
print(f"Error executing intent '{intent}': {e}")
|
||||
error_response = "Sorry, I encountered an error while trying to perform that action."
|
||||
error_response = (
|
||||
"Sorry, I encountered an error while trying to perform that action."
|
||||
)
|
||||
# --- Save Final Error AI Response ---
|
||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_response)
|
||||
save_chat_message(
|
||||
db, user_id=current_user.id, sender=MessageSender.AI, text=error_response
|
||||
)
|
||||
# ----------------------------------
|
||||
return ProcessCommandResponse(responses=[error_response])
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[ChatMessageResponse])
|
||||
def read_chat_history(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
||||
def read_chat_history(
|
||||
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieves the last 50 chat messages for the current user."""
|
||||
history = get_chat_history(db, user_id=current_user.id, limit=50)
|
||||
return history
|
||||
# -------------------------------------
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
\
|
||||
# /home/cdp/code/MAIA/backend/modules/nlp/models.py
|
||||
from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey, Enum as SQLEnum
|
||||
from sqlalchemy.orm import relationship
|
||||
@@ -7,10 +6,12 @@ import enum
|
||||
|
||||
from core.database import Base
|
||||
|
||||
|
||||
class MessageSender(enum.Enum):
|
||||
USER = "user"
|
||||
AI = "ai"
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_messages"
|
||||
|
||||
@@ -20,4 +21,4 @@ class ChatMessage(Base):
|
||||
text = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
owner = relationship("User") # Relationship to the User model
|
||||
owner = relationship("User") # Relationship to the User model
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class ProcessCommandRequest(BaseModel):
|
||||
user_input: str
|
||||
|
||||
|
||||
class ProcessCommandResponse(BaseModel):
|
||||
responses: List[str]
|
||||
# Optional: Keep details if needed for specific frontend logic beyond display
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# modules/nlp/service.py
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc # Import desc for ordering
|
||||
from sqlalchemy import desc # Import desc for ordering
|
||||
from google import genai
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import List # Import List
|
||||
from typing import List # Import List
|
||||
|
||||
# Import the new model and Enum
|
||||
from .models import ChatMessage, MessageSender
|
||||
@@ -14,7 +14,8 @@ from core.config import settings
|
||||
client = genai.Client(api_key=settings.GOOGLE_API_KEY)
|
||||
|
||||
### Base prompt for MAIA, used for inital user requests
|
||||
SYSTEM_PROMPT = """
|
||||
SYSTEM_PROMPT = (
|
||||
"""
|
||||
You are MAIA - My AI Assistant. Your job is to parse user requests into structured JSON commands and generate a user-facing response text.
|
||||
|
||||
Available functions/intents:
|
||||
@@ -109,8 +110,11 @@ MAIA:
|
||||
"response_text": "Okay, I've deleted task 2 from your list."
|
||||
}
|
||||
|
||||
The datetime right now is """+str(datetime.now(timezone.utc))+""".
|
||||
The datetime right now is """
|
||||
+ str(datetime.now(timezone.utc))
|
||||
+ """.
|
||||
"""
|
||||
)
|
||||
|
||||
### Prompt for MAIA to forward user request to AI
|
||||
SYSTEM_FORWARD_PROMPT = f"""
|
||||
@@ -123,6 +127,7 @@ Here is the user request:
|
||||
|
||||
# --- Chat History Service Functions ---
|
||||
|
||||
|
||||
def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: str):
|
||||
"""Saves a chat message to the database."""
|
||||
db_message = ChatMessage(user_id=user_id, sender=sender, text=text)
|
||||
@@ -131,16 +136,21 @@ def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: st
|
||||
db.refresh(db_message)
|
||||
return db_message
|
||||
|
||||
|
||||
def get_chat_history(db: Session, user_id: int, limit: int = 50) -> List[ChatMessage]:
|
||||
"""Retrieves the last 'limit' chat messages for a user."""
|
||||
return db.query(ChatMessage)\
|
||||
.filter(ChatMessage.user_id == user_id)\
|
||||
.order_by(desc(ChatMessage.timestamp))\
|
||||
.limit(limit)\
|
||||
.all()[::-1] # Reverse to get oldest first for display order
|
||||
return (
|
||||
db.query(ChatMessage)
|
||||
.filter(ChatMessage.user_id == user_id)
|
||||
.order_by(desc(ChatMessage.timestamp))
|
||||
.limit(limit)
|
||||
.all()[::-1]
|
||||
) # Reverse to get oldest first for display order
|
||||
|
||||
|
||||
# --- Existing NLP Service Functions ---
|
||||
|
||||
|
||||
def process_request(request: str):
|
||||
"""
|
||||
Process the user request using the Google GenAI API.
|
||||
@@ -152,7 +162,7 @@ def process_request(request: str):
|
||||
config={
|
||||
"temperature": 0.3, # Less creativity, more factual
|
||||
"response_mime_type": "application/json",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Parse the JSON response
|
||||
@@ -160,7 +170,9 @@ def process_request(request: str):
|
||||
parsed_response = json.loads(response.text)
|
||||
# Validate required fields
|
||||
if not all(k in parsed_response for k in ("intent", "params", "response_text")):
|
||||
raise ValueError("AI response missing required fields (intent, params, response_text)")
|
||||
raise ValueError(
|
||||
"AI response missing required fields (intent, params, response_text)"
|
||||
)
|
||||
return parsed_response
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Error parsing AI response: {e}")
|
||||
@@ -169,9 +181,10 @@ def process_request(request: str):
|
||||
return {
|
||||
"intent": "error",
|
||||
"params": {},
|
||||
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?"
|
||||
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?",
|
||||
}
|
||||
|
||||
|
||||
def ask_ai(request: str):
|
||||
"""
|
||||
Ask the AI a question.
|
||||
@@ -179,6 +192,6 @@ def ask_ai(request: str):
|
||||
"""
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents=SYSTEM_FORWARD_PROMPT+request,
|
||||
contents=SYSTEM_FORWARD_PROMPT + request,
|
||||
)
|
||||
return response.text
|
||||
return response.text
|
||||
|
||||
@@ -5,58 +5,65 @@ from typing import List
|
||||
|
||||
from . import service, schemas
|
||||
from core.database import get_db
|
||||
from modules.auth.dependencies import get_current_user # Corrected import
|
||||
from modules.auth.models import User # Assuming User model is in auth.models
|
||||
from modules.auth.dependencies import get_current_user # Corrected import
|
||||
from modules.auth.models import User # Assuming User model is in auth.models
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/todos",
|
||||
tags=["todos"],
|
||||
dependencies=[Depends(get_current_user)], # Corrected dependency
|
||||
dependencies=[Depends(get_current_user)], # Corrected dependency
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=schemas.Todo, status_code=status.HTTP_201_CREATED)
|
||||
def create_todo_endpoint(
|
||||
todo: schemas.TodoCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
return service.create_todo(db=db, todo=todo, user=current_user)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[schemas.Todo])
|
||||
def read_todos_endpoint(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
todos = service.get_todos(db=db, user=current_user, skip=skip, limit=limit)
|
||||
return todos
|
||||
|
||||
|
||||
@router.get("/{todo_id}", response_model=schemas.Todo)
|
||||
def read_todo_endpoint(
|
||||
todo_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
db_todo = service.get_todo(db=db, todo_id=todo_id, user=current_user)
|
||||
if db_todo is None:
|
||||
raise HTTPException(status_code=404, detail="Todo not found")
|
||||
return db_todo
|
||||
|
||||
|
||||
@router.put("/{todo_id}", response_model=schemas.Todo)
|
||||
def update_todo_endpoint(
|
||||
todo_id: int,
|
||||
todo_update: schemas.TodoUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
return service.update_todo(db=db, todo_id=todo_id, todo_update=todo_update, user=current_user)
|
||||
return service.update_todo(
|
||||
db=db, todo_id=todo_id, todo_update=todo_update, user=current_user
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{todo_id}", response_model=schemas.Todo)
|
||||
def delete_todo_endpoint(
|
||||
todo_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
||||
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||
):
|
||||
return service.delete_todo(db=db, todo_id=todo_id, user=current_user)
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from core.database import Base
|
||||
|
||||
|
||||
class Todo(Base):
|
||||
__tablename__ = "todos"
|
||||
|
||||
@@ -13,4 +14,6 @@ class Todo(Base):
|
||||
complete = Column(Boolean, default=False)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||
|
||||
owner = relationship("User") # Add relationship if needed, assuming User model exists in auth.models
|
||||
owner = relationship(
|
||||
"User"
|
||||
) # Add relationship if needed, assuming User model exists in auth.models
|
||||
|
||||
@@ -3,21 +3,25 @@ from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import datetime
|
||||
|
||||
|
||||
class TodoBase(BaseModel):
|
||||
task: str
|
||||
date: Optional[datetime.datetime] = None
|
||||
remind: bool = False
|
||||
complete: bool = False
|
||||
|
||||
|
||||
class TodoCreate(TodoBase):
|
||||
pass
|
||||
|
||||
|
||||
class TodoUpdate(BaseModel):
|
||||
task: Optional[str] = None
|
||||
date: Optional[datetime.datetime] = None
|
||||
remind: Optional[bool] = None
|
||||
complete: Optional[bool] = None
|
||||
|
||||
|
||||
class Todo(TodoBase):
|
||||
id: int
|
||||
owner_id: int
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# backend/modules/todo/service.py
|
||||
from sqlalchemy.orm import Session
|
||||
from . import models, schemas
|
||||
from modules.auth.models import User # Assuming User model is in auth.models
|
||||
from modules.auth.models import User # Assuming User model is in auth.models
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
|
||||
db_todo = models.Todo(**todo.dict(), owner_id=user.id)
|
||||
db.add(db_todo)
|
||||
@@ -11,17 +12,34 @@ def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
|
||||
db.refresh(db_todo)
|
||||
return db_todo
|
||||
|
||||
|
||||
def get_todos(db: Session, user: User, skip: int = 0, limit: int = 100):
|
||||
return db.query(models.Todo).filter(models.Todo.owner_id == user.id).offset(skip).limit(limit).all()
|
||||
return (
|
||||
db.query(models.Todo)
|
||||
.filter(models.Todo.owner_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_todo(db: Session, todo_id: int, user: User):
|
||||
db_todo = db.query(models.Todo).filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id).first()
|
||||
db_todo = (
|
||||
db.query(models.Todo)
|
||||
.filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id)
|
||||
.first()
|
||||
)
|
||||
if db_todo is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found"
|
||||
)
|
||||
return db_todo
|
||||
|
||||
|
||||
def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user: User):
|
||||
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence
|
||||
db_todo = get_todo(
|
||||
db=db, todo_id=todo_id, user=user
|
||||
) # Reuse get_todo to check ownership and existence
|
||||
update_data = todo_update.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_todo, key, value)
|
||||
@@ -29,8 +47,11 @@ def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user
|
||||
db.refresh(db_todo)
|
||||
return db_todo
|
||||
|
||||
|
||||
def delete_todo(db: Session, todo_id: int, user: User):
|
||||
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence
|
||||
db_todo = get_todo(
|
||||
db=db, todo_id=todo_id, user=user
|
||||
) # Reuse get_todo to check ownership and existence
|
||||
db.delete(db_todo)
|
||||
db.commit()
|
||||
return db_todo
|
||||
|
||||
@@ -11,37 +11,52 @@ from modules.auth.models import User
|
||||
|
||||
router = APIRouter(prefix="/user", tags=["user"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
def me(db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||
def me(
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Get the current user. Requires user to be logged in.
|
||||
Returns the user object.
|
||||
"""
|
||||
"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.get("/{username}", response_model=UserResponse)
|
||||
def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||
def get_user(
|
||||
username: str,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Get a user by username.
|
||||
Returns the user object.
|
||||
"""
|
||||
if current_user.username != username:
|
||||
raise forbidden_exception("You can only view your own profile")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise not_found_exception("User not found")
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/{username}", response_model=UserResponse)
|
||||
def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||
def update_user(
|
||||
username: str,
|
||||
user_data: UserPatch,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Update a user by username.
|
||||
Returns the updated user object.
|
||||
"""
|
||||
if current_user.username != username:
|
||||
raise forbidden_exception("You can only update your own profile")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise not_found_exception("User not found")
|
||||
@@ -60,19 +75,24 @@ def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depe
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/{username}", response_model=UserResponse)
|
||||
def delete_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||
def delete_user(
|
||||
username: str,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> UserResponse:
|
||||
"""
|
||||
Delete a user by username.
|
||||
Returns the deleted user object.
|
||||
"""
|
||||
if current_user.username != username:
|
||||
raise forbidden_exception("You can only delete your own profile")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
raise not_found_exception("User not found")
|
||||
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
return user
|
||||
return user
|
||||
|
||||
@@ -12,6 +12,7 @@ from core.database import get_db, get_sessionmaker
|
||||
|
||||
fake = Faker()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def postgres_container() -> Generator[PostgresContainer, None, None]:
|
||||
"""Fixture to create a PostgreSQL container for testing."""
|
||||
@@ -21,7 +22,8 @@ def postgres_container() -> Generator[PostgresContainer, None, None]:
|
||||
print(f"Postgres container started at {settings.DB_URL}")
|
||||
yield postgres
|
||||
print("Postgres container stopped.")
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db(postgres_container) -> Generator[Session, None, None]:
|
||||
"""Function-scoped database session with rollback"""
|
||||
@@ -34,25 +36,28 @@ def db(postgres_container) -> Generator[Session, None, None]:
|
||||
session.rollback()
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db: Session) -> Generator[TestClient, None, None]:
|
||||
"""Function-scoped test client with dependency override"""
|
||||
from main import app
|
||||
|
||||
|
||||
# Override the database dependency
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
pass # Don't close session here
|
||||
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None:
|
||||
from main import app
|
||||
app.dependency_overrides[dependency] = lambda: mocked_response
|
||||
|
||||
app.dependency_overrides[dependency] = lambda: mocked_response
|
||||
|
||||
@@ -5,17 +5,24 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.config import settings
|
||||
from modules.auth.models import User
|
||||
from modules.auth.security import authenticate_user, create_access_token, create_refresh_token, hash_password
|
||||
from modules.auth.security import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
hash_password,
|
||||
)
|
||||
from modules.auth.schemas import UserRole
|
||||
from tests.conftest import fake
|
||||
from typing import Optional # Import Optional
|
||||
from typing import Optional # Import Optional
|
||||
|
||||
|
||||
def create_user(db: Session, is_admin: bool = False, username: Optional[str] = None) -> User:
|
||||
def create_user(
|
||||
db: Session, is_admin: bool = False, username: Optional[str] = None
|
||||
) -> User:
|
||||
unhashed_password = fake.password()
|
||||
_user = User(
|
||||
name=fake.name(),
|
||||
username=username or fake.user_name(), # Use provided username or generate one
|
||||
username=username or fake.user_name(), # Use provided username or generate one
|
||||
hashed_password=hash_password(unhashed_password),
|
||||
uuid=uuid_pkg.uuid4(),
|
||||
role=UserRole.ADMIN if is_admin else UserRole.USER,
|
||||
@@ -24,14 +31,18 @@ def create_user(db: Session, is_admin: bool = False, username: Optional[str] = N
|
||||
db.add(_user)
|
||||
db.commit()
|
||||
db.refresh(_user)
|
||||
return _user, unhashed_password # return for testing
|
||||
return _user, unhashed_password # return for testing
|
||||
|
||||
|
||||
def login(db: Session, username: str, password: str) -> str:
|
||||
user = authenticate_user(username, password, db)
|
||||
if not user:
|
||||
raise Exception("Incorrect username or password")
|
||||
|
||||
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username},
|
||||
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
)
|
||||
refresh_token = create_refresh_token(data={"sub": user.username})
|
||||
|
||||
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||
@@ -40,4 +51,4 @@ def login(db: Session, username: str, password: str) -> str:
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"max_age": max_age,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,71 +7,93 @@ from tests.helpers import generators
|
||||
|
||||
# Test admin routes require admin privileges
|
||||
|
||||
|
||||
def test_read_admin_unauthorized(client: TestClient) -> None:
|
||||
"""Test accessing admin route without authentication."""
|
||||
response = client.get("/api/admin/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
def test_read_admin_forbidden(db: Session, client: TestClient) -> None:
|
||||
"""Test accessing admin route as a non-admin user."""
|
||||
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
|
||||
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
|
||||
response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"})
|
||||
response = client.get(
|
||||
"/api/admin/", headers={"Authorization": f"Bearer {access_token}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_read_admin_success(db: Session, client: TestClient) -> None:
|
||||
"""Test accessing admin route as an admin user."""
|
||||
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
|
||||
admin_user, password = generators.create_user(
|
||||
db, is_admin=True
|
||||
) # Use is_admin=True
|
||||
login_rsp = generators.login(db, admin_user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
|
||||
response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"})
|
||||
response = client.get(
|
||||
"/api/admin/", headers={"Authorization": f"Bearer {access_token}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == {"message": "Admin route"}
|
||||
|
||||
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
|
||||
|
||||
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
|
||||
def test_clear_db_soft(mock_cleardb_delay, db: Session, client: TestClient) -> None:
|
||||
"""Test soft clearing the database as admin."""
|
||||
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
|
||||
admin_user, password = generators.create_user(
|
||||
db, is_admin=True
|
||||
) # Use is_admin=True
|
||||
login_rsp = generators.login(db, admin_user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
|
||||
response = client.post(
|
||||
"/api/admin/cleardb",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"hard": False}
|
||||
json={"hard": False},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == {"message": "Clearing database in the background", "hard": False}
|
||||
assert response.json() == {
|
||||
"message": "Clearing database in the background",
|
||||
"hard": False,
|
||||
}
|
||||
mock_cleardb_delay.assert_called_once_with(False)
|
||||
|
||||
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
|
||||
|
||||
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
|
||||
def test_clear_db_hard(mock_cleardb_delay, db: Session, client: TestClient) -> None:
|
||||
"""Test hard clearing the database as admin."""
|
||||
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
|
||||
admin_user, password = generators.create_user(
|
||||
db, is_admin=True
|
||||
) # Use is_admin=True
|
||||
login_rsp = generators.login(db, admin_user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
|
||||
response = client.post(
|
||||
"/api/admin/cleardb",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"hard": True}
|
||||
json={"hard": True},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == {"message": "Clearing database in the background", "hard": True}
|
||||
assert response.json() == {
|
||||
"message": "Clearing database in the background",
|
||||
"hard": True,
|
||||
}
|
||||
mock_cleardb_delay.assert_called_once_with(True)
|
||||
|
||||
|
||||
def test_clear_db_forbidden(db: Session, client: TestClient) -> None:
|
||||
"""Test clearing the database as a non-admin user."""
|
||||
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
|
||||
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
|
||||
response = client.post(
|
||||
"/api/admin/cleardb",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json={"hard": False}
|
||||
json={"hard": False},
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
@@ -34,6 +34,7 @@ def test_register(client: TestClient) -> None:
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
|
||||
|
||||
def test_login(db: Session, client: TestClient) -> None:
|
||||
user, unhashed_password = generators.create_user(db)
|
||||
|
||||
@@ -51,17 +52,21 @@ def test_login(db: Session, client: TestClient) -> None:
|
||||
assert "token_type" in response_data
|
||||
assert response_data["token_type"] == "bearer"
|
||||
|
||||
|
||||
def test_refresh_token(db: Session, client: TestClient) -> None:
|
||||
user, unhashed_password = generators.create_user(db)
|
||||
rsp = generators.login(db, user.username, unhashed_password)
|
||||
access_token = rsp["access_token"]
|
||||
refresh_token = rsp["refresh_token"]
|
||||
|
||||
time.sleep(1) # Sleep to ensure tokens won't be identical
|
||||
time.sleep(1) # Sleep to ensure tokens won't be identical
|
||||
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -70,7 +75,10 @@ def test_refresh_token(db: Session, client: TestClient) -> None:
|
||||
assert "access_token" in response_data
|
||||
assert "token_type" in response_data
|
||||
assert response_data["token_type"] == "bearer"
|
||||
assert response_data["access_token"] != access_token # Ensure the token is refreshed
|
||||
assert (
|
||||
response_data["access_token"] != access_token
|
||||
) # Ensure the token is refreshed
|
||||
|
||||
|
||||
def test_logout(db: Session, client: TestClient) -> None:
|
||||
user, unhashed_password = generators.create_user(db)
|
||||
@@ -80,15 +88,20 @@ def test_logout(db: Session, client: TestClient) -> None:
|
||||
|
||||
response = client.post(
|
||||
"/api/auth/logout",
|
||||
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Verify that the token is blacklisted
|
||||
blacklisted_token = db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
|
||||
blacklisted_token = (
|
||||
db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
|
||||
)
|
||||
assert blacklisted_token is not None
|
||||
|
||||
|
||||
# Verify that we can't still actually do anything
|
||||
response = client.get(
|
||||
"/api/user/me",
|
||||
@@ -98,7 +111,10 @@ def test_logout(db: Session, client: TestClient) -> None:
|
||||
|
||||
response = client.post(
|
||||
"/api/auth/refresh",
|
||||
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@@ -106,7 +122,9 @@ def test_logout(db: Session, client: TestClient) -> None:
|
||||
|
||||
def test_get_me(db: Session, client: TestClient) -> None:
|
||||
user, unhashed_password = generators.create_user(db)
|
||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||
access_token = generators.login(db, user.username, unhashed_password)[
|
||||
"access_token"
|
||||
]
|
||||
|
||||
response = client.get(
|
||||
"/api/user/me",
|
||||
@@ -119,14 +137,18 @@ def test_get_me(db: Session, client: TestClient) -> None:
|
||||
assert response_data["uuid"] == user.uuid
|
||||
assert response_data["username"] == user.username
|
||||
|
||||
|
||||
def test_get_me_unauthorized(client: TestClient) -> None:
|
||||
### This test should fail (unauthorized) because the user isn't logged in
|
||||
response = client.get("/api/user/me")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
def test_get_user(db: Session, client: TestClient) -> None:
|
||||
user, unhashed_password = generators.create_user(db)
|
||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||
access_token = generators.login(db, user.username, unhashed_password)[
|
||||
"access_token"
|
||||
]
|
||||
|
||||
response = client.get(
|
||||
f"/api/user/{user.username}",
|
||||
@@ -139,11 +161,14 @@ def test_get_user(db: Session, client: TestClient) -> None:
|
||||
assert response_data["uuid"] == user.uuid
|
||||
assert response_data["username"] == user.username
|
||||
|
||||
|
||||
def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
|
||||
### This test should fail (unauthorized) because the user isn't us
|
||||
user, unhashed_password = generators.create_user(db)
|
||||
user2, _ = generators.create_user(db)
|
||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||
access_token = generators.login(db, user.username, unhashed_password)[
|
||||
"access_token"
|
||||
]
|
||||
|
||||
response = client.get(
|
||||
f"/api/user/{user2.username}",
|
||||
@@ -151,11 +176,14 @@ def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_update_user(db: Session, client: TestClient) -> None:
|
||||
user, unhashed_password = generators.create_user(db)
|
||||
new_name = fake.name()
|
||||
|
||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||
access_token = generators.login(db, user.username, unhashed_password)[
|
||||
"access_token"
|
||||
]
|
||||
response = client.patch(
|
||||
f"/api/user/{user.username}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
@@ -168,7 +196,9 @@ def test_update_user(db: Session, client: TestClient) -> None:
|
||||
|
||||
def test_delete_user(db: Session, client: TestClient) -> None:
|
||||
user, unhashed_password = generators.create_user(db)
|
||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||
access_token = generators.login(db, user.username, unhashed_password)[
|
||||
"access_token"
|
||||
]
|
||||
response = client.delete(
|
||||
f"/api/user/{user.username}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
@@ -179,6 +209,7 @@ def test_delete_user(db: Session, client: TestClient) -> None:
|
||||
deleted_user = db.query(User).filter(User.username == user.username).first()
|
||||
assert deleted_user is None
|
||||
|
||||
|
||||
def test_get_user_forbidden(db: Session, client: TestClient) -> None:
|
||||
"""Test getting another user's profile (should be forbidden)."""
|
||||
user1, password_user1 = generators.create_user(db, username="user1_get_forbidden")
|
||||
@@ -195,9 +226,12 @@ def test_get_user_forbidden(db: Session, client: TestClient) -> None:
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_update_user_forbidden(db: Session, client: TestClient) -> None:
|
||||
"""Test updating another user's profile (should be forbidden)."""
|
||||
user1, password_user1 = generators.create_user(db, username="user1_update_forbidden")
|
||||
user1, password_user1 = generators.create_user(
|
||||
db, username="user1_update_forbidden"
|
||||
)
|
||||
user2, _ = generators.create_user(db, username="user2_update_forbidden")
|
||||
new_name = fake.name()
|
||||
|
||||
@@ -213,9 +247,12 @@ def test_update_user_forbidden(db: Session, client: TestClient) -> None:
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_delete_user_forbidden(db: Session, client: TestClient) -> None:
|
||||
"""Test deleting another user's profile (should be forbidden)."""
|
||||
user1, password_user1 = generators.create_user(db, username="user1_delete_forbidden")
|
||||
user1, password_user1 = generators.create_user(
|
||||
db, username="user1_delete_forbidden"
|
||||
)
|
||||
user2, _ = generators.create_user(db, username="user2_delete_forbidden")
|
||||
|
||||
# Log in as user1
|
||||
|
||||
@@ -4,9 +4,10 @@ from sqlalchemy.orm import Session
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from tests.helpers import generators
|
||||
from modules.calendar.models import CalendarEvent # Assuming model exists
|
||||
from modules.calendar.models import CalendarEvent # Assuming model exists
|
||||
from tests.conftest import fake
|
||||
|
||||
|
||||
# Helper function to create an event payload
|
||||
def create_event_payload(start_offset_days=0, end_offset_days=1):
|
||||
start_time = datetime.utcnow() + timedelta(days=start_offset_days)
|
||||
@@ -14,19 +15,22 @@ def create_event_payload(start_offset_days=0, end_offset_days=1):
|
||||
return {
|
||||
"title": fake.sentence(nb_words=3),
|
||||
"description": fake.text(),
|
||||
"start": start_time.isoformat(), # Rename start_time to start
|
||||
"end": end_time.isoformat(), # Rename end_time to end
|
||||
"start": start_time.isoformat(), # Rename start_time to start
|
||||
"end": end_time.isoformat(), # Rename end_time to end
|
||||
"all_day": fake.boolean(),
|
||||
}
|
||||
|
||||
|
||||
# --- Test Create Event ---
|
||||
|
||||
|
||||
def test_create_event_unauthorized(client: TestClient) -> None:
|
||||
"""Test creating an event without authentication."""
|
||||
payload = create_event_payload()
|
||||
response = client.post("/api/calendar/events", json=payload)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
def test_create_event_success(db: Session, client: TestClient) -> None:
|
||||
"""Test creating a calendar event successfully."""
|
||||
user, password = generators.create_user(db)
|
||||
@@ -37,9 +41,11 @@ def test_create_event_success(db: Session, client: TestClient) -> None:
|
||||
response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload
|
||||
json=payload,
|
||||
)
|
||||
assert response.status_code == status.HTTP_201_CREATED # Change expected status to 201
|
||||
assert (
|
||||
response.status_code == status.HTTP_201_CREATED
|
||||
) # Change expected status to 201
|
||||
data = response.json()
|
||||
assert data["title"] == payload["title"]
|
||||
assert data["description"] == payload["description"]
|
||||
@@ -56,13 +62,16 @@ def test_create_event_success(db: Session, client: TestClient) -> None:
|
||||
assert event_in_db.user_id == user.id
|
||||
assert event_in_db.title == payload["title"]
|
||||
|
||||
|
||||
# --- Test Get Events ---
|
||||
|
||||
|
||||
def test_get_events_unauthorized(client: TestClient) -> None:
|
||||
"""Test getting events without authentication."""
|
||||
response = client.get("/api/calendar/events")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
def test_get_events_success(db: Session, client: TestClient) -> None:
|
||||
"""Test getting all calendar events for a user."""
|
||||
user, password = generators.create_user(db)
|
||||
@@ -71,21 +80,31 @@ def test_get_events_success(db: Session, client: TestClient) -> None:
|
||||
|
||||
# Create a couple of events for the user
|
||||
payload1 = create_event_payload(0, 1)
|
||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1)
|
||||
client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload1,
|
||||
)
|
||||
payload2 = create_event_payload(2, 3)
|
||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2)
|
||||
client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload2,
|
||||
)
|
||||
|
||||
# Create an event for another user (should not be returned)
|
||||
other_user, other_password = generators.create_user(db)
|
||||
other_login_rsp = generators.login(db, other_user.username, other_password)
|
||||
other_access_token = other_login_rsp["access_token"]
|
||||
other_payload = create_event_payload(4, 5)
|
||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {other_access_token}"}, json=other_payload)
|
||||
|
||||
client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {other_access_token}"},
|
||||
json=other_payload,
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
"/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
@@ -103,12 +122,24 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
|
||||
access_token = login_rsp["access_token"]
|
||||
|
||||
# Create events
|
||||
payload1 = create_event_payload(0, 1) # Today -> Tomorrow
|
||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1)
|
||||
payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days
|
||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2)
|
||||
payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days
|
||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload3)
|
||||
payload1 = create_event_payload(0, 1) # Today -> Tomorrow
|
||||
client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload1,
|
||||
)
|
||||
payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days
|
||||
client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload2,
|
||||
)
|
||||
payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days
|
||||
client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload3,
|
||||
)
|
||||
|
||||
# Filter for events starting within the next week
|
||||
start_filter = datetime.utcnow().isoformat()
|
||||
@@ -117,11 +148,11 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
|
||||
response = client.get(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
params={"start": start_filter, "end": end_filter}
|
||||
params={"start": start_filter, "end": end_filter},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 2 # Should get event 1 and 2
|
||||
assert len(data) == 2 # Should get event 1 and 2
|
||||
assert data[0]["title"] == payload1["title"]
|
||||
assert data[1]["title"] == payload2["title"]
|
||||
|
||||
@@ -130,40 +161,50 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
|
||||
response = client.get(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
params={"start": start_filter_late}
|
||||
params={"start": start_filter_late},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 1 # Should get event 3
|
||||
assert len(data) == 1 # Should get event 3
|
||||
assert data[0]["title"] == payload3["title"]
|
||||
|
||||
|
||||
# --- Test Get Event By ID ---
|
||||
|
||||
|
||||
def test_get_event_by_id_unauthorized(db: Session, client: TestClient) -> None:
|
||||
"""Test getting a specific event without authentication."""
|
||||
user, password = generators.create_user(db)
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload,
|
||||
)
|
||||
event_id = create_response.json()["id"]
|
||||
|
||||
response = client.get(f"/api/calendar/events/{event_id}")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
|
||||
"""Test getting a specific event successfully."""
|
||||
user, password = generators.create_user(db)
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload,
|
||||
)
|
||||
event_id = create_response.json()["id"]
|
||||
|
||||
response = client.get(
|
||||
f"/api/calendar/events/{event_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
@@ -171,6 +212,7 @@ def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
|
||||
assert data["title"] == payload["title"]
|
||||
assert data["user_id"] == user.id
|
||||
|
||||
|
||||
def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
|
||||
"""Test getting a non-existent event."""
|
||||
user, password = generators.create_user(db)
|
||||
@@ -180,10 +222,11 @@ def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
|
||||
|
||||
response = client.get(
|
||||
f"/api/calendar/events/{non_existent_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
|
||||
"""Test getting another user's event."""
|
||||
user1, password_user1 = generators.create_user(db)
|
||||
@@ -193,7 +236,11 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
|
||||
login_rsp1 = generators.login(db, user1.username, password_user1)
|
||||
access_token1 = login_rsp1["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token1}"},
|
||||
json=payload,
|
||||
)
|
||||
event_id = create_response.json()["id"]
|
||||
|
||||
# Log in as user2 and try to get user1's event
|
||||
@@ -202,45 +249,60 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
|
||||
|
||||
response = client.get(
|
||||
f"/api/calendar/events/{event_id}",
|
||||
headers={"Authorization": f"Bearer {access_token2}"}
|
||||
headers={"Authorization": f"Bearer {access_token2}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match
|
||||
assert (
|
||||
response.status_code == status.HTTP_404_NOT_FOUND
|
||||
) # Service layer returns 404 if user_id doesn't match
|
||||
|
||||
|
||||
# --- Test Update Event ---
|
||||
|
||||
|
||||
def test_update_event_unauthorized(db: Session, client: TestClient) -> None:
|
||||
"""Test updating an event without authentication."""
|
||||
user, password = generators.create_user(db)
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload,
|
||||
)
|
||||
event_id = create_response.json()["id"]
|
||||
update_payload = {"title": "Updated Title"}
|
||||
|
||||
response = client.patch(f"/api/calendar/events/{event_id}", json=update_payload)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
def test_update_event_success(db: Session, client: TestClient) -> None:
|
||||
"""Test updating an event successfully."""
|
||||
user, password = generators.create_user(db)
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
||||
assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload,
|
||||
)
|
||||
assert (
|
||||
create_response.status_code == status.HTTP_201_CREATED
|
||||
) # Ensure creation check uses 201
|
||||
event_id = create_response.json()["id"]
|
||||
|
||||
update_payload = {
|
||||
"title": "Updated Title",
|
||||
"description": "Updated description.",
|
||||
"all_day": not payload["all_day"] # Toggle all_day
|
||||
"all_day": not payload["all_day"], # Toggle all_day
|
||||
}
|
||||
|
||||
response = client.patch(
|
||||
f"/api/calendar/events/{event_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=update_payload
|
||||
json=update_payload,
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
@@ -248,7 +310,7 @@ def test_update_event_success(db: Session, client: TestClient) -> None:
|
||||
assert data["title"] == update_payload["title"]
|
||||
assert data["description"] == update_payload["description"]
|
||||
assert data["all_day"] == update_payload["all_day"]
|
||||
assert data["start"] == payload["start"] # Check correct field name 'start'
|
||||
assert data["start"] == payload["start"] # Check correct field name 'start'
|
||||
assert data["user_id"] == user.id
|
||||
|
||||
# Verify in DB
|
||||
@@ -258,6 +320,7 @@ def test_update_event_success(db: Session, client: TestClient) -> None:
|
||||
assert event_in_db.description == update_payload["description"]
|
||||
assert event_in_db.all_day == update_payload["all_day"]
|
||||
|
||||
|
||||
def test_update_event_not_found(db: Session, client: TestClient) -> None:
|
||||
"""Test updating a non-existent event."""
|
||||
user, password = generators.create_user(db)
|
||||
@@ -269,10 +332,11 @@ def test_update_event_not_found(db: Session, client: TestClient) -> None:
|
||||
response = client.patch(
|
||||
f"/api/calendar/events/{non_existent_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=update_payload
|
||||
json=update_payload,
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
def test_update_event_forbidden(db: Session, client: TestClient) -> None:
|
||||
"""Test updating another user's event."""
|
||||
user1, password_user1 = generators.create_user(db)
|
||||
@@ -282,7 +346,11 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None:
|
||||
login_rsp1 = generators.login(db, user1.username, password_user1)
|
||||
access_token1 = login_rsp1["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token1}"},
|
||||
json=payload,
|
||||
)
|
||||
event_id = create_response.json()["id"]
|
||||
|
||||
# Log in as user2 and try to update user1's event
|
||||
@@ -293,32 +361,47 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None:
|
||||
response = client.patch(
|
||||
f"/api/calendar/events/{event_id}",
|
||||
headers={"Authorization": f"Bearer {access_token2}"},
|
||||
json=update_payload
|
||||
json=update_payload,
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match
|
||||
assert (
|
||||
response.status_code == status.HTTP_404_NOT_FOUND
|
||||
) # Service layer returns 404 if user_id doesn't match
|
||||
|
||||
|
||||
# --- Test Delete Event ---
|
||||
|
||||
|
||||
def test_delete_event_unauthorized(db: Session, client: TestClient) -> None:
|
||||
"""Test deleting an event without authentication."""
|
||||
user, password = generators.create_user(db)
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload,
|
||||
)
|
||||
event_id = create_response.json()["id"]
|
||||
|
||||
response = client.delete(f"/api/calendar/events/{event_id}")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
def test_delete_event_success(db: Session, client: TestClient) -> None:
|
||||
"""Test deleting an event successfully."""
|
||||
user, password = generators.create_user(db)
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
access_token = login_rsp["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
||||
assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
json=payload,
|
||||
)
|
||||
assert (
|
||||
create_response.status_code == status.HTTP_201_CREATED
|
||||
) # Ensure creation check uses 201
|
||||
event_id = create_response.json()["id"]
|
||||
|
||||
# Verify event exists before delete
|
||||
@@ -327,7 +410,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None:
|
||||
|
||||
response = client.delete(
|
||||
f"/api/calendar/events/{event_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||
|
||||
@@ -338,7 +421,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None:
|
||||
# Try getting the deleted event (should be 404)
|
||||
get_response = client.get(
|
||||
f"/api/calendar/events/{event_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
assert get_response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@@ -352,7 +435,7 @@ def test_delete_event_not_found(db: Session, client: TestClient) -> None:
|
||||
|
||||
response = client.delete(
|
||||
f"/api/calendar/events/{non_existent_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"}
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
# The service layer raises NotFound, which should result in 404
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -367,7 +450,11 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
|
||||
login_rsp1 = generators.login(db, user1.username, password_user1)
|
||||
access_token1 = login_rsp1["access_token"]
|
||||
payload = create_event_payload()
|
||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
|
||||
create_response = client.post(
|
||||
"/api/calendar/events",
|
||||
headers={"Authorization": f"Bearer {access_token1}"},
|
||||
json=payload,
|
||||
)
|
||||
event_id = create_response.json()["id"]
|
||||
|
||||
# Log in as user2 and try to delete user1's event
|
||||
@@ -376,7 +463,7 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
|
||||
|
||||
response = client.delete(
|
||||
f"/api/calendar/events/{event_id}",
|
||||
headers={"Authorization": f"Bearer {access_token2}"}
|
||||
headers={"Authorization": f"Bearer {access_token2}"},
|
||||
)
|
||||
# The service layer raises NotFound if user_id doesn't match, resulting in 404
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
@@ -385,4 +472,3 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
|
||||
event_in_db = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first()
|
||||
assert event_in_db is not None
|
||||
assert event_in_db.user_id == user1.id
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from fastapi.testclient import TestClient
|
||||
|
||||
# No database needed for this simple test
|
||||
|
||||
|
||||
def test_health_check(client: TestClient):
|
||||
"""Test the health check endpoint."""
|
||||
response = client.get("/api/health")
|
||||
|
||||
@@ -7,24 +7,37 @@ from datetime import datetime
|
||||
|
||||
from tests.helpers import generators
|
||||
from modules.nlp.schemas import ProcessCommandResponse
|
||||
from modules.nlp.models import MessageSender, ChatMessage # Import necessary models/enums
|
||||
from modules.nlp.models import (
|
||||
MessageSender,
|
||||
ChatMessage,
|
||||
) # Import necessary models/enums
|
||||
|
||||
|
||||
# --- Mocks ---
|
||||
# Mock the external AI call and internal service functions
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_nlp_services():
|
||||
with patch("modules.nlp.api.process_request") as mock_process, \
|
||||
patch("modules.nlp.api.ask_ai") as mock_ask, \
|
||||
patch("modules.nlp.api.save_chat_message") as mock_save, \
|
||||
patch("modules.nlp.api.get_chat_history") as mock_get_history, \
|
||||
patch("modules.nlp.api.create_calendar_event") as mock_create_event, \
|
||||
patch("modules.nlp.api.get_calendar_events") as mock_get_events, \
|
||||
patch("modules.nlp.api.update_calendar_event") as mock_update_event, \
|
||||
patch("modules.nlp.api.delete_calendar_event") as mock_delete_event, \
|
||||
patch("modules.nlp.api.todo_service.create_todo") as mock_create_todo, \
|
||||
patch("modules.nlp.api.todo_service.get_todos") as mock_get_todos, \
|
||||
patch("modules.nlp.api.todo_service.update_todo") as mock_update_todo, \
|
||||
patch("modules.nlp.api.todo_service.delete_todo") as mock_delete_todo:
|
||||
with patch("modules.nlp.api.process_request") as mock_process, patch(
|
||||
"modules.nlp.api.ask_ai"
|
||||
) as mock_ask, patch("modules.nlp.api.save_chat_message") as mock_save, patch(
|
||||
"modules.nlp.api.get_chat_history"
|
||||
) as mock_get_history, patch(
|
||||
"modules.nlp.api.create_calendar_event"
|
||||
) as mock_create_event, patch(
|
||||
"modules.nlp.api.get_calendar_events"
|
||||
) as mock_get_events, patch(
|
||||
"modules.nlp.api.update_calendar_event"
|
||||
) as mock_update_event, patch(
|
||||
"modules.nlp.api.delete_calendar_event"
|
||||
) as mock_delete_event, patch(
|
||||
"modules.nlp.api.todo_service.create_todo"
|
||||
) as mock_create_todo, patch(
|
||||
"modules.nlp.api.todo_service.get_todos"
|
||||
) as mock_get_todos, patch(
|
||||
"modules.nlp.api.todo_service.update_todo"
|
||||
) as mock_update_todo, patch(
|
||||
"modules.nlp.api.todo_service.delete_todo"
|
||||
) as mock_delete_todo:
|
||||
mocks = {
|
||||
"process_request": mock_process,
|
||||
"ask_ai": mock_ask,
|
||||
@@ -41,21 +54,24 @@ def mock_nlp_services():
|
||||
}
|
||||
yield mocks
|
||||
|
||||
|
||||
# --- Helper Function ---
|
||||
def _login_user(db: Session, client: TestClient):
|
||||
user, password = generators.create_user(db)
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
return user, login_rsp["access_token"], login_rsp["refresh_token"]
|
||||
|
||||
|
||||
# --- Tests for /process-command ---
|
||||
|
||||
|
||||
def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
user_input = "What is the capital of France?"
|
||||
mock_nlp_services["process_request"].return_value = {
|
||||
"intent": "ask_ai",
|
||||
"params": {"request": user_input},
|
||||
"response_text": "Let me check that for you."
|
||||
"response_text": "Let me check that for you.",
|
||||
}
|
||||
mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris."
|
||||
|
||||
@@ -63,25 +79,45 @@ def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_servic
|
||||
"/api/nlp/process-command",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"user_input": user_input}
|
||||
json={"user_input": user_input},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == ProcessCommandResponse(responses=["Let me check that for you.", "The capital of France is Paris."]).model_dump()
|
||||
assert (
|
||||
response.json()
|
||||
== ProcessCommandResponse(
|
||||
responses=["Let me check that for you.", "The capital of France is Paris."]
|
||||
).model_dump()
|
||||
)
|
||||
# Verify save calls: user message, initial AI response, final AI answer
|
||||
assert mock_nlp_services["save_chat_message"].call_count == 3
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you.")
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="The capital of France is Paris.")
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
||||
)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||
db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you."
|
||||
)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||
db,
|
||||
user_id=user.id,
|
||||
sender=MessageSender.AI,
|
||||
text="The capital of France is Paris.",
|
||||
)
|
||||
mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input)
|
||||
|
||||
def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_services):
|
||||
|
||||
def test_process_command_get_calendar(
|
||||
client: TestClient, db: Session, mock_nlp_services
|
||||
):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
user_input = "What are my events today?"
|
||||
mock_nlp_services["process_request"].return_value = {
|
||||
"intent": "get_calendar_events",
|
||||
"params": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T23:59:59Z"}, # Example params
|
||||
"response_text": "Okay, fetching your events."
|
||||
"params": {
|
||||
"start": "2024-01-01T00:00:00Z",
|
||||
"end": "2024-01-01T23:59:59Z",
|
||||
}, # Example params
|
||||
"response_text": "Okay, fetching your events.",
|
||||
}
|
||||
# Mock the actual event model returned by the service
|
||||
mock_event = MagicMock()
|
||||
@@ -94,26 +130,32 @@ def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_
|
||||
"/api/nlp/process-command",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"user_input": user_input}
|
||||
json={"user_input": user_input},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
expected_responses = [
|
||||
"Okay, fetching your events.",
|
||||
"Here are the events:",
|
||||
"- Team Meeting (2024-01-01 10:00 - 11:00)"
|
||||
"- Team Meeting (2024-01-01 10:00 - 11:00)",
|
||||
]
|
||||
assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump()
|
||||
assert mock_nlp_services["save_chat_message"].call_count == 4 # User, Initial AI, Header, Event
|
||||
assert (
|
||||
response.json()
|
||||
== ProcessCommandResponse(responses=expected_responses).model_dump()
|
||||
)
|
||||
assert (
|
||||
mock_nlp_services["save_chat_message"].call_count == 4
|
||||
) # User, Initial AI, Header, Event
|
||||
mock_nlp_services["get_calendar_events"].assert_called_once()
|
||||
|
||||
|
||||
def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
user_input = "Add buy milk to my list"
|
||||
mock_nlp_services["process_request"].return_value = {
|
||||
"intent": "add_todo",
|
||||
"params": {"task": "buy milk"},
|
||||
"response_text": "Adding it now."
|
||||
"response_text": "Adding it now.",
|
||||
}
|
||||
# Mock the actual Todo model returned by the service
|
||||
mock_todo = MagicMock()
|
||||
@@ -125,81 +167,119 @@ def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_serv
|
||||
"/api/nlp/process-command",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"user_input": user_input}
|
||||
json={"user_input": user_input},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."]
|
||||
assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump()
|
||||
assert mock_nlp_services["save_chat_message"].call_count == 3 # User, Initial AI, Confirmation AI
|
||||
assert (
|
||||
response.json()
|
||||
== ProcessCommandResponse(responses=expected_responses).model_dump()
|
||||
)
|
||||
assert (
|
||||
mock_nlp_services["save_chat_message"].call_count == 3
|
||||
) # User, Initial AI, Confirmation AI
|
||||
mock_nlp_services["create_todo"].assert_called_once()
|
||||
|
||||
def test_process_command_clarification(client: TestClient, db: Session, mock_nlp_services):
|
||||
|
||||
def test_process_command_clarification(
|
||||
client: TestClient, db: Session, mock_nlp_services
|
||||
):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
user_input = "Delete the event"
|
||||
clarification_text = "Which event do you mean? Please provide the ID."
|
||||
mock_nlp_services["process_request"].return_value = {
|
||||
"intent": "clarification_needed",
|
||||
"params": {"request": user_input},
|
||||
"response_text": clarification_text
|
||||
"response_text": clarification_text,
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/nlp/process-command",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"user_input": user_input}
|
||||
json={"user_input": user_input},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == ProcessCommandResponse(responses=[clarification_text]).model_dump()
|
||||
assert (
|
||||
response.json()
|
||||
== ProcessCommandResponse(responses=[clarification_text]).model_dump()
|
||||
)
|
||||
# Verify save calls: user message, clarification AI response
|
||||
assert mock_nlp_services["save_chat_message"].call_count == 2
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=clarification_text)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
||||
)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||
db, user_id=user.id, sender=MessageSender.AI, text=clarification_text
|
||||
)
|
||||
# Ensure no action services were called
|
||||
mock_nlp_services["delete_calendar_event"].assert_not_called()
|
||||
|
||||
def test_process_command_error_intent(client: TestClient, db: Session, mock_nlp_services):
|
||||
|
||||
def test_process_command_error_intent(
|
||||
client: TestClient, db: Session, mock_nlp_services
|
||||
):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
user_input = "Gibberish request"
|
||||
error_text = "Sorry, I didn't understand that."
|
||||
mock_nlp_services["process_request"].return_value = {
|
||||
"intent": "error",
|
||||
"params": {},
|
||||
"response_text": error_text
|
||||
"response_text": error_text,
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/nlp/process-command",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"user_input": user_input}
|
||||
json={"user_input": user_input},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
|
||||
assert (
|
||||
response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
|
||||
)
|
||||
# Verify save calls: user message, error AI response
|
||||
assert mock_nlp_services["save_chat_message"].call_count == 2
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=error_text)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
||||
)
|
||||
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||
db, user_id=user.id, sender=MessageSender.AI, text=error_text
|
||||
)
|
||||
|
||||
|
||||
# --- Tests for /history ---
|
||||
|
||||
|
||||
def test_get_history(client: TestClient, db: Session, mock_nlp_services):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
|
||||
# Mock the history data returned by the service
|
||||
mock_history = [
|
||||
ChatMessage(id=1, user_id=user.id, sender=MessageSender.USER, text="Hello", timestamp=datetime.now()),
|
||||
ChatMessage(id=2, user_id=user.id, sender=MessageSender.AI, text="Hi there!", timestamp=datetime.now())
|
||||
ChatMessage(
|
||||
id=1,
|
||||
user_id=user.id,
|
||||
sender=MessageSender.USER,
|
||||
text="Hello",
|
||||
timestamp=datetime.now(),
|
||||
),
|
||||
ChatMessage(
|
||||
id=2,
|
||||
user_id=user.id,
|
||||
sender=MessageSender.AI,
|
||||
text="Hi there!",
|
||||
timestamp=datetime.now(),
|
||||
),
|
||||
]
|
||||
mock_nlp_services["get_chat_history"].return_value = mock_history
|
||||
|
||||
response = client.get(
|
||||
"/api/nlp/history",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token}
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -208,11 +288,15 @@ def test_get_history(client: TestClient, db: Session, mock_nlp_services):
|
||||
assert len(response_data) == 2
|
||||
assert response_data[0]["text"] == "Hello"
|
||||
assert response_data[1]["text"] == "Hi there!"
|
||||
mock_nlp_services["get_chat_history"].assert_called_once_with(db, user_id=user.id, limit=50)
|
||||
mock_nlp_services["get_chat_history"].assert_called_once_with(
|
||||
db, user_id=user.id, limit=50
|
||||
)
|
||||
|
||||
|
||||
def test_get_history_unauthorized(client: TestClient):
|
||||
response = client.get("/api/nlp/history")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
# Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.)
|
||||
# Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete)
|
||||
|
||||
@@ -5,14 +5,17 @@ from datetime import date
|
||||
|
||||
from tests.helpers import generators
|
||||
|
||||
|
||||
# Helper Function
|
||||
def _login_user(db: Session, client: TestClient):
|
||||
user, password = generators.create_user(db)
|
||||
login_rsp = generators.login(db, user.username, password)
|
||||
return user, login_rsp["access_token"], login_rsp["refresh_token"]
|
||||
|
||||
|
||||
# --- Test CRUD Operations ---
|
||||
|
||||
|
||||
def test_create_todo(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
today_date = date.today()
|
||||
@@ -20,14 +23,14 @@ def test_create_todo(client: TestClient, db: Session):
|
||||
todo_data = {
|
||||
"task": "Test TODO",
|
||||
"date": f"{today_date.isoformat()}T00:00:00",
|
||||
"remind": True
|
||||
"remind": True,
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json=todo_data
|
||||
json=todo_data,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
@@ -35,50 +38,66 @@ def test_create_todo(client: TestClient, db: Session):
|
||||
assert data["task"] == todo_data["task"]
|
||||
assert data["date"] == todo_data["date"]
|
||||
assert data["remind"] == todo_data["remind"]
|
||||
assert data["complete"] is False # Default
|
||||
assert data["complete"] is False # Default
|
||||
assert "id" in data
|
||||
assert data["owner_id"] == user.id
|
||||
|
||||
|
||||
def test_read_todos(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
# Create some todos for the user
|
||||
client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 1"})
|
||||
client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 2"})
|
||||
client.post(
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"task": "Todo 1"},
|
||||
)
|
||||
client.post(
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"task": "Todo 2"},
|
||||
)
|
||||
|
||||
# Create a todo for another user
|
||||
other_user, other_password = generators.create_user(db)
|
||||
other_login_rsp = generators.login(db, other_user.username, other_password)
|
||||
other_access_token = other_login_rsp["access_token"]
|
||||
other_refresh_token = other_login_rsp["refresh_token"]
|
||||
client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"})
|
||||
|
||||
client.post(
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {other_access_token}"},
|
||||
cookies={"refresh_token": other_refresh_token},
|
||||
json={"task": "Other User Todo"},
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token}
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 2 # Should only get todos for the logged-in user
|
||||
assert len(data) == 2 # Should only get todos for the logged-in user
|
||||
assert data[0]["task"] == "Todo 1"
|
||||
assert data[1]["task"] == "Todo 2"
|
||||
|
||||
|
||||
def test_read_single_todo(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
create_response = client.post(
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"task": "Specific Todo"}
|
||||
json={"task": "Specific Todo"},
|
||||
)
|
||||
todo_id = create_response.json()["id"]
|
||||
|
||||
response = client.get(
|
||||
f"/api/todos/{todo_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token}
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -87,15 +106,17 @@ def test_read_single_todo(client: TestClient, db: Session):
|
||||
assert data["task"] == "Specific Todo"
|
||||
assert data["owner_id"] == user.id
|
||||
|
||||
|
||||
def test_read_single_todo_not_found(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
response = client.get(
|
||||
"/api/todos/9999", # Non-existent ID
|
||||
"/api/todos/9999", # Non-existent ID
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token}
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
def test_read_single_todo_forbidden(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
|
||||
@@ -104,16 +125,26 @@ def test_read_single_todo_forbidden(client: TestClient, db: Session):
|
||||
other_login_rsp = generators.login(db, other_user.username, other_password)
|
||||
other_access_token = other_login_rsp["access_token"]
|
||||
other_refresh_token = other_login_rsp["refresh_token"]
|
||||
other_create_response = client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"})
|
||||
other_create_response = client.post(
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {other_access_token}"},
|
||||
cookies={"refresh_token": other_refresh_token},
|
||||
json={"task": "Other User Todo"},
|
||||
)
|
||||
other_todo_id = other_create_response.json()["id"]
|
||||
|
||||
# Try to access the other user's todo
|
||||
response = client.get(
|
||||
f"/api/todos/{other_todo_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"}, # Using the first user's token
|
||||
cookies={"refresh_token": refresh_token}
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}"
|
||||
}, # Using the first user's token
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND # Service raises 404 if not found for *this* user
|
||||
assert (
|
||||
response.status_code == status.HTTP_404_NOT_FOUND
|
||||
) # Service raises 404 if not found for *this* user
|
||||
|
||||
|
||||
def test_update_todo(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
@@ -121,7 +152,7 @@ def test_update_todo(client: TestClient, db: Session):
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"task": "Update Me"}
|
||||
json={"task": "Update Me"},
|
||||
)
|
||||
todo_id = create_response.json()["id"]
|
||||
|
||||
@@ -130,7 +161,7 @@ def test_update_todo(client: TestClient, db: Session):
|
||||
f"/api/todos/{todo_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json=update_data
|
||||
json=update_data,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -144,7 +175,7 @@ def test_update_todo(client: TestClient, db: Session):
|
||||
get_response = client.get(
|
||||
f"/api/todos/{todo_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token}
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
assert get_response.json()["task"] == update_data["task"]
|
||||
assert get_response.json()["complete"] == update_data["complete"]
|
||||
@@ -154,55 +185,60 @@ def test_update_todo_not_found(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
update_data = {"task": "Updated Task", "complete": True}
|
||||
response = client.put(
|
||||
"/api/todos/9999", # Non-existent ID
|
||||
"/api/todos/9999", # Non-existent ID
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json=update_data
|
||||
json=update_data,
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
def test_delete_todo(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
create_response = client.post(
|
||||
"/api/todos/",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token},
|
||||
json={"task": "Delete Me"}
|
||||
json={"task": "Delete Me"},
|
||||
)
|
||||
todo_id = create_response.json()["id"]
|
||||
|
||||
response = client.delete(
|
||||
f"/api/todos/{todo_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token}
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item
|
||||
assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item
|
||||
assert response.json()["id"] == todo_id
|
||||
|
||||
# Verify deletion by trying to read
|
||||
get_response = client.get(
|
||||
f"/api/todos/{todo_id}",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token}
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
assert get_response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
def test_delete_todo_not_found(client: TestClient, db: Session):
|
||||
user, access_token, refresh_token = _login_user(db, client)
|
||||
response = client.delete(
|
||||
"/api/todos/9999", # Non-existent ID
|
||||
"/api/todos/9999", # Non-existent ID
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
cookies={"refresh_token": refresh_token}
|
||||
cookies={"refresh_token": refresh_token},
|
||||
)
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
# --- Test Authentication/Authorization ---
|
||||
|
||||
|
||||
def test_create_todo_unauthorized(client: TestClient):
|
||||
response = client.post("/api/todos/", json={"task": "No Auth"})
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
def test_read_todos_unauthorized(client: TestClient):
|
||||
response = client.get("/api/todos/")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
Reference in New Issue
Block a user