# modules/auth/security.py from datetime import datetime, timedelta, timezone from enum import Enum from typing import Annotated from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from sqlalchemy.orm import Session from core.database import get_db from core.config import settings from modules.auth.models import TokenBlacklist, User from modules.auth.schemas import TokenData oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") class TokenType(str, Enum): ACCESS = "access" REFRESH = "refresh" password_hasher = PasswordHasher() def hash_password(password: str) -> str: """Hash a password with Argon2 (and optional pepper).""" peppered_password = password + settings.PEPPER return password_hasher.hash(peppered_password) def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hashed version using Argon2.""" peppered_password = plain_password + settings.PEPPER try: return password_hasher.verify(hashed_password, peppered_password) except VerifyMismatchError: return False def authenticate_user(username: str, password: str, db: Session) -> User | None: """ Authenticate a user by checking username/password against the database. Returns User object if valid, None otherwise. """ user = db.query(User).filter(User.username == username).first() if not user or not verify_password(password, user.hashed_password): return None return user def create_access_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta( minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES ) to_encode.update({"exp": expire, "token_type": TokenType.ACCESS}) return jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) def create_refresh_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta( days=settings.REFRESH_TOKEN_EXPIRE_DAYS ) to_encode.update({"exp": expire, "token_type": TokenType.REFRESH}) return jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) def verify_token( token: str, expected_token_type: TokenType, db: Session ) -> TokenData | None: is_blacklisted = ( db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None ) if is_blacklisted: return None try: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) username: str = payload.get("sub") token_type: str = payload.get("token_type") if username is None or token_type != expected_token_type: return None return TokenData(username=username) except JWTError: return None def get_current_user( db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme) ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) is_blacklisted = ( db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None ) if is_blacklisted: raise credentials_exception try: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) username: str = payload.get("sub") if username is None: raise credentials_exception except JWTError: raise credentials_exception user: User = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None: """Blacklist both access and refresh tokens. Parameters ---------- access_token: str The access token to blacklist refresh_token: str The refresh token to blacklist db: Session Database session to perform the operation. """ for token in [access_token, refresh_token]: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) expires_at = datetime.fromtimestamp(payload.get("exp")) blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at) db.add(blacklisted_token) db.commit() def blacklist_token(token: str, db: Session) -> None: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) expires_at = datetime.fromtimestamp(payload.get("exp")) blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at) db.add(blacklisted_token) db.commit()