Files
MAIA/backend/modules/auth/security.py
2025-04-26 12:43:19 +02:00

176 lines
5.2 KiB
Python

# modules/auth/security.py
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from sqlalchemy.orm import Session
from core.database import get_db
from core.config import settings
from modules.auth.models import TokenBlacklist, User
from modules.auth.schemas import TokenData
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
class TokenType(str, Enum):
ACCESS = "access"
REFRESH = "refresh"
password_hasher = PasswordHasher()
def hash_password(password: str) -> str:
"""Hash a password with Argon2 (and optional pepper)."""
peppered_password = password + settings.PEPPER
return password_hasher.hash(peppered_password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hashed version using Argon2."""
peppered_password = plain_password + settings.PEPPER
try:
return password_hasher.verify(hashed_password, peppered_password)
except VerifyMismatchError:
return False
def authenticate_user(username: str, password: str, db: Session) -> User | None:
"""
Authenticate a user by checking username/password against the database.
Returns User object if valid, None otherwise.
"""
user = db.query(User).filter(User.username == username).first()
if not user or not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
return jwt.encode(
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
)
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
return jwt.encode(
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
def verify_token(
token: str, expected_token_type: TokenType, db: Session
) -> TokenData | None:
is_blacklisted = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
is not None
)
if is_blacklisted:
return None
try:
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub")
token_type: str = payload.get("token_type")
if username is None or token_type != expected_token_type:
return None
return TokenData(username=username)
except JWTError:
return None
def get_current_user(
db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
is_blacklisted = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
is not None
)
if is_blacklisted:
raise credentials_exception
try:
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user: User = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
"""Blacklist both access and refresh tokens.
Parameters
----------
access_token: str
The access token to blacklist
refresh_token: str
The refresh token to blacklist
db: Session
Database session to perform the operation.
"""
for token in [access_token, refresh_token]:
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
expires_at = datetime.fromtimestamp(payload.get("exp"))
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
db.add(blacklisted_token)
db.commit()
def blacklist_token(token: str, db: Session) -> None:
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
expires_at = datetime.fromtimestamp(payload.get("exp"))
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
db.add(blacklisted_token)
db.commit()