[REFORMAT] Ran black reformat
This commit is contained in:
@@ -77,9 +77,7 @@ def run_migrations_online() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
connection=connection, target_metadata=target_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ Revises:
|
|||||||
Create Date: 2025-04-21 01:14:33.233195
|
Create Date: 2025-04-21 01:14:33.233195
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = '69069d6184b3'
|
revision: str = "69069d6184b3"
|
||||||
down_revision: Union[str, None] = None
|
down_revision: Union[str, None] = None
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ Revises: 69069d6184b3
|
|||||||
Create Date: 2025-04-21 20:33:27.028529
|
Create Date: 2025-04-21 20:33:27.028529
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = '9a82960db482'
|
revision: str = "9a82960db482"
|
||||||
down_revision: Union[str, None] = '69069d6184b3'
|
down_revision: Union[str, None] = "69069d6184b3"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ celery_app = Celery(
|
|||||||
"worker",
|
"worker",
|
||||||
broker=settings.REDIS_URL,
|
broker=settings.REDIS_URL,
|
||||||
backend=settings.REDIS_URL,
|
backend=settings.REDIS_URL,
|
||||||
include=["modules.auth.tasks", "modules.admin.tasks"] # Add paths to modules containing tasks
|
include=[
|
||||||
|
"modules.auth.tasks",
|
||||||
|
"modules.admin.tasks",
|
||||||
|
], # Add paths to modules containing tasks
|
||||||
# Add other modules with tasks here, e.g., "modules.some_other_module.tasks"
|
# Add other modules with tasks here, e.g., "modules.some_other_module.tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ import os
|
|||||||
|
|
||||||
DOTENV_PATH = os.path.join(os.path.dirname(__file__), "../.env")
|
DOTENV_PATH = os.path.join(os.path.dirname(__file__), "../.env")
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
# Database settings - reads DB_URL from environment or .env
|
# Database settings - reads DB_URL from environment or .env
|
||||||
DB_URL: str = "postgresql://maia:maia@localhost:5432/maia"
|
DB_URL: str = "postgresql://maia:maia@localhost:5432/maia"
|
||||||
|
|
||||||
# Redis settings - reads REDIS_URL from environment or .env, also used for Celery.
|
# Redis settings - reads REDIS_URL from environment or .env, also used for Celery.
|
||||||
REDIS_URL: str ="redis://localhost:6379/0"
|
REDIS_URL: str = "redis://localhost:6379/0"
|
||||||
|
|
||||||
# JWT settings - reads from environment or .env
|
# JWT settings - reads from environment or .env
|
||||||
JWT_ALGORITHM: str = "HS256"
|
JWT_ALGORITHM: str = "HS256"
|
||||||
@@ -24,8 +25,9 @@ class Settings(BaseSettings):
|
|||||||
class Config:
|
class Config:
|
||||||
# Tell pydantic-settings to load variables from a .env file
|
# Tell pydantic-settings to load variables from a .env file
|
||||||
env_file = DOTENV_PATH
|
env_file = DOTENV_PATH
|
||||||
env_file_encoding = 'utf-8'
|
env_file_encoding = "utf-8"
|
||||||
extra = 'ignore'
|
extra = "ignore"
|
||||||
|
|
||||||
|
|
||||||
# Create a single instance of the settings
|
# Create a single instance of the settings
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ Base = declarative_base() # Used for models
|
|||||||
_engine = None
|
_engine = None
|
||||||
_SessionLocal = None
|
_SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
def get_engine():
|
def get_engine():
|
||||||
global _engine
|
global _engine
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
@@ -20,10 +21,13 @@ def get_engine():
|
|||||||
try:
|
try:
|
||||||
_engine.connect()
|
_engine.connect()
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception("Database connection failed. Is the database server running?")
|
raise Exception(
|
||||||
|
"Database connection failed. Is the database server running?"
|
||||||
|
)
|
||||||
Base.metadata.create_all(_engine) # Create tables here
|
Base.metadata.create_all(_engine) # Create tables here
|
||||||
return _engine
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
def get_sessionmaker():
|
def get_sessionmaker():
|
||||||
global _SessionLocal
|
global _SessionLocal
|
||||||
if _SessionLocal is None:
|
if _SessionLocal is None:
|
||||||
@@ -31,6 +35,7 @@ def get_sessionmaker():
|
|||||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
return _SessionLocal
|
return _SessionLocal
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> Generator[Session, None, None]:
|
def get_db() -> Generator[Session, None, None]:
|
||||||
SessionLocal = get_sessionmaker()
|
SessionLocal = get_sessionmaker()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
|||||||
@@ -8,20 +8,26 @@ from starlette.status import (
|
|||||||
HTTP_409_CONFLICT,
|
HTTP_409_CONFLICT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def bad_request_exception(detail: str = "Bad Request"):
|
def bad_request_exception(detail: str = "Bad Request"):
|
||||||
return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
|
return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
def unauthorized_exception(detail: str = "Unauthorized"):
|
def unauthorized_exception(detail: str = "Unauthorized"):
|
||||||
return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
|
return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
def forbidden_exception(detail: str = "Forbidden"):
|
def forbidden_exception(detail: str = "Forbidden"):
|
||||||
return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
|
return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
def not_found_exception(detail: str = "Not Found"):
|
def not_found_exception(detail: str = "Not Found"):
|
||||||
return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
|
return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
def internal_server_error_exception(detail: str = "Internal Server Error"):
|
def internal_server_error_exception(detail: str = "Internal Server Error"):
|
||||||
return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
|
return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
def conflict_exception(detail: str = "Conflict"):
|
def conflict_exception(detail: str = "Conflict"):
|
||||||
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
|
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
|
||||||
@@ -11,7 +11,8 @@ import logging
|
|||||||
# import all models to ensure they are registered before create_all
|
# import all models to ensure they are registered before create_all
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger('passlib').setLevel(logging.ERROR) # fix bc package logging is broken
|
logging.getLogger("passlib").setLevel(logging.ERROR) # fix bc package logging is broken
|
||||||
|
|
||||||
|
|
||||||
# Create DB tables (remove in production; use migrations instead)
|
# Create DB tables (remove in production; use migrations instead)
|
||||||
def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]:
|
def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]:
|
||||||
@@ -24,6 +25,7 @@ def lifespan_factory() -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]
|
|||||||
|
|
||||||
return lifespan
|
return lifespan
|
||||||
|
|
||||||
|
|
||||||
lifespan = lifespan_factory()
|
lifespan = lifespan_factory()
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
@@ -41,9 +43,10 @@ app.add_middleware(
|
|||||||
],
|
],
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"]
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Health endpoint
|
# Health endpoint
|
||||||
@app.get("/api/health")
|
@app.get("/api/health")
|
||||||
def health():
|
def health():
|
||||||
|
|||||||
@@ -9,14 +9,17 @@ from .tasks import cleardb
|
|||||||
|
|
||||||
router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)])
|
router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)])
|
||||||
|
|
||||||
|
|
||||||
# Define a Pydantic model for the request body
|
# Define a Pydantic model for the request body
|
||||||
class ClearDbRequest(BaseModel):
|
class ClearDbRequest(BaseModel):
|
||||||
hard: bool
|
hard: bool
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
def read_admin():
|
def read_admin():
|
||||||
return {"message": "Admin route"}
|
return {"message": "Admin route"}
|
||||||
|
|
||||||
|
|
||||||
# Change to POST and use the request body model
|
# Change to POST and use the request body model
|
||||||
@router.post("/cleardb")
|
@router.post("/cleardb")
|
||||||
def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
|
def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from core.celery_app import celery_app
|
from core.celery_app import celery_app
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task
|
@celery_app.task
|
||||||
def cleardb(hard: bool):
|
def cleardb(hard: bool):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,9 +3,24 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
|||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from jose import JWTError
|
from jose import JWTError
|
||||||
from modules.auth.models import User
|
from modules.auth.models import User
|
||||||
from modules.auth.schemas import UserCreate, UserResponse, Token, RefreshTokenRequest, LogoutRequest
|
from modules.auth.schemas import (
|
||||||
|
UserCreate,
|
||||||
|
UserResponse,
|
||||||
|
Token,
|
||||||
|
RefreshTokenRequest,
|
||||||
|
LogoutRequest,
|
||||||
|
)
|
||||||
from modules.auth.services import create_user
|
from modules.auth.services import create_user
|
||||||
from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens
|
from modules.auth.security import (
|
||||||
|
TokenType,
|
||||||
|
get_current_user,
|
||||||
|
oauth2_scheme,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
verify_token,
|
||||||
|
authenticate_user,
|
||||||
|
blacklist_tokens,
|
||||||
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from core.database import get_db
|
from core.database import get_db
|
||||||
@@ -15,12 +30,19 @@ from core.exceptions import unauthorized_exception
|
|||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
|
@router.post(
|
||||||
|
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
|
||||||
|
)
|
||||||
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
|
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
|
||||||
return create_user(user.username, user.password, user.name, db)
|
return create_user(user.username, user.password, user.name, db)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=Token)
|
@router.post("/login", response_model=Token)
|
||||||
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
|
def login(
|
||||||
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Authenticate user and return JWT tokens in the response body.
|
Authenticate user and return JWT tokens in the response body.
|
||||||
"""
|
"""
|
||||||
@@ -31,37 +53,51 @@ def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annota
|
|||||||
detail="Incorrect username or password",
|
detail="Incorrect username or password",
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
access_token = create_access_token(
|
||||||
|
data={"sub": user.username},
|
||||||
|
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||||
|
)
|
||||||
refresh_token = create_refresh_token(data={"sub": user.username})
|
refresh_token = create_refresh_token(data={"sub": user.username})
|
||||||
|
|
||||||
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh")
|
@router.post("/refresh")
|
||||||
def refresh_token(payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]):
|
def refresh_token(
|
||||||
|
payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]
|
||||||
|
):
|
||||||
print("Refreshing token...")
|
print("Refreshing token...")
|
||||||
refresh_token = payload.refresh_token
|
refresh_token = payload.refresh_token
|
||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
raise unauthorized_exception("Refresh token missing in request body")
|
raise unauthorized_exception("Refresh token missing in request body")
|
||||||
|
|
||||||
user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db)
|
user_data = verify_token(
|
||||||
|
refresh_token, expected_token_type=TokenType.REFRESH, db=db
|
||||||
|
)
|
||||||
if not user_data:
|
if not user_data:
|
||||||
raise unauthorized_exception("Invalid refresh token")
|
raise unauthorized_exception("Invalid refresh token")
|
||||||
|
|
||||||
new_access_token = create_access_token(data={"sub": user_data.username})
|
new_access_token = create_access_token(data={"sub": user_data.username})
|
||||||
return {"access_token": new_access_token, "token_type": "bearer"}
|
return {"access_token": new_access_token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout")
|
@router.post("/logout")
|
||||||
def logout(payload: LogoutRequest, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme)):
|
def logout(
|
||||||
|
payload: LogoutRequest,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
access_token: str = Depends(oauth2_scheme),
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
refresh_token = payload.refresh_token
|
refresh_token = payload.refresh_token
|
||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
raise unauthorized_exception("Refresh token not found in request body")
|
raise unauthorized_exception("Refresh token not found in request body")
|
||||||
|
|
||||||
blacklist_tokens(
|
blacklist_tokens(access_token=access_token, refresh_token=refresh_token, db=db)
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"message": "Logged out successfully"}
|
return {"message": "Logged out successfully"}
|
||||||
except JWTError:
|
except JWTError:
|
||||||
|
|||||||
@@ -5,14 +5,18 @@ from modules.auth.schemas import UserRole
|
|||||||
from modules.auth.models import User
|
from modules.auth.models import User
|
||||||
from core.exceptions import forbidden_exception
|
from core.exceptions import forbidden_exception
|
||||||
|
|
||||||
|
|
||||||
class RoleChecker:
|
class RoleChecker:
|
||||||
def __init__(self, allowed_roles: list[UserRole]):
|
def __init__(self, allowed_roles: list[UserRole]):
|
||||||
self.allowed_roles = allowed_roles
|
self.allowed_roles = allowed_roles
|
||||||
|
|
||||||
def __call__(self, user: User = Depends(get_current_user)):
|
def __call__(self, user: User = Depends(get_current_user)):
|
||||||
if user.role not in self.allowed_roles:
|
if user.role not in self.allowed_roles:
|
||||||
raise forbidden_exception("You do not have permission to perform this action.")
|
raise forbidden_exception(
|
||||||
|
"You do not have permission to perform this action."
|
||||||
|
)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
admin_only = RoleChecker([UserRole.ADMIN])
|
admin_only = RoleChecker([UserRole.ADMIN])
|
||||||
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
|
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
|
||||||
@@ -4,10 +4,12 @@ from sqlalchemy import Column, Integer, String, Enum, DateTime
|
|||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from enum import Enum as PyEnum
|
from enum import Enum as PyEnum
|
||||||
|
|
||||||
|
|
||||||
class UserRole(str, PyEnum):
|
class UserRole(str, PyEnum):
|
||||||
ADMIN = "admin"
|
ADMIN = "admin"
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
|
|||||||
@@ -2,33 +2,41 @@
|
|||||||
from enum import Enum as PyEnum
|
from enum import Enum as PyEnum
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
refresh_token: str | None = None
|
refresh_token: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
class TokenData(BaseModel):
|
||||||
username: str | None = None
|
username: str | None = None
|
||||||
scopes: list[str] = []
|
scopes: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
class RefreshTokenRequest(BaseModel):
|
class RefreshTokenRequest(BaseModel):
|
||||||
refresh_token: str
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
class LogoutRequest(BaseModel):
|
class LogoutRequest(BaseModel):
|
||||||
refresh_token: str
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
class UserRole(str, PyEnum):
|
class UserRole(str, PyEnum):
|
||||||
ADMIN = "admin"
|
ADMIN = "admin"
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(BaseModel):
|
class UserCreate(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
password: str
|
password: str
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class UserPatch(BaseModel):
|
class UserPatch(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
uuid: str
|
uuid: str
|
||||||
username: str
|
username: str
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from modules.auth.schemas import TokenData
|
|||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||||
|
|
||||||
|
|
||||||
class TokenType(str, Enum):
|
class TokenType(str, Enum):
|
||||||
ACCESS = "access"
|
ACCESS = "access"
|
||||||
REFRESH = "refresh"
|
REFRESH = "refresh"
|
||||||
@@ -25,11 +26,13 @@ class TokenType(str, Enum):
|
|||||||
|
|
||||||
password_hasher = PasswordHasher()
|
password_hasher = PasswordHasher()
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
"""Hash a password with Argon2 (and optional pepper)."""
|
"""Hash a password with Argon2 (and optional pepper)."""
|
||||||
peppered_password = password + settings.PEPPER # Prepend/append pepper
|
peppered_password = password + settings.PEPPER # Prepend/append pepper
|
||||||
return password_hasher.hash(peppered_password)
|
return password_hasher.hash(peppered_password)
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
"""Verify a password against its hashed version using Argon2."""
|
"""Verify a password against its hashed version using Argon2."""
|
||||||
peppered_password = plain_password + settings.PEPPER
|
peppered_password = plain_password + settings.PEPPER
|
||||||
@@ -38,6 +41,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|||||||
except VerifyMismatchError:
|
except VerifyMismatchError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
||||||
"""
|
"""
|
||||||
Authenticate a user by checking username/password against the database.
|
Authenticate a user by checking username/password against the database.
|
||||||
@@ -52,34 +56,39 @@ def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
|||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.now(timezone.utc) + expires_delta
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire = datetime.now(timezone.utc) + timedelta(
|
||||||
|
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
|
)
|
||||||
# expire = datetime.now(timezone.utc) + timedelta(seconds=5)
|
# expire = datetime.now(timezone.utc) + timedelta(seconds=5)
|
||||||
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
|
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
|
||||||
return jwt.encode(
|
return jwt.encode(
|
||||||
to_encode,
|
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||||
settings.JWT_SECRET_KEY,
|
|
||||||
algorithm=settings.JWT_ALGORITHM
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
|
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.now(timezone.utc) + expires_delta
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
expire = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
|
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
|
||||||
return jwt.encode(
|
return jwt.encode(
|
||||||
to_encode,
|
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
|
||||||
settings.JWT_SECRET_KEY,
|
|
||||||
algorithm=settings.JWT_ALGORITHM
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
|
|
||||||
|
def verify_token(
|
||||||
|
token: str, expected_token_type: TokenType, db: Session
|
||||||
|
) -> TokenData | None:
|
||||||
"""Verify a JWT token and return TokenData if valid.
|
"""Verify a JWT token and return TokenData if valid.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -96,12 +105,17 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok
|
|||||||
TokenData | None
|
TokenData | None
|
||||||
TokenData instance if the token is valid, None otherwise.
|
TokenData instance if the token is valid, None otherwise.
|
||||||
"""
|
"""
|
||||||
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
|
is_blacklisted = (
|
||||||
|
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
if is_blacklisted:
|
if is_blacklisted:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
token_type: str = payload.get("token_type")
|
token_type: str = payload.get("token_type")
|
||||||
|
|
||||||
@@ -113,7 +127,10 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok
|
|||||||
except JWTError:
|
except JWTError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
|
|
||||||
|
def get_current_user(
|
||||||
|
db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)
|
||||||
|
) -> User:
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Could not validate credentials",
|
detail="Could not validate credentials",
|
||||||
@@ -121,14 +138,15 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if the token is blacklisted
|
# Check if the token is blacklisted
|
||||||
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
|
is_blacklisted = (
|
||||||
|
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
if is_blacklisted:
|
if is_blacklisted:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||||
settings.JWT_SECRET_KEY,
|
|
||||||
algorithms=[settings.JWT_ALGORITHM]
|
|
||||||
)
|
)
|
||||||
username: str = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
if username is None:
|
if username is None:
|
||||||
@@ -141,6 +159,7 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen
|
|||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
|
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
|
||||||
"""Blacklist both access and refresh tokens.
|
"""Blacklist both access and refresh tokens.
|
||||||
|
|
||||||
@@ -154,7 +173,9 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
|
|||||||
Database session to perform the operation.
|
Database session to perform the operation.
|
||||||
"""
|
"""
|
||||||
for token in [access_token, refresh_token]:
|
for token in [access_token, refresh_token]:
|
||||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||||
|
|
||||||
# Add the token to the blacklist
|
# Add the token to the blacklist
|
||||||
@@ -163,8 +184,11 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
|
|||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
def blacklist_token(token: str, db: Session) -> None:
|
def blacklist_token(token: str, db: Session) -> None:
|
||||||
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||||
|
|
||||||
# Add the token to the blacklist
|
# Add the token to the blacklist
|
||||||
|
|||||||
@@ -23,7 +23,9 @@ def create_user(username: str, password: str, name: str, db: Session) -> UserRes
|
|||||||
|
|
||||||
hashed_password = hash_password(password)
|
hashed_password = hash_password(password)
|
||||||
user_uuid = str(uuid.uuid4())
|
user_uuid = str(uuid.uuid4())
|
||||||
user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid)
|
user = User(
|
||||||
|
username=username, hashed_password=hashed_password, name=name, uuid=user_uuid
|
||||||
|
)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(user) # Loads the generated ID
|
db.refresh(user) # Loads the generated ID
|
||||||
|
|||||||
@@ -6,50 +6,63 @@ from typing import List, Optional
|
|||||||
from modules.auth.dependencies import get_current_user
|
from modules.auth.dependencies import get_current_user
|
||||||
from core.database import get_db
|
from core.database import get_db
|
||||||
from modules.auth.models import User
|
from modules.auth.models import User
|
||||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate, CalendarEventResponse
|
from modules.calendar.schemas import (
|
||||||
from modules.calendar.service import create_calendar_event, get_calendar_event_by_id, get_calendar_events, update_calendar_event, delete_calendar_event
|
CalendarEventCreate,
|
||||||
|
CalendarEventUpdate,
|
||||||
|
CalendarEventResponse,
|
||||||
|
)
|
||||||
|
from modules.calendar.service import (
|
||||||
|
create_calendar_event,
|
||||||
|
get_calendar_event_by_id,
|
||||||
|
get_calendar_events,
|
||||||
|
update_calendar_event,
|
||||||
|
delete_calendar_event,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/calendar", tags=["calendar"])
|
router = APIRouter(prefix="/calendar", tags=["calendar"])
|
||||||
|
|
||||||
@router.post("/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
|
@router.post(
|
||||||
|
"/events", response_model=CalendarEventResponse, status_code=status.HTTP_201_CREATED
|
||||||
|
)
|
||||||
def create_event(
|
def create_event(
|
||||||
event: CalendarEventCreate,
|
event: CalendarEventCreate,
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
return create_calendar_event(db, user.id, event)
|
return create_calendar_event(db, user.id, event)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/events", response_model=List[CalendarEventResponse])
|
@router.get("/events", response_model=List[CalendarEventResponse])
|
||||||
def get_events(
|
def get_events(
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
start: Optional[datetime] = None,
|
start: Optional[datetime] = None,
|
||||||
end: Optional[datetime] = None
|
end: Optional[datetime] = None,
|
||||||
):
|
):
|
||||||
return get_calendar_events(db, user.id, start, end)
|
return get_calendar_events(db, user.id, start, end)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/events/{event_id}", response_model=CalendarEventResponse)
|
@router.get("/events/{event_id}", response_model=CalendarEventResponse)
|
||||||
def get_event_by_id(
|
def get_event_by_id(
|
||||||
event_id: int,
|
event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||||
user: User = Depends(get_current_user),
|
|
||||||
db: Session = Depends(get_db)
|
|
||||||
):
|
):
|
||||||
event = get_calendar_event_by_id(db, user.id, event_id)
|
event = get_calendar_event_by_id(db, user.id, event_id)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/events/{event_id}", response_model=CalendarEventResponse)
|
@router.patch("/events/{event_id}", response_model=CalendarEventResponse)
|
||||||
def update_event(
|
def update_event(
|
||||||
event_id: int,
|
event_id: int,
|
||||||
event: CalendarEventUpdate,
|
event: CalendarEventUpdate,
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
return update_calendar_event(db, user.id, event_id, event)
|
return update_calendar_event(db, user.id, event_id, event)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/events/{event_id}", status_code=204)
|
@router.delete("/events/{event_id}", status_code=204)
|
||||||
def delete_event(
|
def delete_event(
|
||||||
event_id: int,
|
event_id: int, user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||||
user: User = Depends(get_current_user),
|
|
||||||
db: Session = Depends(get_db)
|
|
||||||
):
|
):
|
||||||
delete_calendar_event(db, user.id, event_id)
|
delete_calendar_event(db, user.id, event_id)
|
||||||
@@ -1,8 +1,17 @@
|
|||||||
# modules/calendar/models.py
|
# modules/calendar/models.py
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON, Boolean # Add Boolean
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
JSON,
|
||||||
|
Boolean,
|
||||||
|
) # Add Boolean
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from core.database import Base
|
from core.database import Base
|
||||||
|
|
||||||
|
|
||||||
class CalendarEvent(Base):
|
class CalendarEvent(Base):
|
||||||
__tablename__ = "calendar_events"
|
__tablename__ = "calendar_events"
|
||||||
|
|
||||||
@@ -15,7 +24,9 @@ class CalendarEvent(Base):
|
|||||||
all_day = Column(Boolean, default=False) # Add all_day column
|
all_day = Column(Boolean, default=False) # Add all_day column
|
||||||
tags = Column(JSON)
|
tags = Column(JSON)
|
||||||
color = Column(String) # hex code for color
|
color = Column(String) # hex code for color
|
||||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False) # <-- Relationship
|
user_id = Column(
|
||||||
|
Integer, ForeignKey("users.id"), nullable=False
|
||||||
|
) # <-- Relationship
|
||||||
|
|
||||||
# Bi-directional relationship (for eager loading)
|
# Bi-directional relationship (for eager loading)
|
||||||
user = relationship("User", back_populates="calendar_events")
|
user = relationship("User", back_populates="calendar_events")
|
||||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
|||||||
from pydantic import BaseModel, field_validator # Add field_validator
|
from pydantic import BaseModel, field_validator # Add field_validator
|
||||||
from typing import List, Optional # Add List and Optional
|
from typing import List, Optional # Add List and Optional
|
||||||
|
|
||||||
|
|
||||||
# Base schema for common fields, including tags
|
# Base schema for common fields, including tags
|
||||||
class CalendarEventBase(BaseModel):
|
class CalendarEventBase(BaseModel):
|
||||||
title: str
|
title: str
|
||||||
@@ -14,17 +15,19 @@ class CalendarEventBase(BaseModel):
|
|||||||
all_day: Optional[bool] = None # Add all_day field
|
all_day: Optional[bool] = None # Add all_day field
|
||||||
tags: Optional[List[str]] = None # Add optional tags
|
tags: Optional[List[str]] = None # Add optional tags
|
||||||
|
|
||||||
@field_validator('tags', mode='before')
|
@field_validator("tags", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def tags_validate_null_string(cls, v):
|
def tags_validate_null_string(cls, v):
|
||||||
if v == "Null":
|
if v == "Null":
|
||||||
return None
|
return None
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
# Schema for creating an event (inherits from Base)
|
# Schema for creating an event (inherits from Base)
|
||||||
class CalendarEventCreate(CalendarEventBase):
|
class CalendarEventCreate(CalendarEventBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Schema for updating an event (all fields optional)
|
# Schema for updating an event (all fields optional)
|
||||||
class CalendarEventUpdate(BaseModel):
|
class CalendarEventUpdate(BaseModel):
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
@@ -36,20 +39,21 @@ class CalendarEventUpdate(BaseModel):
|
|||||||
all_day: Optional[bool] = None # Add all_day field
|
all_day: Optional[bool] = None # Add all_day field
|
||||||
tags: Optional[List[str]] = None # Add optional tags for update
|
tags: Optional[List[str]] = None # Add optional tags for update
|
||||||
|
|
||||||
@field_validator('tags', mode='before')
|
@field_validator("tags", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def tags_validate_null_string(cls, v):
|
def tags_validate_null_string(cls, v):
|
||||||
if v == "Null":
|
if v == "Null":
|
||||||
return None
|
return None
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
# Schema for the response (inherits from Base, adds ID and user_id)
|
# Schema for the response (inherits from Base, adds ID and user_id)
|
||||||
class CalendarEventResponse(CalendarEventBase):
|
class CalendarEventResponse(CalendarEventBase):
|
||||||
id: int
|
id: int
|
||||||
user_id: int
|
user_id: int
|
||||||
tags: List[str] # Keep as List[str], remove default []
|
tags: List[str] # Keep as List[str], remove default []
|
||||||
|
|
||||||
@field_validator('tags', mode='before')
|
@field_validator("tags", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def tags_validate_none_to_list(cls, v):
|
def tags_validate_none_to_list(cls, v):
|
||||||
# If the value from the source object (e.g., ORM model) is None,
|
# If the value from the source object (e.g., ORM model) is None,
|
||||||
|
|||||||
@@ -4,22 +4,31 @@ from sqlalchemy import or_ # Import or_
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from modules.calendar.models import CalendarEvent
|
from modules.calendar.models import CalendarEvent
|
||||||
from core.exceptions import not_found_exception
|
from core.exceptions import not_found_exception
|
||||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate # Import schemas
|
from modules.calendar.schemas import (
|
||||||
|
CalendarEventCreate,
|
||||||
|
CalendarEventUpdate,
|
||||||
|
) # Import schemas
|
||||||
|
|
||||||
|
|
||||||
def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate):
|
def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate):
|
||||||
# Ensure tags is None if not provided or empty list, matching model
|
# Ensure tags is None if not provided or empty list, matching model
|
||||||
tags_to_store = event_data.tags if event_data.tags else None
|
tags_to_store = event_data.tags if event_data.tags else None
|
||||||
event = CalendarEvent(
|
event = CalendarEvent(
|
||||||
**event_data.model_dump(exclude={'tags'}), # Use model_dump and exclude tags initially
|
**event_data.model_dump(
|
||||||
|
exclude={"tags"}
|
||||||
|
), # Use model_dump and exclude tags initially
|
||||||
tags=tags_to_store, # Set tags separately
|
tags=tags_to_store, # Set tags separately
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
db.add(event)
|
db.add(event)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(event)
|
db.refresh(event)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def get_calendar_events(db: Session, user_id: int, start: datetime | None, end: datetime | None):
|
|
||||||
|
def get_calendar_events(
|
||||||
|
db: Session, user_id: int, start: datetime | None, end: datetime | None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Retrieves calendar events for a user, optionally filtered by a date range.
|
Retrieves calendar events for a user, optionally filtered by a date range.
|
||||||
|
|
||||||
@@ -46,9 +55,13 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
|
|||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
# Case 1: Event has duration and overlaps
|
# Case 1: Event has duration and overlaps
|
||||||
(CalendarEvent.end is not None) & (CalendarEvent.start < end) & (CalendarEvent.end > start),
|
(CalendarEvent.end is not None)
|
||||||
|
& (CalendarEvent.start < end)
|
||||||
|
& (CalendarEvent.end > start),
|
||||||
# Case 2: Event is a point event within the range
|
# Case 2: Event is a point event within the range
|
||||||
(CalendarEvent.end is None) & (CalendarEvent.start >= start) & (CalendarEvent.start < end)
|
(CalendarEvent.end is None)
|
||||||
|
& (CalendarEvent.start >= start)
|
||||||
|
& (CalendarEvent.start < end),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# If only start is provided, filter events starting on or after start
|
# If only start is provided, filter events starting on or after start
|
||||||
@@ -65,32 +78,36 @@ def get_calendar_events(db: Session, user_id: int, start: datetime | None, end:
|
|||||||
# Event ends before the specified end time
|
# Event ends before the specified end time
|
||||||
(CalendarEvent.end is not None) & (CalendarEvent.end <= end),
|
(CalendarEvent.end is not None) & (CalendarEvent.end <= end),
|
||||||
# Point event occurs before the specified end time
|
# Point event occurs before the specified end time
|
||||||
(CalendarEvent.end is None) & (CalendarEvent.start < end)
|
(CalendarEvent.end is None) & (CalendarEvent.start < end),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Alternative interpretation for "ending before end": include events that *start* before end
|
# Alternative interpretation for "ending before end": include events that *start* before end
|
||||||
# query = query.filter(CalendarEvent.start < end)
|
# query = query.filter(CalendarEvent.start < end)
|
||||||
|
|
||||||
|
|
||||||
return query.order_by(CalendarEvent.start).all() # Order by start time
|
return query.order_by(CalendarEvent.start).all() # Order by start time
|
||||||
|
|
||||||
|
|
||||||
def get_calendar_event_by_id(db: Session, user_id: int, event_id: int):
|
def get_calendar_event_by_id(db: Session, user_id: int, event_id: int):
|
||||||
event = db.query(CalendarEvent).filter(
|
event = (
|
||||||
CalendarEvent.id == event_id,
|
db.query(CalendarEvent)
|
||||||
CalendarEvent.user_id == user_id
|
.filter(CalendarEvent.id == event_id, CalendarEvent.user_id == user_id)
|
||||||
).first()
|
.first()
|
||||||
|
)
|
||||||
if not event:
|
if not event:
|
||||||
raise not_found_exception()
|
raise not_found_exception()
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def update_calendar_event(db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate):
|
|
||||||
|
def update_calendar_event(
|
||||||
|
db: Session, user_id: int, event_id: int, event_data: CalendarEventUpdate
|
||||||
|
):
|
||||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||||
# Use model_dump with exclude_unset=True to only update provided fields
|
# Use model_dump with exclude_unset=True to only update provided fields
|
||||||
update_data = event_data.model_dump(exclude_unset=True)
|
update_data = event_data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
# Ensure tags is handled correctly (set to None if empty list provided)
|
# Ensure tags is handled correctly (set to None if empty list provided)
|
||||||
if key == 'tags' and isinstance(value, list) and not value:
|
if key == "tags" and isinstance(value, list) and not value:
|
||||||
setattr(event, key, None)
|
setattr(event, key, None)
|
||||||
else:
|
else:
|
||||||
setattr(event, key, value)
|
setattr(event, key, value)
|
||||||
@@ -99,6 +116,7 @@ def update_calendar_event(db: Session, user_id: int, event_id: int, event_data:
|
|||||||
db.refresh(event)
|
db.refresh(event)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
||||||
def delete_calendar_event(db: Session, user_id: int, event_id: int):
|
def delete_calendar_event(db: Session, user_id: int, event_id: int):
|
||||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||||
db.delete(event)
|
db.delete(event)
|
||||||
|
|||||||
@@ -7,13 +7,27 @@ from core.database import get_db
|
|||||||
|
|
||||||
from modules.auth.dependencies import get_current_user
|
from modules.auth.dependencies import get_current_user
|
||||||
from modules.auth.models import User
|
from modules.auth.models import User
|
||||||
|
|
||||||
# Import the new service functions and Enum
|
# Import the new service functions and Enum
|
||||||
from modules.nlp.service import process_request, ask_ai, save_chat_message, get_chat_history, MessageSender
|
from modules.nlp.service import (
|
||||||
|
process_request,
|
||||||
|
ask_ai,
|
||||||
|
save_chat_message,
|
||||||
|
get_chat_history,
|
||||||
|
MessageSender,
|
||||||
|
)
|
||||||
|
|
||||||
# Import the response schema and the new ChatMessage model for response type hinting
|
# Import the response schema and the new ChatMessage model for response type hinting
|
||||||
from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse
|
from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse
|
||||||
from modules.calendar.service import create_calendar_event, get_calendar_events, update_calendar_event, delete_calendar_event
|
from modules.calendar.service import (
|
||||||
|
create_calendar_event,
|
||||||
|
get_calendar_events,
|
||||||
|
update_calendar_event,
|
||||||
|
delete_calendar_event,
|
||||||
|
)
|
||||||
from modules.calendar.models import CalendarEvent
|
from modules.calendar.models import CalendarEvent
|
||||||
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate
|
from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate
|
||||||
|
|
||||||
# Import TODO services, schemas, and model
|
# Import TODO services, schemas, and model
|
||||||
from modules.todo import service as todo_service
|
from modules.todo import service as todo_service
|
||||||
from modules.todo.models import Todo
|
from modules.todo.models import Todo
|
||||||
@@ -21,6 +35,7 @@ from modules.todo.schemas import TodoCreate, TodoUpdate
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageResponse(BaseModel):
|
class ChatMessageResponse(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
sender: MessageSender # Use the enum directly
|
sender: MessageSender # Use the enum directly
|
||||||
@@ -30,8 +45,10 @@ class ChatMessageResponse(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True # Allow Pydantic to work with ORM models
|
from_attributes = True # Allow Pydantic to work with ORM models
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/nlp", tags=["nlp"])
|
router = APIRouter(prefix="/nlp", tags=["nlp"])
|
||||||
|
|
||||||
|
|
||||||
# Helper to format calendar events (expects list of CalendarEvent models)
|
# Helper to format calendar events (expects list of CalendarEvent models)
|
||||||
def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
|
def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
|
||||||
if not events:
|
if not events:
|
||||||
@@ -39,12 +56,15 @@ def format_calendar_events(events: List[CalendarEvent]) -> List[str]:
|
|||||||
formatted = ["Here are the events:"]
|
formatted = ["Here are the events:"]
|
||||||
for event in events:
|
for event in events:
|
||||||
# Access attributes directly from the model instance
|
# Access attributes directly from the model instance
|
||||||
start_str = event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
|
start_str = (
|
||||||
|
event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time"
|
||||||
|
)
|
||||||
end_str = event.end.strftime("%H:%M") if event.end else ""
|
end_str = event.end.strftime("%H:%M") if event.end else ""
|
||||||
title = event.title or "Untitled Event"
|
title = event.title or "Untitled Event"
|
||||||
formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})")
|
formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})")
|
||||||
return formatted
|
return formatted
|
||||||
|
|
||||||
|
|
||||||
# Helper to format TODO items (expects list of Todo models)
|
# Helper to format TODO items (expects list of Todo models)
|
||||||
def format_todos(todos: List[Todo]) -> List[str]:
|
def format_todos(todos: List[Todo]) -> List[str]:
|
||||||
if not todos:
|
if not todos:
|
||||||
@@ -54,19 +74,28 @@ def format_todos(todos: List[Todo]) -> List[str]:
|
|||||||
status = "[X]" if todo.complete else "[ ]"
|
status = "[X]" if todo.complete else "[ ]"
|
||||||
date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else ""
|
date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else ""
|
||||||
remind_str = " (Reminder)" if todo.remind else ""
|
remind_str = " (Reminder)" if todo.remind else ""
|
||||||
formatted.append(f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})")
|
formatted.append(
|
||||||
|
f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})"
|
||||||
|
)
|
||||||
return formatted
|
return formatted
|
||||||
|
|
||||||
|
|
||||||
# Update the response model for the endpoint
|
# Update the response model for the endpoint
|
||||||
@router.post("/process-command", response_model=ProcessCommandResponse)
|
@router.post("/process-command", response_model=ProcessCommandResponse)
|
||||||
def process_command(request_data: ProcessCommandRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
def process_command(
|
||||||
|
request_data: ProcessCommandRequest,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Process the user command, save messages, execute action, save response, and return user-friendly responses.
|
Process the user command, save messages, execute action, save response, and return user-friendly responses.
|
||||||
"""
|
"""
|
||||||
user_input = request_data.user_input
|
user_input = request_data.user_input
|
||||||
|
|
||||||
# --- Save User Message ---
|
# --- Save User Message ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.USER, text=user_input)
|
save_chat_message(
|
||||||
|
db, user_id=current_user.id, sender=MessageSender.USER, text=user_input
|
||||||
|
)
|
||||||
# ------------------------
|
# ------------------------
|
||||||
|
|
||||||
command_data = process_request(user_input)
|
command_data = process_request(user_input)
|
||||||
@@ -78,7 +107,9 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
|
|||||||
|
|
||||||
# --- Save Initial AI Response ---
|
# --- Save Initial AI Response ---
|
||||||
# Save the first response generated by process_request
|
# Save the first response generated by process_request
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text)
|
save_chat_message(
|
||||||
|
db, user_id=current_user.id, sender=MessageSender.AI, text=response_text
|
||||||
|
)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
if intent == "error":
|
if intent == "error":
|
||||||
@@ -97,139 +128,233 @@ def process_command(request_data: ProcessCommandRequest, current_user: User = De
|
|||||||
ai_answer = ask_ai(**params)
|
ai_answer = ask_ai(**params)
|
||||||
responses.append(ai_answer)
|
responses.append(ai_answer)
|
||||||
# --- Save Additional AI Response ---
|
# --- Save Additional AI Response ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer)
|
save_chat_message(
|
||||||
|
db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer
|
||||||
|
)
|
||||||
# ---------------------------------
|
# ---------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
case "get_calendar_events":
|
case "get_calendar_events":
|
||||||
events: List[CalendarEvent] = get_calendar_events(db, current_user.id, **params)
|
events: List[CalendarEvent] = get_calendar_events(
|
||||||
|
db, current_user.id, **params
|
||||||
|
)
|
||||||
formatted_responses = format_calendar_events(events)
|
formatted_responses = format_calendar_events(events)
|
||||||
responses.extend(formatted_responses)
|
responses.extend(formatted_responses)
|
||||||
# --- Save Additional AI Responses ---
|
# --- Save Additional AI Responses ---
|
||||||
for resp in formatted_responses:
|
for resp in formatted_responses:
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp)
|
save_chat_message(
|
||||||
|
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
|
||||||
|
)
|
||||||
# ----------------------------------
|
# ----------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
case "add_calendar_event":
|
case "add_calendar_event":
|
||||||
event_data = CalendarEventCreate(**params)
|
event_data = CalendarEventCreate(**params)
|
||||||
created_event = create_calendar_event(db, current_user.id, event_data)
|
created_event = create_calendar_event(db, current_user.id, event_data)
|
||||||
start_str = created_event.start.strftime("%Y-%m-%d %H:%M") if created_event.start else "No start time"
|
start_str = (
|
||||||
|
created_event.start.strftime("%Y-%m-%d %H:%M")
|
||||||
|
if created_event.start
|
||||||
|
else "No start time"
|
||||||
|
)
|
||||||
title = created_event.title or "Untitled Event"
|
title = created_event.title or "Untitled Event"
|
||||||
add_response = f"Added: {title} starting at {start_str}."
|
add_response = f"Added: {title} starting at {start_str}."
|
||||||
responses.append(add_response)
|
responses.append(add_response)
|
||||||
# --- Save Additional AI Response ---
|
# --- Save Additional AI Response ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=add_response,
|
||||||
|
)
|
||||||
# ---------------------------------
|
# ---------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
case "update_calendar_event":
|
case "update_calendar_event":
|
||||||
event_id = params.pop('event_id', None)
|
event_id = params.pop("event_id", None)
|
||||||
if event_id is None:
|
if event_id is None:
|
||||||
# Save the error message before raising
|
# Save the error message before raising
|
||||||
error_msg = "Event ID is required for update."
|
error_msg = "Event ID is required for update."
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=error_msg,
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail=error_msg)
|
raise HTTPException(status_code=400, detail=error_msg)
|
||||||
event_data = CalendarEventUpdate(**params)
|
event_data = CalendarEventUpdate(**params)
|
||||||
updated_event = update_calendar_event(db, current_user.id, event_id, event_data=event_data)
|
updated_event = update_calendar_event(
|
||||||
|
db, current_user.id, event_id, event_data=event_data
|
||||||
|
)
|
||||||
title = updated_event.title or "Untitled Event"
|
title = updated_event.title or "Untitled Event"
|
||||||
update_response = f"Updated event ID {updated_event.id}: {title}."
|
update_response = f"Updated event ID {updated_event.id}: {title}."
|
||||||
responses.append(update_response)
|
responses.append(update_response)
|
||||||
# --- Save Additional AI Response ---
|
# --- Save Additional AI Response ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=update_response,
|
||||||
|
)
|
||||||
# ---------------------------------
|
# ---------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
case "delete_calendar_event":
|
case "delete_calendar_event":
|
||||||
event_id = params.get('event_id')
|
event_id = params.get("event_id")
|
||||||
if event_id is None:
|
if event_id is None:
|
||||||
# Save the error message before raising
|
# Save the error message before raising
|
||||||
error_msg = "Event ID is required for delete."
|
error_msg = "Event ID is required for delete."
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=error_msg,
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail=error_msg)
|
raise HTTPException(status_code=400, detail=error_msg)
|
||||||
delete_calendar_event(db, current_user.id, event_id)
|
delete_calendar_event(db, current_user.id, event_id)
|
||||||
delete_response = f"Deleted event ID {event_id}."
|
delete_response = f"Deleted event ID {event_id}."
|
||||||
responses.append(delete_response)
|
responses.append(delete_response)
|
||||||
# --- Save Additional AI Response ---
|
# --- Save Additional AI Response ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=delete_response,
|
||||||
|
)
|
||||||
# ---------------------------------
|
# ---------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
# --- Add TODO Cases ---
|
# --- Add TODO Cases ---
|
||||||
case "get_todos":
|
case "get_todos":
|
||||||
todos: List[Todo] = todo_service.get_todos(db, user=current_user, **params)
|
todos: List[Todo] = todo_service.get_todos(
|
||||||
|
db, user=current_user, **params
|
||||||
|
)
|
||||||
formatted_responses = format_todos(todos)
|
formatted_responses = format_todos(todos)
|
||||||
responses.extend(formatted_responses)
|
responses.extend(formatted_responses)
|
||||||
# --- Save Additional AI Responses ---
|
# --- Save Additional AI Responses ---
|
||||||
for resp in formatted_responses:
|
for resp in formatted_responses:
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp)
|
save_chat_message(
|
||||||
|
db, user_id=current_user.id, sender=MessageSender.AI, text=resp
|
||||||
|
)
|
||||||
# ----------------------------------
|
# ----------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
case "add_todo":
|
case "add_todo":
|
||||||
todo_data = TodoCreate(**params)
|
todo_data = TodoCreate(**params)
|
||||||
created_todo = todo_service.create_todo(db, todo=todo_data, user=current_user)
|
created_todo = todo_service.create_todo(
|
||||||
add_response = f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
|
db, todo=todo_data, user=current_user
|
||||||
|
)
|
||||||
|
add_response = (
|
||||||
|
f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})."
|
||||||
|
)
|
||||||
responses.append(add_response)
|
responses.append(add_response)
|
||||||
# --- Save Additional AI Response ---
|
# --- Save Additional AI Response ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=add_response,
|
||||||
|
)
|
||||||
# ---------------------------------
|
# ---------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
case "update_todo":
|
case "update_todo":
|
||||||
todo_id = params.pop('todo_id', None)
|
todo_id = params.pop("todo_id", None)
|
||||||
if todo_id is None:
|
if todo_id is None:
|
||||||
error_msg = "TODO ID is required for update."
|
error_msg = "TODO ID is required for update."
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=error_msg,
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail=error_msg)
|
raise HTTPException(status_code=400, detail=error_msg)
|
||||||
todo_data = TodoUpdate(**params)
|
todo_data = TodoUpdate(**params)
|
||||||
updated_todo = todo_service.update_todo(db, todo_id=todo_id, todo_update=todo_data, user=current_user)
|
updated_todo = todo_service.update_todo(
|
||||||
update_response = f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
|
db, todo_id=todo_id, todo_update=todo_data, user=current_user
|
||||||
if 'complete' in params:
|
)
|
||||||
status = "complete" if params['complete'] else "incomplete"
|
update_response = (
|
||||||
|
f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'."
|
||||||
|
)
|
||||||
|
if "complete" in params:
|
||||||
|
status = "complete" if params["complete"] else "incomplete"
|
||||||
update_response += f" Marked as {status}."
|
update_response += f" Marked as {status}."
|
||||||
responses.append(update_response)
|
responses.append(update_response)
|
||||||
# --- Save Additional AI Response ---
|
# --- Save Additional AI Response ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=update_response,
|
||||||
|
)
|
||||||
# ---------------------------------
|
# ---------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
case "delete_todo":
|
case "delete_todo":
|
||||||
todo_id = params.get('todo_id')
|
todo_id = params.get("todo_id")
|
||||||
if todo_id is None:
|
if todo_id is None:
|
||||||
error_msg = "TODO ID is required for delete."
|
error_msg = "TODO ID is required for delete."
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=error_msg,
|
||||||
|
)
|
||||||
raise HTTPException(status_code=400, detail=error_msg)
|
raise HTTPException(status_code=400, detail=error_msg)
|
||||||
deleted_todo = todo_service.delete_todo(db, todo_id=todo_id, user=current_user)
|
deleted_todo = todo_service.delete_todo(
|
||||||
delete_response = f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
|
db, todo_id=todo_id, user=current_user
|
||||||
|
)
|
||||||
|
delete_response = (
|
||||||
|
f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'."
|
||||||
|
)
|
||||||
responses.append(delete_response)
|
responses.append(delete_response)
|
||||||
# --- Save Additional AI Response ---
|
# --- Save Additional AI Response ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=delete_response,
|
||||||
|
)
|
||||||
# ---------------------------------
|
# ---------------------------------
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
# --- End TODO Cases ---
|
# --- End TODO Cases ---
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f"Warning: Unhandled intent '{intent}' reached api.py match statement.")
|
print(
|
||||||
|
f"Warning: Unhandled intent '{intent}' reached api.py match statement."
|
||||||
|
)
|
||||||
# The initial response_text was already saved
|
# The initial response_text was already saved
|
||||||
return ProcessCommandResponse(responses=responses)
|
return ProcessCommandResponse(responses=responses)
|
||||||
|
|
||||||
except HTTPException as http_exc:
|
except HTTPException as http_exc:
|
||||||
# Don't save again if already saved before raising
|
# Don't save again if already saved before raising
|
||||||
if http_exc.status_code != 400 or ('event_id' not in http_exc.detail.lower()):
|
if http_exc.status_code != 400 or ("event_id" not in http_exc.detail.lower()):
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=http_exc.detail)
|
save_chat_message(
|
||||||
|
db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text=http_exc.detail,
|
||||||
|
)
|
||||||
raise http_exc
|
raise http_exc
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error executing intent '{intent}': {e}")
|
print(f"Error executing intent '{intent}': {e}")
|
||||||
error_response = "Sorry, I encountered an error while trying to perform that action."
|
error_response = (
|
||||||
|
"Sorry, I encountered an error while trying to perform that action."
|
||||||
|
)
|
||||||
# --- Save Final Error AI Response ---
|
# --- Save Final Error AI Response ---
|
||||||
save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_response)
|
save_chat_message(
|
||||||
|
db, user_id=current_user.id, sender=MessageSender.AI, text=error_response
|
||||||
|
)
|
||||||
# ----------------------------------
|
# ----------------------------------
|
||||||
return ProcessCommandResponse(responses=[error_response])
|
return ProcessCommandResponse(responses=[error_response])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/history", response_model=List[ChatMessageResponse])
|
@router.get("/history", response_model=List[ChatMessageResponse])
|
||||||
def read_chat_history(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
|
def read_chat_history(
|
||||||
|
current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
"""Retrieves the last 50 chat messages for the current user."""
|
"""Retrieves the last 50 chat messages for the current user."""
|
||||||
history = get_chat_history(db, user_id=current_user.id, limit=50)
|
history = get_chat_history(db, user_id=current_user.id, limit=50)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
\
|
|
||||||
# /home/cdp/code/MAIA/backend/modules/nlp/models.py
|
# /home/cdp/code/MAIA/backend/modules/nlp/models.py
|
||||||
from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey, Enum as SQLEnum
|
from sqlalchemy import Column, Integer, Text, DateTime, ForeignKey, Enum as SQLEnum
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
@@ -7,10 +6,12 @@ import enum
|
|||||||
|
|
||||||
from core.database import Base
|
from core.database import Base
|
||||||
|
|
||||||
|
|
||||||
class MessageSender(enum.Enum):
|
class MessageSender(enum.Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
AI = "ai"
|
AI = "ai"
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(Base):
|
class ChatMessage(Base):
|
||||||
__tablename__ = "chat_messages"
|
__tablename__ = "chat_messages"
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,11 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class ProcessCommandRequest(BaseModel):
|
class ProcessCommandRequest(BaseModel):
|
||||||
user_input: str
|
user_input: str
|
||||||
|
|
||||||
|
|
||||||
class ProcessCommandResponse(BaseModel):
|
class ProcessCommandResponse(BaseModel):
|
||||||
responses: List[str]
|
responses: List[str]
|
||||||
# Optional: Keep details if needed for specific frontend logic beyond display
|
# Optional: Keep details if needed for specific frontend logic beyond display
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ from core.config import settings
|
|||||||
client = genai.Client(api_key=settings.GOOGLE_API_KEY)
|
client = genai.Client(api_key=settings.GOOGLE_API_KEY)
|
||||||
|
|
||||||
### Base prompt for MAIA, used for inital user requests
|
### Base prompt for MAIA, used for inital user requests
|
||||||
SYSTEM_PROMPT = """
|
SYSTEM_PROMPT = (
|
||||||
|
"""
|
||||||
You are MAIA - My AI Assistant. Your job is to parse user requests into structured JSON commands and generate a user-facing response text.
|
You are MAIA - My AI Assistant. Your job is to parse user requests into structured JSON commands and generate a user-facing response text.
|
||||||
|
|
||||||
Available functions/intents:
|
Available functions/intents:
|
||||||
@@ -109,8 +110,11 @@ MAIA:
|
|||||||
"response_text": "Okay, I've deleted task 2 from your list."
|
"response_text": "Okay, I've deleted task 2 from your list."
|
||||||
}
|
}
|
||||||
|
|
||||||
The datetime right now is """+str(datetime.now(timezone.utc))+""".
|
The datetime right now is """
|
||||||
|
+ str(datetime.now(timezone.utc))
|
||||||
|
+ """.
|
||||||
"""
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
### Prompt for MAIA to forward user request to AI
|
### Prompt for MAIA to forward user request to AI
|
||||||
SYSTEM_FORWARD_PROMPT = f"""
|
SYSTEM_FORWARD_PROMPT = f"""
|
||||||
@@ -123,6 +127,7 @@ Here is the user request:
|
|||||||
|
|
||||||
# --- Chat History Service Functions ---
|
# --- Chat History Service Functions ---
|
||||||
|
|
||||||
|
|
||||||
def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: str):
|
def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: str):
|
||||||
"""Saves a chat message to the database."""
|
"""Saves a chat message to the database."""
|
||||||
db_message = ChatMessage(user_id=user_id, sender=sender, text=text)
|
db_message = ChatMessage(user_id=user_id, sender=sender, text=text)
|
||||||
@@ -131,16 +136,21 @@ def save_chat_message(db: Session, user_id: int, sender: MessageSender, text: st
|
|||||||
db.refresh(db_message)
|
db.refresh(db_message)
|
||||||
return db_message
|
return db_message
|
||||||
|
|
||||||
|
|
||||||
def get_chat_history(db: Session, user_id: int, limit: int = 50) -> List[ChatMessage]:
|
def get_chat_history(db: Session, user_id: int, limit: int = 50) -> List[ChatMessage]:
|
||||||
"""Retrieves the last 'limit' chat messages for a user."""
|
"""Retrieves the last 'limit' chat messages for a user."""
|
||||||
return db.query(ChatMessage)\
|
return (
|
||||||
.filter(ChatMessage.user_id == user_id)\
|
db.query(ChatMessage)
|
||||||
.order_by(desc(ChatMessage.timestamp))\
|
.filter(ChatMessage.user_id == user_id)
|
||||||
.limit(limit)\
|
.order_by(desc(ChatMessage.timestamp))
|
||||||
.all()[::-1] # Reverse to get oldest first for display order
|
.limit(limit)
|
||||||
|
.all()[::-1]
|
||||||
|
) # Reverse to get oldest first for display order
|
||||||
|
|
||||||
|
|
||||||
# --- Existing NLP Service Functions ---
|
# --- Existing NLP Service Functions ---
|
||||||
|
|
||||||
|
|
||||||
def process_request(request: str):
|
def process_request(request: str):
|
||||||
"""
|
"""
|
||||||
Process the user request using the Google GenAI API.
|
Process the user request using the Google GenAI API.
|
||||||
@@ -152,7 +162,7 @@ def process_request(request: str):
|
|||||||
config={
|
config={
|
||||||
"temperature": 0.3, # Less creativity, more factual
|
"temperature": 0.3, # Less creativity, more factual
|
||||||
"response_mime_type": "application/json",
|
"response_mime_type": "application/json",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse the JSON response
|
# Parse the JSON response
|
||||||
@@ -160,7 +170,9 @@ def process_request(request: str):
|
|||||||
parsed_response = json.loads(response.text)
|
parsed_response = json.loads(response.text)
|
||||||
# Validate required fields
|
# Validate required fields
|
||||||
if not all(k in parsed_response for k in ("intent", "params", "response_text")):
|
if not all(k in parsed_response for k in ("intent", "params", "response_text")):
|
||||||
raise ValueError("AI response missing required fields (intent, params, response_text)")
|
raise ValueError(
|
||||||
|
"AI response missing required fields (intent, params, response_text)"
|
||||||
|
)
|
||||||
return parsed_response
|
return parsed_response
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
print(f"Error parsing AI response: {e}")
|
print(f"Error parsing AI response: {e}")
|
||||||
@@ -169,9 +181,10 @@ def process_request(request: str):
|
|||||||
return {
|
return {
|
||||||
"intent": "error",
|
"intent": "error",
|
||||||
"params": {},
|
"params": {},
|
||||||
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?"
|
"response_text": "Sorry, I had trouble understanding that request or formulating a response. Could you please try rephrasing?",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def ask_ai(request: str):
|
def ask_ai(request: str):
|
||||||
"""
|
"""
|
||||||
Ask the AI a question.
|
Ask the AI a question.
|
||||||
@@ -179,6 +192,6 @@ def ask_ai(request: str):
|
|||||||
"""
|
"""
|
||||||
response = client.models.generate_content(
|
response = client.models.generate_content(
|
||||||
model="gemini-2.0-flash",
|
model="gemini-2.0-flash",
|
||||||
contents=SYSTEM_FORWARD_PROMPT+request,
|
contents=SYSTEM_FORWARD_PROMPT + request,
|
||||||
)
|
)
|
||||||
return response.text
|
return response.text
|
||||||
@@ -15,48 +15,55 @@ router = APIRouter(
|
|||||||
responses={404: {"description": "Not found"}},
|
responses={404: {"description": "Not found"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/", response_model=schemas.Todo, status_code=status.HTTP_201_CREATED)
|
@router.post("/", response_model=schemas.Todo, status_code=status.HTTP_201_CREATED)
|
||||||
def create_todo_endpoint(
|
def create_todo_endpoint(
|
||||||
todo: schemas.TodoCreate,
|
todo: schemas.TodoCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||||
):
|
):
|
||||||
return service.create_todo(db=db, todo=todo, user=current_user)
|
return service.create_todo(db=db, todo=todo, user=current_user)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[schemas.Todo])
|
@router.get("/", response_model=List[schemas.Todo])
|
||||||
def read_todos_endpoint(
|
def read_todos_endpoint(
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||||
):
|
):
|
||||||
todos = service.get_todos(db=db, user=current_user, skip=skip, limit=limit)
|
todos = service.get_todos(db=db, user=current_user, skip=skip, limit=limit)
|
||||||
return todos
|
return todos
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{todo_id}", response_model=schemas.Todo)
|
@router.get("/{todo_id}", response_model=schemas.Todo)
|
||||||
def read_todo_endpoint(
|
def read_todo_endpoint(
|
||||||
todo_id: int,
|
todo_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||||
):
|
):
|
||||||
db_todo = service.get_todo(db=db, todo_id=todo_id, user=current_user)
|
db_todo = service.get_todo(db=db, todo_id=todo_id, user=current_user)
|
||||||
if db_todo is None:
|
if db_todo is None:
|
||||||
raise HTTPException(status_code=404, detail="Todo not found")
|
raise HTTPException(status_code=404, detail="Todo not found")
|
||||||
return db_todo
|
return db_todo
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{todo_id}", response_model=schemas.Todo)
|
@router.put("/{todo_id}", response_model=schemas.Todo)
|
||||||
def update_todo_endpoint(
|
def update_todo_endpoint(
|
||||||
todo_id: int,
|
todo_id: int,
|
||||||
todo_update: schemas.TodoUpdate,
|
todo_update: schemas.TodoUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||||
):
|
):
|
||||||
return service.update_todo(db=db, todo_id=todo_id, todo_update=todo_update, user=current_user)
|
return service.update_todo(
|
||||||
|
db=db, todo_id=todo_id, todo_update=todo_update, user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{todo_id}", response_model=schemas.Todo)
|
@router.delete("/{todo_id}", response_model=schemas.Todo)
|
||||||
def delete_todo_endpoint(
|
def delete_todo_endpoint(
|
||||||
todo_id: int,
|
todo_id: int,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user) # Corrected dependency
|
current_user: User = Depends(get_current_user), # Corrected dependency
|
||||||
):
|
):
|
||||||
return service.delete_todo(db=db, todo_id=todo_id, user=current_user)
|
return service.delete_todo(db=db, todo_id=todo_id, user=current_user)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from sqlalchemy import Column, Integer, String, Boolean, DateTime, ForeignKey
|
|||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from core.database import Base
|
from core.database import Base
|
||||||
|
|
||||||
|
|
||||||
class Todo(Base):
|
class Todo(Base):
|
||||||
__tablename__ = "todos"
|
__tablename__ = "todos"
|
||||||
|
|
||||||
@@ -13,4 +14,6 @@ class Todo(Base):
|
|||||||
complete = Column(Boolean, default=False)
|
complete = Column(Boolean, default=False)
|
||||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||||
|
|
||||||
owner = relationship("User") # Add relationship if needed, assuming User model exists in auth.models
|
owner = relationship(
|
||||||
|
"User"
|
||||||
|
) # Add relationship if needed, assuming User model exists in auth.models
|
||||||
|
|||||||
@@ -3,21 +3,25 @@ from pydantic import BaseModel
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
class TodoBase(BaseModel):
|
class TodoBase(BaseModel):
|
||||||
task: str
|
task: str
|
||||||
date: Optional[datetime.datetime] = None
|
date: Optional[datetime.datetime] = None
|
||||||
remind: bool = False
|
remind: bool = False
|
||||||
complete: bool = False
|
complete: bool = False
|
||||||
|
|
||||||
|
|
||||||
class TodoCreate(TodoBase):
|
class TodoCreate(TodoBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TodoUpdate(BaseModel):
|
class TodoUpdate(BaseModel):
|
||||||
task: Optional[str] = None
|
task: Optional[str] = None
|
||||||
date: Optional[datetime.datetime] = None
|
date: Optional[datetime.datetime] = None
|
||||||
remind: Optional[bool] = None
|
remind: Optional[bool] = None
|
||||||
complete: Optional[bool] = None
|
complete: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
class Todo(TodoBase):
|
class Todo(TodoBase):
|
||||||
id: int
|
id: int
|
||||||
owner_id: int
|
owner_id: int
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from . import models, schemas
|
|||||||
from modules.auth.models import User # Assuming User model is in auth.models
|
from modules.auth.models import User # Assuming User model is in auth.models
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
|
def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
|
||||||
db_todo = models.Todo(**todo.dict(), owner_id=user.id)
|
db_todo = models.Todo(**todo.dict(), owner_id=user.id)
|
||||||
db.add(db_todo)
|
db.add(db_todo)
|
||||||
@@ -11,17 +12,34 @@ def create_todo(db: Session, todo: schemas.TodoCreate, user: User):
|
|||||||
db.refresh(db_todo)
|
db.refresh(db_todo)
|
||||||
return db_todo
|
return db_todo
|
||||||
|
|
||||||
|
|
||||||
def get_todos(db: Session, user: User, skip: int = 0, limit: int = 100):
|
def get_todos(db: Session, user: User, skip: int = 0, limit: int = 100):
|
||||||
return db.query(models.Todo).filter(models.Todo.owner_id == user.id).offset(skip).limit(limit).all()
|
return (
|
||||||
|
db.query(models.Todo)
|
||||||
|
.filter(models.Todo.owner_id == user.id)
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_todo(db: Session, todo_id: int, user: User):
|
def get_todo(db: Session, todo_id: int, user: User):
|
||||||
db_todo = db.query(models.Todo).filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id).first()
|
db_todo = (
|
||||||
|
db.query(models.Todo)
|
||||||
|
.filter(models.Todo.id == todo_id, models.Todo.owner_id == user.id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if db_todo is None:
|
if db_todo is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Todo not found"
|
||||||
|
)
|
||||||
return db_todo
|
return db_todo
|
||||||
|
|
||||||
|
|
||||||
def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user: User):
|
def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user: User):
|
||||||
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence
|
db_todo = get_todo(
|
||||||
|
db=db, todo_id=todo_id, user=user
|
||||||
|
) # Reuse get_todo to check ownership and existence
|
||||||
update_data = todo_update.dict(exclude_unset=True)
|
update_data = todo_update.dict(exclude_unset=True)
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(db_todo, key, value)
|
setattr(db_todo, key, value)
|
||||||
@@ -29,8 +47,11 @@ def update_todo(db: Session, todo_id: int, todo_update: schemas.TodoUpdate, user
|
|||||||
db.refresh(db_todo)
|
db.refresh(db_todo)
|
||||||
return db_todo
|
return db_todo
|
||||||
|
|
||||||
|
|
||||||
def delete_todo(db: Session, todo_id: int, user: User):
|
def delete_todo(db: Session, todo_id: int, user: User):
|
||||||
db_todo = get_todo(db=db, todo_id=todo_id, user=user) # Reuse get_todo to check ownership and existence
|
db_todo = get_todo(
|
||||||
|
db=db, todo_id=todo_id, user=user
|
||||||
|
) # Reuse get_todo to check ownership and existence
|
||||||
db.delete(db_todo)
|
db.delete(db_todo)
|
||||||
db.commit()
|
db.commit()
|
||||||
return db_todo
|
return db_todo
|
||||||
|
|||||||
@@ -11,16 +11,25 @@ from modules.auth.models import User
|
|||||||
|
|
||||||
router = APIRouter(prefix="/user", tags=["user"])
|
router = APIRouter(prefix="/user", tags=["user"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
@router.get("/me", response_model=UserResponse)
|
||||||
def me(db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
def me(
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> UserResponse:
|
||||||
"""
|
"""
|
||||||
Get the current user. Requires user to be logged in.
|
Get the current user. Requires user to be logged in.
|
||||||
Returns the user object.
|
Returns the user object.
|
||||||
"""
|
"""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{username}", response_model=UserResponse)
|
@router.get("/{username}", response_model=UserResponse)
|
||||||
def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
def get_user(
|
||||||
|
username: str,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> UserResponse:
|
||||||
"""
|
"""
|
||||||
Get a user by username.
|
Get a user by username.
|
||||||
Returns the user object.
|
Returns the user object.
|
||||||
@@ -33,8 +42,14 @@ def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_use
|
|||||||
raise not_found_exception("User not found")
|
raise not_found_exception("User not found")
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{username}", response_model=UserResponse)
|
@router.patch("/{username}", response_model=UserResponse)
|
||||||
def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
def update_user(
|
||||||
|
username: str,
|
||||||
|
user_data: UserPatch,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> UserResponse:
|
||||||
"""
|
"""
|
||||||
Update a user by username.
|
Update a user by username.
|
||||||
Returns the updated user object.
|
Returns the updated user object.
|
||||||
@@ -60,8 +75,13 @@ def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depe
|
|||||||
db.refresh(user)
|
db.refresh(user)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{username}", response_model=UserResponse)
|
@router.delete("/{username}", response_model=UserResponse)
|
||||||
def delete_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
def delete_user(
|
||||||
|
username: str,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> UserResponse:
|
||||||
"""
|
"""
|
||||||
Delete a user by username.
|
Delete a user by username.
|
||||||
Returns the deleted user object.
|
Returns the deleted user object.
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from core.database import get_db, get_sessionmaker
|
|||||||
|
|
||||||
fake = Faker()
|
fake = Faker()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def postgres_container() -> Generator[PostgresContainer, None, None]:
|
def postgres_container() -> Generator[PostgresContainer, None, None]:
|
||||||
"""Fixture to create a PostgreSQL container for testing."""
|
"""Fixture to create a PostgreSQL container for testing."""
|
||||||
@@ -22,6 +23,7 @@ def postgres_container() -> Generator[PostgresContainer, None, None]:
|
|||||||
yield postgres
|
yield postgres
|
||||||
print("Postgres container stopped.")
|
print("Postgres container stopped.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def db(postgres_container) -> Generator[Session, None, None]:
|
def db(postgres_container) -> Generator[Session, None, None]:
|
||||||
"""Function-scoped database session with rollback"""
|
"""Function-scoped database session with rollback"""
|
||||||
@@ -34,6 +36,7 @@ def db(postgres_container) -> Generator[Session, None, None]:
|
|||||||
session.rollback()
|
session.rollback()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def client(db: Session) -> Generator[TestClient, None, None]:
|
def client(db: Session) -> Generator[TestClient, None, None]:
|
||||||
"""Function-scoped test client with dependency override"""
|
"""Function-scoped test client with dependency override"""
|
||||||
@@ -53,6 +56,8 @@ def client(db: Session) -> Generator[TestClient, None, None]:
|
|||||||
|
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None:
|
def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None:
|
||||||
from main import app
|
from main import app
|
||||||
|
|
||||||
app.dependency_overrides[dependency] = lambda: mocked_response
|
app.dependency_overrides[dependency] = lambda: mocked_response
|
||||||
@@ -5,13 +5,20 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from core.config import settings
|
from core.config import settings
|
||||||
from modules.auth.models import User
|
from modules.auth.models import User
|
||||||
from modules.auth.security import authenticate_user, create_access_token, create_refresh_token, hash_password
|
from modules.auth.security import (
|
||||||
|
authenticate_user,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
hash_password,
|
||||||
|
)
|
||||||
from modules.auth.schemas import UserRole
|
from modules.auth.schemas import UserRole
|
||||||
from tests.conftest import fake
|
from tests.conftest import fake
|
||||||
from typing import Optional # Import Optional
|
from typing import Optional # Import Optional
|
||||||
|
|
||||||
|
|
||||||
def create_user(db: Session, is_admin: bool = False, username: Optional[str] = None) -> User:
|
def create_user(
|
||||||
|
db: Session, is_admin: bool = False, username: Optional[str] = None
|
||||||
|
) -> User:
|
||||||
unhashed_password = fake.password()
|
unhashed_password = fake.password()
|
||||||
_user = User(
|
_user = User(
|
||||||
name=fake.name(),
|
name=fake.name(),
|
||||||
@@ -26,12 +33,16 @@ def create_user(db: Session, is_admin: bool = False, username: Optional[str] = N
|
|||||||
db.refresh(_user)
|
db.refresh(_user)
|
||||||
return _user, unhashed_password # return for testing
|
return _user, unhashed_password # return for testing
|
||||||
|
|
||||||
|
|
||||||
def login(db: Session, username: str, password: str) -> str:
|
def login(db: Session, username: str, password: str) -> str:
|
||||||
user = authenticate_user(username, password, db)
|
user = authenticate_user(username, password, db)
|
||||||
if not user:
|
if not user:
|
||||||
raise Exception("Incorrect username or password")
|
raise Exception("Incorrect username or password")
|
||||||
|
|
||||||
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
access_token = create_access_token(
|
||||||
|
data={"sub": user.username},
|
||||||
|
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||||
|
)
|
||||||
refresh_token = create_refresh_token(data={"sub": user.username})
|
refresh_token = create_refresh_token(data={"sub": user.username})
|
||||||
|
|
||||||
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||||
|
|||||||
@@ -7,62 +7,84 @@ from tests.helpers import generators
|
|||||||
|
|
||||||
# Test admin routes require admin privileges
|
# Test admin routes require admin privileges
|
||||||
|
|
||||||
|
|
||||||
def test_read_admin_unauthorized(client: TestClient) -> None:
|
def test_read_admin_unauthorized(client: TestClient) -> None:
|
||||||
"""Test accessing admin route without authentication."""
|
"""Test accessing admin route without authentication."""
|
||||||
response = client.get("/api/admin/")
|
response = client.get("/api/admin/")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
def test_read_admin_forbidden(db: Session, client: TestClient) -> None:
|
def test_read_admin_forbidden(db: Session, client: TestClient) -> None:
|
||||||
"""Test accessing admin route as a non-admin user."""
|
"""Test accessing admin route as a non-admin user."""
|
||||||
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
|
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
|
|
||||||
response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"})
|
response = client.get(
|
||||||
|
"/api/admin/", headers={"Authorization": f"Bearer {access_token}"}
|
||||||
|
)
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
def test_read_admin_success(db: Session, client: TestClient) -> None:
|
def test_read_admin_success(db: Session, client: TestClient) -> None:
|
||||||
"""Test accessing admin route as an admin user."""
|
"""Test accessing admin route as an admin user."""
|
||||||
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
|
admin_user, password = generators.create_user(
|
||||||
|
db, is_admin=True
|
||||||
|
) # Use is_admin=True
|
||||||
login_rsp = generators.login(db, admin_user.username, password)
|
login_rsp = generators.login(db, admin_user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
|
|
||||||
response = client.get("/api/admin/", headers={"Authorization": f"Bearer {access_token}"})
|
response = client.get(
|
||||||
|
"/api/admin/", headers={"Authorization": f"Bearer {access_token}"}
|
||||||
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json() == {"message": "Admin route"}
|
assert response.json() == {"message": "Admin route"}
|
||||||
|
|
||||||
|
|
||||||
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
|
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
|
||||||
def test_clear_db_soft(mock_cleardb_delay, db: Session, client: TestClient) -> None:
|
def test_clear_db_soft(mock_cleardb_delay, db: Session, client: TestClient) -> None:
|
||||||
"""Test soft clearing the database as admin."""
|
"""Test soft clearing the database as admin."""
|
||||||
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
|
admin_user, password = generators.create_user(
|
||||||
|
db, is_admin=True
|
||||||
|
) # Use is_admin=True
|
||||||
login_rsp = generators.login(db, admin_user.username, password)
|
login_rsp = generators.login(db, admin_user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/admin/cleardb",
|
"/api/admin/cleardb",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
json={"hard": False}
|
json={"hard": False},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json() == {"message": "Clearing database in the background", "hard": False}
|
assert response.json() == {
|
||||||
|
"message": "Clearing database in the background",
|
||||||
|
"hard": False,
|
||||||
|
}
|
||||||
mock_cleardb_delay.assert_called_once_with(False)
|
mock_cleardb_delay.assert_called_once_with(False)
|
||||||
|
|
||||||
|
|
||||||
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
|
@patch("modules.admin.api.cleardb.delay") # Mock the celery task
|
||||||
def test_clear_db_hard(mock_cleardb_delay, db: Session, client: TestClient) -> None:
|
def test_clear_db_hard(mock_cleardb_delay, db: Session, client: TestClient) -> None:
|
||||||
"""Test hard clearing the database as admin."""
|
"""Test hard clearing the database as admin."""
|
||||||
admin_user, password = generators.create_user(db, is_admin=True) # Use is_admin=True
|
admin_user, password = generators.create_user(
|
||||||
|
db, is_admin=True
|
||||||
|
) # Use is_admin=True
|
||||||
login_rsp = generators.login(db, admin_user.username, password)
|
login_rsp = generators.login(db, admin_user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/admin/cleardb",
|
"/api/admin/cleardb",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
json={"hard": True}
|
json={"hard": True},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json() == {"message": "Clearing database in the background", "hard": True}
|
assert response.json() == {
|
||||||
|
"message": "Clearing database in the background",
|
||||||
|
"hard": True,
|
||||||
|
}
|
||||||
mock_cleardb_delay.assert_called_once_with(True)
|
mock_cleardb_delay.assert_called_once_with(True)
|
||||||
|
|
||||||
|
|
||||||
def test_clear_db_forbidden(db: Session, client: TestClient) -> None:
|
def test_clear_db_forbidden(db: Session, client: TestClient) -> None:
|
||||||
"""Test clearing the database as a non-admin user."""
|
"""Test clearing the database as a non-admin user."""
|
||||||
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
|
user, password = generators.create_user(db, is_admin=False) # Use is_admin=False
|
||||||
@@ -72,6 +94,6 @@ def test_clear_db_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/admin/cleardb",
|
"/api/admin/cleardb",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
json={"hard": False}
|
json={"hard": False},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ def test_register(client: TestClient) -> None:
|
|||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_201_CREATED
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
|
||||||
|
|
||||||
def test_login(db: Session, client: TestClient) -> None:
|
def test_login(db: Session, client: TestClient) -> None:
|
||||||
user, unhashed_password = generators.create_user(db)
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ def test_login(db: Session, client: TestClient) -> None:
|
|||||||
assert "token_type" in response_data
|
assert "token_type" in response_data
|
||||||
assert response_data["token_type"] == "bearer"
|
assert response_data["token_type"] == "bearer"
|
||||||
|
|
||||||
|
|
||||||
def test_refresh_token(db: Session, client: TestClient) -> None:
|
def test_refresh_token(db: Session, client: TestClient) -> None:
|
||||||
user, unhashed_password = generators.create_user(db)
|
user, unhashed_password = generators.create_user(db)
|
||||||
rsp = generators.login(db, user.username, unhashed_password)
|
rsp = generators.login(db, user.username, unhashed_password)
|
||||||
@@ -61,7 +63,10 @@ def test_refresh_token(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/auth/refresh",
|
"/api/auth/refresh",
|
||||||
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
json={"refresh_token": refresh_token},
|
json={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
@@ -70,7 +75,10 @@ def test_refresh_token(db: Session, client: TestClient) -> None:
|
|||||||
assert "access_token" in response_data
|
assert "access_token" in response_data
|
||||||
assert "token_type" in response_data
|
assert "token_type" in response_data
|
||||||
assert response_data["token_type"] == "bearer"
|
assert response_data["token_type"] == "bearer"
|
||||||
assert response_data["access_token"] != access_token # Ensure the token is refreshed
|
assert (
|
||||||
|
response_data["access_token"] != access_token
|
||||||
|
) # Ensure the token is refreshed
|
||||||
|
|
||||||
|
|
||||||
def test_logout(db: Session, client: TestClient) -> None:
|
def test_logout(db: Session, client: TestClient) -> None:
|
||||||
user, unhashed_password = generators.create_user(db)
|
user, unhashed_password = generators.create_user(db)
|
||||||
@@ -80,13 +88,18 @@ def test_logout(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/auth/logout",
|
"/api/auth/logout",
|
||||||
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
json={"refresh_token": refresh_token},
|
json={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
# Verify that the token is blacklisted
|
# Verify that the token is blacklisted
|
||||||
blacklisted_token = db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
|
blacklisted_token = (
|
||||||
|
db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
|
||||||
|
)
|
||||||
assert blacklisted_token is not None
|
assert blacklisted_token is not None
|
||||||
|
|
||||||
# Verify that we can't still actually do anything
|
# Verify that we can't still actually do anything
|
||||||
@@ -98,7 +111,10 @@ def test_logout(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/auth/refresh",
|
"/api/auth/refresh",
|
||||||
headers={"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"},
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
json={"refresh_token": refresh_token},
|
json={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
@@ -106,7 +122,9 @@ def test_logout(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
def test_get_me(db: Session, client: TestClient) -> None:
|
def test_get_me(db: Session, client: TestClient) -> None:
|
||||||
user, unhashed_password = generators.create_user(db)
|
user, unhashed_password = generators.create_user(db)
|
||||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
access_token = generators.login(db, user.username, unhashed_password)[
|
||||||
|
"access_token"
|
||||||
|
]
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/user/me",
|
"/api/user/me",
|
||||||
@@ -119,14 +137,18 @@ def test_get_me(db: Session, client: TestClient) -> None:
|
|||||||
assert response_data["uuid"] == user.uuid
|
assert response_data["uuid"] == user.uuid
|
||||||
assert response_data["username"] == user.username
|
assert response_data["username"] == user.username
|
||||||
|
|
||||||
|
|
||||||
def test_get_me_unauthorized(client: TestClient) -> None:
|
def test_get_me_unauthorized(client: TestClient) -> None:
|
||||||
### This test should fail (unauthorized) because the user isn't logged in
|
### This test should fail (unauthorized) because the user isn't logged in
|
||||||
response = client.get("/api/user/me")
|
response = client.get("/api/user/me")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
def test_get_user(db: Session, client: TestClient) -> None:
|
def test_get_user(db: Session, client: TestClient) -> None:
|
||||||
user, unhashed_password = generators.create_user(db)
|
user, unhashed_password = generators.create_user(db)
|
||||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
access_token = generators.login(db, user.username, unhashed_password)[
|
||||||
|
"access_token"
|
||||||
|
]
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/user/{user.username}",
|
f"/api/user/{user.username}",
|
||||||
@@ -139,11 +161,14 @@ def test_get_user(db: Session, client: TestClient) -> None:
|
|||||||
assert response_data["uuid"] == user.uuid
|
assert response_data["uuid"] == user.uuid
|
||||||
assert response_data["username"] == user.username
|
assert response_data["username"] == user.username
|
||||||
|
|
||||||
|
|
||||||
def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
|
def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
|
||||||
### This test should fail (unauthorized) because the user isn't us
|
### This test should fail (unauthorized) because the user isn't us
|
||||||
user, unhashed_password = generators.create_user(db)
|
user, unhashed_password = generators.create_user(db)
|
||||||
user2, _ = generators.create_user(db)
|
user2, _ = generators.create_user(db)
|
||||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
access_token = generators.login(db, user.username, unhashed_password)[
|
||||||
|
"access_token"
|
||||||
|
]
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/user/{user2.username}",
|
f"/api/user/{user2.username}",
|
||||||
@@ -151,11 +176,14 @@ def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
|
|||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
def test_update_user(db: Session, client: TestClient) -> None:
|
def test_update_user(db: Session, client: TestClient) -> None:
|
||||||
user, unhashed_password = generators.create_user(db)
|
user, unhashed_password = generators.create_user(db)
|
||||||
new_name = fake.name()
|
new_name = fake.name()
|
||||||
|
|
||||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
access_token = generators.login(db, user.username, unhashed_password)[
|
||||||
|
"access_token"
|
||||||
|
]
|
||||||
response = client.patch(
|
response = client.patch(
|
||||||
f"/api/user/{user.username}",
|
f"/api/user/{user.username}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
@@ -168,7 +196,9 @@ def test_update_user(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
def test_delete_user(db: Session, client: TestClient) -> None:
|
def test_delete_user(db: Session, client: TestClient) -> None:
|
||||||
user, unhashed_password = generators.create_user(db)
|
user, unhashed_password = generators.create_user(db)
|
||||||
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
access_token = generators.login(db, user.username, unhashed_password)[
|
||||||
|
"access_token"
|
||||||
|
]
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
f"/api/user/{user.username}",
|
f"/api/user/{user.username}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
@@ -179,6 +209,7 @@ def test_delete_user(db: Session, client: TestClient) -> None:
|
|||||||
deleted_user = db.query(User).filter(User.username == user.username).first()
|
deleted_user = db.query(User).filter(User.username == user.username).first()
|
||||||
assert deleted_user is None
|
assert deleted_user is None
|
||||||
|
|
||||||
|
|
||||||
def test_get_user_forbidden(db: Session, client: TestClient) -> None:
|
def test_get_user_forbidden(db: Session, client: TestClient) -> None:
|
||||||
"""Test getting another user's profile (should be forbidden)."""
|
"""Test getting another user's profile (should be forbidden)."""
|
||||||
user1, password_user1 = generators.create_user(db, username="user1_get_forbidden")
|
user1, password_user1 = generators.create_user(db, username="user1_get_forbidden")
|
||||||
@@ -195,9 +226,12 @@ def test_get_user_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
def test_update_user_forbidden(db: Session, client: TestClient) -> None:
|
def test_update_user_forbidden(db: Session, client: TestClient) -> None:
|
||||||
"""Test updating another user's profile (should be forbidden)."""
|
"""Test updating another user's profile (should be forbidden)."""
|
||||||
user1, password_user1 = generators.create_user(db, username="user1_update_forbidden")
|
user1, password_user1 = generators.create_user(
|
||||||
|
db, username="user1_update_forbidden"
|
||||||
|
)
|
||||||
user2, _ = generators.create_user(db, username="user2_update_forbidden")
|
user2, _ = generators.create_user(db, username="user2_update_forbidden")
|
||||||
new_name = fake.name()
|
new_name = fake.name()
|
||||||
|
|
||||||
@@ -213,9 +247,12 @@ def test_update_user_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
def test_delete_user_forbidden(db: Session, client: TestClient) -> None:
|
def test_delete_user_forbidden(db: Session, client: TestClient) -> None:
|
||||||
"""Test deleting another user's profile (should be forbidden)."""
|
"""Test deleting another user's profile (should be forbidden)."""
|
||||||
user1, password_user1 = generators.create_user(db, username="user1_delete_forbidden")
|
user1, password_user1 = generators.create_user(
|
||||||
|
db, username="user1_delete_forbidden"
|
||||||
|
)
|
||||||
user2, _ = generators.create_user(db, username="user2_delete_forbidden")
|
user2, _ = generators.create_user(db, username="user2_delete_forbidden")
|
||||||
|
|
||||||
# Log in as user1
|
# Log in as user1
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from tests.helpers import generators
|
|||||||
from modules.calendar.models import CalendarEvent # Assuming model exists
|
from modules.calendar.models import CalendarEvent # Assuming model exists
|
||||||
from tests.conftest import fake
|
from tests.conftest import fake
|
||||||
|
|
||||||
|
|
||||||
# Helper function to create an event payload
|
# Helper function to create an event payload
|
||||||
def create_event_payload(start_offset_days=0, end_offset_days=1):
|
def create_event_payload(start_offset_days=0, end_offset_days=1):
|
||||||
start_time = datetime.utcnow() + timedelta(days=start_offset_days)
|
start_time = datetime.utcnow() + timedelta(days=start_offset_days)
|
||||||
@@ -19,14 +20,17 @@ def create_event_payload(start_offset_days=0, end_offset_days=1):
|
|||||||
"all_day": fake.boolean(),
|
"all_day": fake.boolean(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# --- Test Create Event ---
|
# --- Test Create Event ---
|
||||||
|
|
||||||
|
|
||||||
def test_create_event_unauthorized(client: TestClient) -> None:
|
def test_create_event_unauthorized(client: TestClient) -> None:
|
||||||
"""Test creating an event without authentication."""
|
"""Test creating an event without authentication."""
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
response = client.post("/api/calendar/events", json=payload)
|
response = client.post("/api/calendar/events", json=payload)
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
def test_create_event_success(db: Session, client: TestClient) -> None:
|
def test_create_event_success(db: Session, client: TestClient) -> None:
|
||||||
"""Test creating a calendar event successfully."""
|
"""Test creating a calendar event successfully."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
@@ -37,9 +41,11 @@ def test_create_event_success(db: Session, client: TestClient) -> None:
|
|||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/calendar/events",
|
"/api/calendar/events",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
json=payload
|
json=payload,
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_201_CREATED # Change expected status to 201
|
assert (
|
||||||
|
response.status_code == status.HTTP_201_CREATED
|
||||||
|
) # Change expected status to 201
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["title"] == payload["title"]
|
assert data["title"] == payload["title"]
|
||||||
assert data["description"] == payload["description"]
|
assert data["description"] == payload["description"]
|
||||||
@@ -56,13 +62,16 @@ def test_create_event_success(db: Session, client: TestClient) -> None:
|
|||||||
assert event_in_db.user_id == user.id
|
assert event_in_db.user_id == user.id
|
||||||
assert event_in_db.title == payload["title"]
|
assert event_in_db.title == payload["title"]
|
||||||
|
|
||||||
|
|
||||||
# --- Test Get Events ---
|
# --- Test Get Events ---
|
||||||
|
|
||||||
|
|
||||||
def test_get_events_unauthorized(client: TestClient) -> None:
|
def test_get_events_unauthorized(client: TestClient) -> None:
|
||||||
"""Test getting events without authentication."""
|
"""Test getting events without authentication."""
|
||||||
response = client.get("/api/calendar/events")
|
response = client.get("/api/calendar/events")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
def test_get_events_success(db: Session, client: TestClient) -> None:
|
def test_get_events_success(db: Session, client: TestClient) -> None:
|
||||||
"""Test getting all calendar events for a user."""
|
"""Test getting all calendar events for a user."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
@@ -71,21 +80,31 @@ def test_get_events_success(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
# Create a couple of events for the user
|
# Create a couple of events for the user
|
||||||
payload1 = create_event_payload(0, 1)
|
payload1 = create_event_payload(0, 1)
|
||||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1)
|
client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload1,
|
||||||
|
)
|
||||||
payload2 = create_event_payload(2, 3)
|
payload2 = create_event_payload(2, 3)
|
||||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2)
|
client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload2,
|
||||||
|
)
|
||||||
|
|
||||||
# Create an event for another user (should not be returned)
|
# Create an event for another user (should not be returned)
|
||||||
other_user, other_password = generators.create_user(db)
|
other_user, other_password = generators.create_user(db)
|
||||||
other_login_rsp = generators.login(db, other_user.username, other_password)
|
other_login_rsp = generators.login(db, other_user.username, other_password)
|
||||||
other_access_token = other_login_rsp["access_token"]
|
other_access_token = other_login_rsp["access_token"]
|
||||||
other_payload = create_event_payload(4, 5)
|
other_payload = create_event_payload(4, 5)
|
||||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {other_access_token}"}, json=other_payload)
|
client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {other_access_token}"},
|
||||||
|
json=other_payload,
|
||||||
|
)
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/calendar/events",
|
"/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}
|
||||||
headers={"Authorization": f"Bearer {access_token}"}
|
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -104,11 +123,23 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
# Create events
|
# Create events
|
||||||
payload1 = create_event_payload(0, 1) # Today -> Tomorrow
|
payload1 = create_event_payload(0, 1) # Today -> Tomorrow
|
||||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload1)
|
client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload1,
|
||||||
|
)
|
||||||
payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days
|
payload2 = create_event_payload(5, 6) # In 5 days -> In 6 days
|
||||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload2)
|
client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload2,
|
||||||
|
)
|
||||||
payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days
|
payload3 = create_event_payload(10, 11) # In 10 days -> In 11 days
|
||||||
client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload3)
|
client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload3,
|
||||||
|
)
|
||||||
|
|
||||||
# Filter for events starting within the next week
|
# Filter for events starting within the next week
|
||||||
start_filter = datetime.utcnow().isoformat()
|
start_filter = datetime.utcnow().isoformat()
|
||||||
@@ -117,7 +148,7 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
|
|||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/calendar/events",
|
"/api/calendar/events",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
params={"start": start_filter, "end": end_filter}
|
params={"start": start_filter, "end": end_filter},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -130,7 +161,7 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
|
|||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/calendar/events",
|
"/api/calendar/events",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
params={"start": start_filter_late}
|
params={"start": start_filter_late},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -140,30 +171,40 @@ def test_get_events_filtered(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
# --- Test Get Event By ID ---
|
# --- Test Get Event By ID ---
|
||||||
|
|
||||||
|
|
||||||
def test_get_event_by_id_unauthorized(db: Session, client: TestClient) -> None:
|
def test_get_event_by_id_unauthorized(db: Session, client: TestClient) -> None:
|
||||||
"""Test getting a specific event without authentication."""
|
"""Test getting a specific event without authentication."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
create_response = client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
|
|
||||||
response = client.get(f"/api/calendar/events/{event_id}")
|
response = client.get(f"/api/calendar/events/{event_id}")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
|
def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
|
||||||
"""Test getting a specific event successfully."""
|
"""Test getting a specific event successfully."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
create_response = client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/calendar/events/{event_id}",
|
f"/api/calendar/events/{event_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"}
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -171,6 +212,7 @@ def test_get_event_by_id_success(db: Session, client: TestClient) -> None:
|
|||||||
assert data["title"] == payload["title"]
|
assert data["title"] == payload["title"]
|
||||||
assert data["user_id"] == user.id
|
assert data["user_id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
|
def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
|
||||||
"""Test getting a non-existent event."""
|
"""Test getting a non-existent event."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
@@ -180,10 +222,11 @@ def test_get_event_by_id_not_found(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/calendar/events/{non_existent_id}",
|
f"/api/calendar/events/{non_existent_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"}
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
|
def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
|
||||||
"""Test getting another user's event."""
|
"""Test getting another user's event."""
|
||||||
user1, password_user1 = generators.create_user(db)
|
user1, password_user1 = generators.create_user(db)
|
||||||
@@ -193,7 +236,11 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
login_rsp1 = generators.login(db, user1.username, password_user1)
|
login_rsp1 = generators.login(db, user1.username, password_user1)
|
||||||
access_token1 = login_rsp1["access_token"]
|
access_token1 = login_rsp1["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
|
create_response = client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token1}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
|
|
||||||
# Log in as user2 and try to get user1's event
|
# Log in as user2 and try to get user1's event
|
||||||
@@ -202,45 +249,60 @@ def test_get_event_by_id_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/calendar/events/{event_id}",
|
f"/api/calendar/events/{event_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token2}"}
|
headers={"Authorization": f"Bearer {access_token2}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match
|
assert (
|
||||||
|
response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
) # Service layer returns 404 if user_id doesn't match
|
||||||
|
|
||||||
|
|
||||||
# --- Test Update Event ---
|
# --- Test Update Event ---
|
||||||
|
|
||||||
|
|
||||||
def test_update_event_unauthorized(db: Session, client: TestClient) -> None:
|
def test_update_event_unauthorized(db: Session, client: TestClient) -> None:
|
||||||
"""Test updating an event without authentication."""
|
"""Test updating an event without authentication."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
create_response = client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
update_payload = {"title": "Updated Title"}
|
update_payload = {"title": "Updated Title"}
|
||||||
|
|
||||||
response = client.patch(f"/api/calendar/events/{event_id}", json=update_payload)
|
response = client.patch(f"/api/calendar/events/{event_id}", json=update_payload)
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
def test_update_event_success(db: Session, client: TestClient) -> None:
|
def test_update_event_success(db: Session, client: TestClient) -> None:
|
||||||
"""Test updating an event successfully."""
|
"""Test updating an event successfully."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
create_response = client.post(
|
||||||
assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
create_response.status_code == status.HTTP_201_CREATED
|
||||||
|
) # Ensure creation check uses 201
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
|
|
||||||
update_payload = {
|
update_payload = {
|
||||||
"title": "Updated Title",
|
"title": "Updated Title",
|
||||||
"description": "Updated description.",
|
"description": "Updated description.",
|
||||||
"all_day": not payload["all_day"] # Toggle all_day
|
"all_day": not payload["all_day"], # Toggle all_day
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.patch(
|
response = client.patch(
|
||||||
f"/api/calendar/events/{event_id}",
|
f"/api/calendar/events/{event_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
json=update_payload
|
json=update_payload,
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -258,6 +320,7 @@ def test_update_event_success(db: Session, client: TestClient) -> None:
|
|||||||
assert event_in_db.description == update_payload["description"]
|
assert event_in_db.description == update_payload["description"]
|
||||||
assert event_in_db.all_day == update_payload["all_day"]
|
assert event_in_db.all_day == update_payload["all_day"]
|
||||||
|
|
||||||
|
|
||||||
def test_update_event_not_found(db: Session, client: TestClient) -> None:
|
def test_update_event_not_found(db: Session, client: TestClient) -> None:
|
||||||
"""Test updating a non-existent event."""
|
"""Test updating a non-existent event."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
@@ -269,10 +332,11 @@ def test_update_event_not_found(db: Session, client: TestClient) -> None:
|
|||||||
response = client.patch(
|
response = client.patch(
|
||||||
f"/api/calendar/events/{non_existent_id}",
|
f"/api/calendar/events/{non_existent_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
json=update_payload
|
json=update_payload,
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
def test_update_event_forbidden(db: Session, client: TestClient) -> None:
|
def test_update_event_forbidden(db: Session, client: TestClient) -> None:
|
||||||
"""Test updating another user's event."""
|
"""Test updating another user's event."""
|
||||||
user1, password_user1 = generators.create_user(db)
|
user1, password_user1 = generators.create_user(db)
|
||||||
@@ -282,7 +346,11 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
login_rsp1 = generators.login(db, user1.username, password_user1)
|
login_rsp1 = generators.login(db, user1.username, password_user1)
|
||||||
access_token1 = login_rsp1["access_token"]
|
access_token1 = login_rsp1["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
|
create_response = client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token1}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
|
|
||||||
# Log in as user2 and try to update user1's event
|
# Log in as user2 and try to update user1's event
|
||||||
@@ -293,32 +361,47 @@ def test_update_event_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
response = client.patch(
|
response = client.patch(
|
||||||
f"/api/calendar/events/{event_id}",
|
f"/api/calendar/events/{event_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token2}"},
|
headers={"Authorization": f"Bearer {access_token2}"},
|
||||||
json=update_payload
|
json=update_payload,
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND # Service layer returns 404 if user_id doesn't match
|
assert (
|
||||||
|
response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
) # Service layer returns 404 if user_id doesn't match
|
||||||
|
|
||||||
|
|
||||||
# --- Test Delete Event ---
|
# --- Test Delete Event ---
|
||||||
|
|
||||||
|
|
||||||
def test_delete_event_unauthorized(db: Session, client: TestClient) -> None:
|
def test_delete_event_unauthorized(db: Session, client: TestClient) -> None:
|
||||||
"""Test deleting an event without authentication."""
|
"""Test deleting an event without authentication."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
create_response = client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
|
|
||||||
response = client.delete(f"/api/calendar/events/{event_id}")
|
response = client.delete(f"/api/calendar/events/{event_id}")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
def test_delete_event_success(db: Session, client: TestClient) -> None:
|
def test_delete_event_success(db: Session, client: TestClient) -> None:
|
||||||
"""Test deleting an event successfully."""
|
"""Test deleting an event successfully."""
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
access_token = login_rsp["access_token"]
|
access_token = login_rsp["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token}"}, json=payload)
|
create_response = client.post(
|
||||||
assert create_response.status_code == status.HTTP_201_CREATED # Ensure creation check uses 201
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
create_response.status_code == status.HTTP_201_CREATED
|
||||||
|
) # Ensure creation check uses 201
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
|
|
||||||
# Verify event exists before delete
|
# Verify event exists before delete
|
||||||
@@ -327,7 +410,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
f"/api/calendar/events/{event_id}",
|
f"/api/calendar/events/{event_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"}
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_204_NO_CONTENT
|
assert response.status_code == status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
@@ -338,7 +421,7 @@ def test_delete_event_success(db: Session, client: TestClient) -> None:
|
|||||||
# Try getting the deleted event (should be 404)
|
# Try getting the deleted event (should be 404)
|
||||||
get_response = client.get(
|
get_response = client.get(
|
||||||
f"/api/calendar/events/{event_id}",
|
f"/api/calendar/events/{event_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"}
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
assert get_response.status_code == status.HTTP_404_NOT_FOUND
|
assert get_response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
@@ -352,7 +435,7 @@ def test_delete_event_not_found(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
f"/api/calendar/events/{non_existent_id}",
|
f"/api/calendar/events/{non_existent_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"}
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
)
|
)
|
||||||
# The service layer raises NotFound, which should result in 404
|
# The service layer raises NotFound, which should result in 404
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
@@ -367,7 +450,11 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
login_rsp1 = generators.login(db, user1.username, password_user1)
|
login_rsp1 = generators.login(db, user1.username, password_user1)
|
||||||
access_token1 = login_rsp1["access_token"]
|
access_token1 = login_rsp1["access_token"]
|
||||||
payload = create_event_payload()
|
payload = create_event_payload()
|
||||||
create_response = client.post("/api/calendar/events", headers={"Authorization": f"Bearer {access_token1}"}, json=payload)
|
create_response = client.post(
|
||||||
|
"/api/calendar/events",
|
||||||
|
headers={"Authorization": f"Bearer {access_token1}"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
event_id = create_response.json()["id"]
|
event_id = create_response.json()["id"]
|
||||||
|
|
||||||
# Log in as user2 and try to delete user1's event
|
# Log in as user2 and try to delete user1's event
|
||||||
@@ -376,7 +463,7 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
|
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
f"/api/calendar/events/{event_id}",
|
f"/api/calendar/events/{event_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token2}"}
|
headers={"Authorization": f"Bearer {access_token2}"},
|
||||||
)
|
)
|
||||||
# The service layer raises NotFound if user_id doesn't match, resulting in 404
|
# The service layer raises NotFound if user_id doesn't match, resulting in 404
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
@@ -385,4 +472,3 @@ def test_delete_event_forbidden(db: Session, client: TestClient) -> None:
|
|||||||
event_in_db = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first()
|
event_in_db = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first()
|
||||||
assert event_in_db is not None
|
assert event_in_db is not None
|
||||||
assert event_in_db.user_id == user1.id
|
assert event_in_db.user_id == user1.id
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from fastapi.testclient import TestClient
|
|||||||
|
|
||||||
# No database needed for this simple test
|
# No database needed for this simple test
|
||||||
|
|
||||||
|
|
||||||
def test_health_check(client: TestClient):
|
def test_health_check(client: TestClient):
|
||||||
"""Test the health check endpoint."""
|
"""Test the health check endpoint."""
|
||||||
response = client.get("/api/health")
|
response = client.get("/api/health")
|
||||||
|
|||||||
@@ -7,24 +7,37 @@ from datetime import datetime
|
|||||||
|
|
||||||
from tests.helpers import generators
|
from tests.helpers import generators
|
||||||
from modules.nlp.schemas import ProcessCommandResponse
|
from modules.nlp.schemas import ProcessCommandResponse
|
||||||
from modules.nlp.models import MessageSender, ChatMessage # Import necessary models/enums
|
from modules.nlp.models import (
|
||||||
|
MessageSender,
|
||||||
|
ChatMessage,
|
||||||
|
) # Import necessary models/enums
|
||||||
|
|
||||||
|
|
||||||
# --- Mocks ---
|
# --- Mocks ---
|
||||||
# Mock the external AI call and internal service functions
|
# Mock the external AI call and internal service functions
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_nlp_services():
|
def mock_nlp_services():
|
||||||
with patch("modules.nlp.api.process_request") as mock_process, \
|
with patch("modules.nlp.api.process_request") as mock_process, patch(
|
||||||
patch("modules.nlp.api.ask_ai") as mock_ask, \
|
"modules.nlp.api.ask_ai"
|
||||||
patch("modules.nlp.api.save_chat_message") as mock_save, \
|
) as mock_ask, patch("modules.nlp.api.save_chat_message") as mock_save, patch(
|
||||||
patch("modules.nlp.api.get_chat_history") as mock_get_history, \
|
"modules.nlp.api.get_chat_history"
|
||||||
patch("modules.nlp.api.create_calendar_event") as mock_create_event, \
|
) as mock_get_history, patch(
|
||||||
patch("modules.nlp.api.get_calendar_events") as mock_get_events, \
|
"modules.nlp.api.create_calendar_event"
|
||||||
patch("modules.nlp.api.update_calendar_event") as mock_update_event, \
|
) as mock_create_event, patch(
|
||||||
patch("modules.nlp.api.delete_calendar_event") as mock_delete_event, \
|
"modules.nlp.api.get_calendar_events"
|
||||||
patch("modules.nlp.api.todo_service.create_todo") as mock_create_todo, \
|
) as mock_get_events, patch(
|
||||||
patch("modules.nlp.api.todo_service.get_todos") as mock_get_todos, \
|
"modules.nlp.api.update_calendar_event"
|
||||||
patch("modules.nlp.api.todo_service.update_todo") as mock_update_todo, \
|
) as mock_update_event, patch(
|
||||||
patch("modules.nlp.api.todo_service.delete_todo") as mock_delete_todo:
|
"modules.nlp.api.delete_calendar_event"
|
||||||
|
) as mock_delete_event, patch(
|
||||||
|
"modules.nlp.api.todo_service.create_todo"
|
||||||
|
) as mock_create_todo, patch(
|
||||||
|
"modules.nlp.api.todo_service.get_todos"
|
||||||
|
) as mock_get_todos, patch(
|
||||||
|
"modules.nlp.api.todo_service.update_todo"
|
||||||
|
) as mock_update_todo, patch(
|
||||||
|
"modules.nlp.api.todo_service.delete_todo"
|
||||||
|
) as mock_delete_todo:
|
||||||
mocks = {
|
mocks = {
|
||||||
"process_request": mock_process,
|
"process_request": mock_process,
|
||||||
"ask_ai": mock_ask,
|
"ask_ai": mock_ask,
|
||||||
@@ -41,21 +54,24 @@ def mock_nlp_services():
|
|||||||
}
|
}
|
||||||
yield mocks
|
yield mocks
|
||||||
|
|
||||||
|
|
||||||
# --- Helper Function ---
|
# --- Helper Function ---
|
||||||
def _login_user(db: Session, client: TestClient):
|
def _login_user(db: Session, client: TestClient):
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
return user, login_rsp["access_token"], login_rsp["refresh_token"]
|
return user, login_rsp["access_token"], login_rsp["refresh_token"]
|
||||||
|
|
||||||
|
|
||||||
# --- Tests for /process-command ---
|
# --- Tests for /process-command ---
|
||||||
|
|
||||||
|
|
||||||
def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services):
|
def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
user_input = "What is the capital of France?"
|
user_input = "What is the capital of France?"
|
||||||
mock_nlp_services["process_request"].return_value = {
|
mock_nlp_services["process_request"].return_value = {
|
||||||
"intent": "ask_ai",
|
"intent": "ask_ai",
|
||||||
"params": {"request": user_input},
|
"params": {"request": user_input},
|
||||||
"response_text": "Let me check that for you."
|
"response_text": "Let me check that for you.",
|
||||||
}
|
}
|
||||||
mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris."
|
mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris."
|
||||||
|
|
||||||
@@ -63,25 +79,45 @@ def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_servic
|
|||||||
"/api/nlp/process-command",
|
"/api/nlp/process-command",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json={"user_input": user_input}
|
json={"user_input": user_input},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json() == ProcessCommandResponse(responses=["Let me check that for you.", "The capital of France is Paris."]).model_dump()
|
assert (
|
||||||
|
response.json()
|
||||||
|
== ProcessCommandResponse(
|
||||||
|
responses=["Let me check that for you.", "The capital of France is Paris."]
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
# Verify save calls: user message, initial AI response, final AI answer
|
# Verify save calls: user message, initial AI response, final AI answer
|
||||||
assert mock_nlp_services["save_chat_message"].call_count == 3
|
assert mock_nlp_services["save_chat_message"].call_count == 3
|
||||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you.")
|
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
||||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text="The capital of France is Paris.")
|
)
|
||||||
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||||
|
db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you."
|
||||||
|
)
|
||||||
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||||
|
db,
|
||||||
|
user_id=user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text="The capital of France is Paris.",
|
||||||
|
)
|
||||||
mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input)
|
mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input)
|
||||||
|
|
||||||
def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_services):
|
|
||||||
|
def test_process_command_get_calendar(
|
||||||
|
client: TestClient, db: Session, mock_nlp_services
|
||||||
|
):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
user_input = "What are my events today?"
|
user_input = "What are my events today?"
|
||||||
mock_nlp_services["process_request"].return_value = {
|
mock_nlp_services["process_request"].return_value = {
|
||||||
"intent": "get_calendar_events",
|
"intent": "get_calendar_events",
|
||||||
"params": {"start": "2024-01-01T00:00:00Z", "end": "2024-01-01T23:59:59Z"}, # Example params
|
"params": {
|
||||||
"response_text": "Okay, fetching your events."
|
"start": "2024-01-01T00:00:00Z",
|
||||||
|
"end": "2024-01-01T23:59:59Z",
|
||||||
|
}, # Example params
|
||||||
|
"response_text": "Okay, fetching your events.",
|
||||||
}
|
}
|
||||||
# Mock the actual event model returned by the service
|
# Mock the actual event model returned by the service
|
||||||
mock_event = MagicMock()
|
mock_event = MagicMock()
|
||||||
@@ -94,26 +130,32 @@ def test_process_command_get_calendar(client: TestClient, db: Session, mock_nlp_
|
|||||||
"/api/nlp/process-command",
|
"/api/nlp/process-command",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json={"user_input": user_input}
|
json={"user_input": user_input},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
"Okay, fetching your events.",
|
"Okay, fetching your events.",
|
||||||
"Here are the events:",
|
"Here are the events:",
|
||||||
"- Team Meeting (2024-01-01 10:00 - 11:00)"
|
"- Team Meeting (2024-01-01 10:00 - 11:00)",
|
||||||
]
|
]
|
||||||
assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump()
|
assert (
|
||||||
assert mock_nlp_services["save_chat_message"].call_count == 4 # User, Initial AI, Header, Event
|
response.json()
|
||||||
|
== ProcessCommandResponse(responses=expected_responses).model_dump()
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_nlp_services["save_chat_message"].call_count == 4
|
||||||
|
) # User, Initial AI, Header, Event
|
||||||
mock_nlp_services["get_calendar_events"].assert_called_once()
|
mock_nlp_services["get_calendar_events"].assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services):
|
def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
user_input = "Add buy milk to my list"
|
user_input = "Add buy milk to my list"
|
||||||
mock_nlp_services["process_request"].return_value = {
|
mock_nlp_services["process_request"].return_value = {
|
||||||
"intent": "add_todo",
|
"intent": "add_todo",
|
||||||
"params": {"task": "buy milk"},
|
"params": {"task": "buy milk"},
|
||||||
"response_text": "Adding it now."
|
"response_text": "Adding it now.",
|
||||||
}
|
}
|
||||||
# Mock the actual Todo model returned by the service
|
# Mock the actual Todo model returned by the service
|
||||||
mock_todo = MagicMock()
|
mock_todo = MagicMock()
|
||||||
@@ -125,81 +167,119 @@ def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_serv
|
|||||||
"/api/nlp/process-command",
|
"/api/nlp/process-command",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json={"user_input": user_input}
|
json={"user_input": user_input},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."]
|
expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."]
|
||||||
assert response.json() == ProcessCommandResponse(responses=expected_responses).model_dump()
|
assert (
|
||||||
assert mock_nlp_services["save_chat_message"].call_count == 3 # User, Initial AI, Confirmation AI
|
response.json()
|
||||||
|
== ProcessCommandResponse(responses=expected_responses).model_dump()
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_nlp_services["save_chat_message"].call_count == 3
|
||||||
|
) # User, Initial AI, Confirmation AI
|
||||||
mock_nlp_services["create_todo"].assert_called_once()
|
mock_nlp_services["create_todo"].assert_called_once()
|
||||||
|
|
||||||
def test_process_command_clarification(client: TestClient, db: Session, mock_nlp_services):
|
|
||||||
|
def test_process_command_clarification(
|
||||||
|
client: TestClient, db: Session, mock_nlp_services
|
||||||
|
):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
user_input = "Delete the event"
|
user_input = "Delete the event"
|
||||||
clarification_text = "Which event do you mean? Please provide the ID."
|
clarification_text = "Which event do you mean? Please provide the ID."
|
||||||
mock_nlp_services["process_request"].return_value = {
|
mock_nlp_services["process_request"].return_value = {
|
||||||
"intent": "clarification_needed",
|
"intent": "clarification_needed",
|
||||||
"params": {"request": user_input},
|
"params": {"request": user_input},
|
||||||
"response_text": clarification_text
|
"response_text": clarification_text,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/nlp/process-command",
|
"/api/nlp/process-command",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json={"user_input": user_input}
|
json={"user_input": user_input},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json() == ProcessCommandResponse(responses=[clarification_text]).model_dump()
|
assert (
|
||||||
|
response.json()
|
||||||
|
== ProcessCommandResponse(responses=[clarification_text]).model_dump()
|
||||||
|
)
|
||||||
# Verify save calls: user message, clarification AI response
|
# Verify save calls: user message, clarification AI response
|
||||||
assert mock_nlp_services["save_chat_message"].call_count == 2
|
assert mock_nlp_services["save_chat_message"].call_count == 2
|
||||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=clarification_text)
|
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
||||||
|
)
|
||||||
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||||
|
db, user_id=user.id, sender=MessageSender.AI, text=clarification_text
|
||||||
|
)
|
||||||
# Ensure no action services were called
|
# Ensure no action services were called
|
||||||
mock_nlp_services["delete_calendar_event"].assert_not_called()
|
mock_nlp_services["delete_calendar_event"].assert_not_called()
|
||||||
|
|
||||||
def test_process_command_error_intent(client: TestClient, db: Session, mock_nlp_services):
|
|
||||||
|
def test_process_command_error_intent(
|
||||||
|
client: TestClient, db: Session, mock_nlp_services
|
||||||
|
):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
user_input = "Gibberish request"
|
user_input = "Gibberish request"
|
||||||
error_text = "Sorry, I didn't understand that."
|
error_text = "Sorry, I didn't understand that."
|
||||||
mock_nlp_services["process_request"].return_value = {
|
mock_nlp_services["process_request"].return_value = {
|
||||||
"intent": "error",
|
"intent": "error",
|
||||||
"params": {},
|
"params": {},
|
||||||
"response_text": error_text
|
"response_text": error_text,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/nlp/process-command",
|
"/api/nlp/process-command",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json={"user_input": user_input}
|
json={"user_input": user_input},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
|
assert (
|
||||||
|
response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
|
||||||
|
)
|
||||||
# Verify save calls: user message, error AI response
|
# Verify save calls: user message, error AI response
|
||||||
assert mock_nlp_services["save_chat_message"].call_count == 2
|
assert mock_nlp_services["save_chat_message"].call_count == 2
|
||||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.USER, text=user_input)
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||||
mock_nlp_services["save_chat_message"].assert_any_call(db, user_id=user.id, sender=MessageSender.AI, text=error_text)
|
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
||||||
|
)
|
||||||
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
||||||
|
db, user_id=user.id, sender=MessageSender.AI, text=error_text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# --- Tests for /history ---
|
# --- Tests for /history ---
|
||||||
|
|
||||||
|
|
||||||
def test_get_history(client: TestClient, db: Session, mock_nlp_services):
|
def test_get_history(client: TestClient, db: Session, mock_nlp_services):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
|
|
||||||
# Mock the history data returned by the service
|
# Mock the history data returned by the service
|
||||||
mock_history = [
|
mock_history = [
|
||||||
ChatMessage(id=1, user_id=user.id, sender=MessageSender.USER, text="Hello", timestamp=datetime.now()),
|
ChatMessage(
|
||||||
ChatMessage(id=2, user_id=user.id, sender=MessageSender.AI, text="Hi there!", timestamp=datetime.now())
|
id=1,
|
||||||
|
user_id=user.id,
|
||||||
|
sender=MessageSender.USER,
|
||||||
|
text="Hello",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
),
|
||||||
|
ChatMessage(
|
||||||
|
id=2,
|
||||||
|
user_id=user.id,
|
||||||
|
sender=MessageSender.AI,
|
||||||
|
text="Hi there!",
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
mock_nlp_services["get_chat_history"].return_value = mock_history
|
mock_nlp_services["get_chat_history"].return_value = mock_history
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/nlp/history",
|
"/api/nlp/history",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token}
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
@@ -208,11 +288,15 @@ def test_get_history(client: TestClient, db: Session, mock_nlp_services):
|
|||||||
assert len(response_data) == 2
|
assert len(response_data) == 2
|
||||||
assert response_data[0]["text"] == "Hello"
|
assert response_data[0]["text"] == "Hello"
|
||||||
assert response_data[1]["text"] == "Hi there!"
|
assert response_data[1]["text"] == "Hi there!"
|
||||||
mock_nlp_services["get_chat_history"].assert_called_once_with(db, user_id=user.id, limit=50)
|
mock_nlp_services["get_chat_history"].assert_called_once_with(
|
||||||
|
db, user_id=user.id, limit=50
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_history_unauthorized(client: TestClient):
|
def test_get_history_unauthorized(client: TestClient):
|
||||||
response = client.get("/api/nlp/history")
|
response = client.get("/api/nlp/history")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
# Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.)
|
# Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.)
|
||||||
# Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete)
|
# Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete)
|
||||||
|
|||||||
@@ -5,14 +5,17 @@ from datetime import date
|
|||||||
|
|
||||||
from tests.helpers import generators
|
from tests.helpers import generators
|
||||||
|
|
||||||
|
|
||||||
# Helper Function
|
# Helper Function
|
||||||
def _login_user(db: Session, client: TestClient):
|
def _login_user(db: Session, client: TestClient):
|
||||||
user, password = generators.create_user(db)
|
user, password = generators.create_user(db)
|
||||||
login_rsp = generators.login(db, user.username, password)
|
login_rsp = generators.login(db, user.username, password)
|
||||||
return user, login_rsp["access_token"], login_rsp["refresh_token"]
|
return user, login_rsp["access_token"], login_rsp["refresh_token"]
|
||||||
|
|
||||||
|
|
||||||
# --- Test CRUD Operations ---
|
# --- Test CRUD Operations ---
|
||||||
|
|
||||||
|
|
||||||
def test_create_todo(client: TestClient, db: Session):
|
def test_create_todo(client: TestClient, db: Session):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
today_date = date.today()
|
today_date = date.today()
|
||||||
@@ -20,14 +23,14 @@ def test_create_todo(client: TestClient, db: Session):
|
|||||||
todo_data = {
|
todo_data = {
|
||||||
"task": "Test TODO",
|
"task": "Test TODO",
|
||||||
"date": f"{today_date.isoformat()}T00:00:00",
|
"date": f"{today_date.isoformat()}T00:00:00",
|
||||||
"remind": True
|
"remind": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/todos/",
|
"/api/todos/",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json=todo_data
|
json=todo_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_201_CREATED
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
@@ -39,24 +42,39 @@ def test_create_todo(client: TestClient, db: Session):
|
|||||||
assert "id" in data
|
assert "id" in data
|
||||||
assert data["owner_id"] == user.id
|
assert data["owner_id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
def test_read_todos(client: TestClient, db: Session):
|
def test_read_todos(client: TestClient, db: Session):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
# Create some todos for the user
|
# Create some todos for the user
|
||||||
client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 1"})
|
client.post(
|
||||||
client.post("/api/todos/", headers={"Authorization": f"Bearer {access_token}"}, cookies={"refresh_token": refresh_token}, json={"task": "Todo 2"})
|
"/api/todos/",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
cookies={"refresh_token": refresh_token},
|
||||||
|
json={"task": "Todo 1"},
|
||||||
|
)
|
||||||
|
client.post(
|
||||||
|
"/api/todos/",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
cookies={"refresh_token": refresh_token},
|
||||||
|
json={"task": "Todo 2"},
|
||||||
|
)
|
||||||
|
|
||||||
# Create a todo for another user
|
# Create a todo for another user
|
||||||
other_user, other_password = generators.create_user(db)
|
other_user, other_password = generators.create_user(db)
|
||||||
other_login_rsp = generators.login(db, other_user.username, other_password)
|
other_login_rsp = generators.login(db, other_user.username, other_password)
|
||||||
other_access_token = other_login_rsp["access_token"]
|
other_access_token = other_login_rsp["access_token"]
|
||||||
other_refresh_token = other_login_rsp["refresh_token"]
|
other_refresh_token = other_login_rsp["refresh_token"]
|
||||||
client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"})
|
client.post(
|
||||||
|
"/api/todos/",
|
||||||
|
headers={"Authorization": f"Bearer {other_access_token}"},
|
||||||
|
cookies={"refresh_token": other_refresh_token},
|
||||||
|
json={"task": "Other User Todo"},
|
||||||
|
)
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/todos/",
|
"/api/todos/",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token}
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
@@ -65,20 +83,21 @@ def test_read_todos(client: TestClient, db: Session):
|
|||||||
assert data[0]["task"] == "Todo 1"
|
assert data[0]["task"] == "Todo 1"
|
||||||
assert data[1]["task"] == "Todo 2"
|
assert data[1]["task"] == "Todo 2"
|
||||||
|
|
||||||
|
|
||||||
def test_read_single_todo(client: TestClient, db: Session):
|
def test_read_single_todo(client: TestClient, db: Session):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
create_response = client.post(
|
create_response = client.post(
|
||||||
"/api/todos/",
|
"/api/todos/",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json={"task": "Specific Todo"}
|
json={"task": "Specific Todo"},
|
||||||
)
|
)
|
||||||
todo_id = create_response.json()["id"]
|
todo_id = create_response.json()["id"]
|
||||||
|
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/todos/{todo_id}",
|
f"/api/todos/{todo_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token}
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
@@ -87,15 +106,17 @@ def test_read_single_todo(client: TestClient, db: Session):
|
|||||||
assert data["task"] == "Specific Todo"
|
assert data["task"] == "Specific Todo"
|
||||||
assert data["owner_id"] == user.id
|
assert data["owner_id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
def test_read_single_todo_not_found(client: TestClient, db: Session):
|
def test_read_single_todo_not_found(client: TestClient, db: Session):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
response = client.get(
|
response = client.get(
|
||||||
"/api/todos/9999", # Non-existent ID
|
"/api/todos/9999", # Non-existent ID
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token}
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
def test_read_single_todo_forbidden(client: TestClient, db: Session):
|
def test_read_single_todo_forbidden(client: TestClient, db: Session):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
|
|
||||||
@@ -104,16 +125,26 @@ def test_read_single_todo_forbidden(client: TestClient, db: Session):
|
|||||||
other_login_rsp = generators.login(db, other_user.username, other_password)
|
other_login_rsp = generators.login(db, other_user.username, other_password)
|
||||||
other_access_token = other_login_rsp["access_token"]
|
other_access_token = other_login_rsp["access_token"]
|
||||||
other_refresh_token = other_login_rsp["refresh_token"]
|
other_refresh_token = other_login_rsp["refresh_token"]
|
||||||
other_create_response = client.post("/api/todos/", headers={"Authorization": f"Bearer {other_access_token}"}, cookies={"refresh_token": other_refresh_token}, json={"task": "Other User Todo"})
|
other_create_response = client.post(
|
||||||
|
"/api/todos/",
|
||||||
|
headers={"Authorization": f"Bearer {other_access_token}"},
|
||||||
|
cookies={"refresh_token": other_refresh_token},
|
||||||
|
json={"task": "Other User Todo"},
|
||||||
|
)
|
||||||
other_todo_id = other_create_response.json()["id"]
|
other_todo_id = other_create_response.json()["id"]
|
||||||
|
|
||||||
# Try to access the other user's todo
|
# Try to access the other user's todo
|
||||||
response = client.get(
|
response = client.get(
|
||||||
f"/api/todos/{other_todo_id}",
|
f"/api/todos/{other_todo_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"}, # Using the first user's token
|
headers={
|
||||||
cookies={"refresh_token": refresh_token}
|
"Authorization": f"Bearer {access_token}"
|
||||||
|
}, # Using the first user's token
|
||||||
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND # Service raises 404 if not found for *this* user
|
assert (
|
||||||
|
response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
) # Service raises 404 if not found for *this* user
|
||||||
|
|
||||||
|
|
||||||
def test_update_todo(client: TestClient, db: Session):
|
def test_update_todo(client: TestClient, db: Session):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
@@ -121,7 +152,7 @@ def test_update_todo(client: TestClient, db: Session):
|
|||||||
"/api/todos/",
|
"/api/todos/",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json={"task": "Update Me"}
|
json={"task": "Update Me"},
|
||||||
)
|
)
|
||||||
todo_id = create_response.json()["id"]
|
todo_id = create_response.json()["id"]
|
||||||
|
|
||||||
@@ -130,7 +161,7 @@ def test_update_todo(client: TestClient, db: Session):
|
|||||||
f"/api/todos/{todo_id}",
|
f"/api/todos/{todo_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json=update_data
|
json=update_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
@@ -144,7 +175,7 @@ def test_update_todo(client: TestClient, db: Session):
|
|||||||
get_response = client.get(
|
get_response = client.get(
|
||||||
f"/api/todos/{todo_id}",
|
f"/api/todos/{todo_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token}
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
assert get_response.json()["task"] == update_data["task"]
|
assert get_response.json()["task"] == update_data["task"]
|
||||||
assert get_response.json()["complete"] == update_data["complete"]
|
assert get_response.json()["complete"] == update_data["complete"]
|
||||||
@@ -157,24 +188,25 @@ def test_update_todo_not_found(client: TestClient, db: Session):
|
|||||||
"/api/todos/9999", # Non-existent ID
|
"/api/todos/9999", # Non-existent ID
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json=update_data
|
json=update_data,
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
def test_delete_todo(client: TestClient, db: Session):
|
def test_delete_todo(client: TestClient, db: Session):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
create_response = client.post(
|
create_response = client.post(
|
||||||
"/api/todos/",
|
"/api/todos/",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token},
|
cookies={"refresh_token": refresh_token},
|
||||||
json={"task": "Delete Me"}
|
json={"task": "Delete Me"},
|
||||||
)
|
)
|
||||||
todo_id = create_response.json()["id"]
|
todo_id = create_response.json()["id"]
|
||||||
|
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
f"/api/todos/{todo_id}",
|
f"/api/todos/{todo_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token}
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item
|
assert response.status_code == status.HTTP_200_OK # Delete returns the deleted item
|
||||||
@@ -184,25 +216,29 @@ def test_delete_todo(client: TestClient, db: Session):
|
|||||||
get_response = client.get(
|
get_response = client.get(
|
||||||
f"/api/todos/{todo_id}",
|
f"/api/todos/{todo_id}",
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token}
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
assert get_response.status_code == status.HTTP_404_NOT_FOUND
|
assert get_response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
def test_delete_todo_not_found(client: TestClient, db: Session):
|
def test_delete_todo_not_found(client: TestClient, db: Session):
|
||||||
user, access_token, refresh_token = _login_user(db, client)
|
user, access_token, refresh_token = _login_user(db, client)
|
||||||
response = client.delete(
|
response = client.delete(
|
||||||
"/api/todos/9999", # Non-existent ID
|
"/api/todos/9999", # Non-existent ID
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
cookies={"refresh_token": refresh_token}
|
cookies={"refresh_token": refresh_token},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
# --- Test Authentication/Authorization ---
|
# --- Test Authentication/Authorization ---
|
||||||
|
|
||||||
|
|
||||||
def test_create_todo_unauthorized(client: TestClient):
|
def test_create_todo_unauthorized(client: TestClient):
|
||||||
response = client.post("/api/todos/", json={"task": "No Auth"})
|
response = client.post("/api/todos/", json={"task": "No Auth"})
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
def test_read_todos_unauthorized(client: TestClient):
|
def test_read_todos_unauthorized(client: TestClient):
|
||||||
response = client.get("/api/todos/")
|
response = client.get("/api/todos/")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|||||||
Reference in New Issue
Block a user