Files
MAIA/backend/modules/auth/security.py
2025-04-21 20:09:41 +02:00

174 lines
5.8 KiB
Python

# modules/auth/security.py
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from sqlalchemy.orm import Session
from core.database import get_db
from core.config import settings
from modules.auth.models import TokenBlacklist, User
from modules.auth.schemas import TokenData
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
class TokenType(str, Enum):
ACCESS = "access"
REFRESH = "refresh"
password_hasher = PasswordHasher()
def hash_password(password: str) -> str:
"""Hash a password with Argon2 (and optional pepper)."""
peppered_password = password + settings.PEPPER # Prepend/append pepper
return password_hasher.hash(peppered_password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hashed version using Argon2."""
peppered_password = plain_password + settings.PEPPER
try:
return password_hasher.verify(hashed_password, peppered_password)
except VerifyMismatchError:
return False
def authenticate_user(username: str, password: str, db: Session) -> User | None:
"""
Authenticate a user by checking username/password against the database.
Returns User object if valid, None otherwise.
"""
# Get user from database
user = db.query(User).filter(User.username == username).first()
# If user not found or password doesn't match
if not user or not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
# expire = datetime.now(timezone.utc) + timedelta(seconds=5)
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
)
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
)
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
"""Verify a JWT token and return TokenData if valid.
Parameters
----------
token: str
The JWT token to be verified.
expected_token_type: TokenType
The expected type of token (access or refresh)
db: Session
Database session to fetch user data.
Returns
-------
TokenData | None
TokenData instance if the token is valid, None otherwise.
"""
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
if is_blacklisted:
return None
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
username: str = payload.get("sub")
token_type: str = payload.get("token_type")
if username is None or token_type != expected_token_type:
return None
return TokenData(username=username)
except JWTError:
return None
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Check if the token is blacklisted
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
if is_blacklisted:
raise credentials_exception
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user: User = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
"""Blacklist both access and refresh tokens.
Parameters
----------
access_token: str
The access token to blacklist
refresh_token: str
The refresh token to blacklist
db: Session
Database session to perform the operation.
"""
for token in [access_token, refresh_token]:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
db.add(blacklisted_token)
db.commit()
def blacklist_token(token: str, db: Session) -> None:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
db.add(blacklisted_token)
db.commit()