# modules/auth/security.py from datetime import datetime, timedelta, timezone from enum import Enum from typing import Annotated from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from sqlalchemy.orm import Session from core.database import get_db from core.config import settings from modules.auth.models import TokenBlacklist, User from modules.auth.schemas import TokenData oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") class TokenType(str, Enum): ACCESS = "access" REFRESH = "refresh" password_hasher = PasswordHasher() def hash_password(password: str) -> str: """Hash a password with Argon2 (and optional pepper).""" peppered_password = password + settings.PEPPER # Prepend/append pepper return password_hasher.hash(peppered_password) def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hashed version using Argon2.""" peppered_password = plain_password + settings.PEPPER try: return password_hasher.verify(hashed_password, peppered_password) except VerifyMismatchError: return False def authenticate_user(username: str, password: str, db: Session) -> User | None: """ Authenticate a user by checking username/password against the database. Returns User object if valid, None otherwise. """ # Get user from database user = db.query(User).filter(User.username == username).first() # If user not found or password doesn't match if not user or not verify_password(password, user.hashed_password): return None return user def create_access_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) # expire = datetime.now(timezone.utc) + timedelta(seconds=5) to_encode.update({"exp": expire, "token_type": TokenType.ACCESS}) return jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) def create_refresh_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) to_encode.update({"exp": expire, "token_type": TokenType.REFRESH}) return jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM ) def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None: """Verify a JWT token and return TokenData if valid. Parameters ---------- token: str The JWT token to be verified. expected_token_type: TokenType The expected type of token (access or refresh) db: Session Database session to fetch user data. Returns ------- TokenData | None TokenData instance if the token is valid, None otherwise. """ is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None if is_blacklisted: return None try: payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) username: str = payload.get("sub") token_type: str = payload.get("token_type") if username is None or token_type != expected_token_type: return None return TokenData(username=username) except JWTError: return None def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) # Check if the token is blacklisted is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None if is_blacklisted: raise credentials_exception try: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM] ) username: str = payload.get("sub") if username is None: raise credentials_exception except JWTError: raise credentials_exception user: User = db.query(User).filter(User.username == username).first() if user is None: raise credentials_exception return user def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None: """Blacklist both access and refresh tokens. Parameters ---------- access_token: str The access token to blacklist refresh_token: str The refresh token to blacklist db: Session Database session to perform the operation. """ for token in [access_token, refresh_token]: payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) expires_at = datetime.fromtimestamp(payload.get("exp")) # Add the token to the blacklist blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at) db.add(blacklisted_token) db.commit() def blacklist_token(token: str, db: Session) -> None: payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) expires_at = datetime.fromtimestamp(payload.get("exp")) # Add the token to the blacklist blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at) db.add(blacklisted_token) db.commit()