[REFORMAT] Ran black reformat

This commit is contained in:
c-d-p
2025-04-23 01:00:56 +02:00
parent d5d0a24403
commit 1553004efc
38 changed files with 1005 additions and 384 deletions

View File

@@ -3,9 +3,24 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from jose import JWTError
from modules.auth.models import User
from modules.auth.schemas import UserCreate, UserResponse, Token, RefreshTokenRequest, LogoutRequest
from modules.auth.schemas import (
UserCreate,
UserResponse,
Token,
RefreshTokenRequest,
LogoutRequest,
)
from modules.auth.services import create_user
from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens
from modules.auth.security import (
TokenType,
get_current_user,
oauth2_scheme,
create_access_token,
create_refresh_token,
verify_token,
authenticate_user,
blacklist_tokens,
)
from sqlalchemy.orm import Session
from typing import Annotated
from core.database import get_db
@@ -15,12 +30,19 @@ from core.exceptions import unauthorized_exception
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
@router.post(
"/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED
)
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
return create_user(user.username, user.password, user.name, db)
@router.post("/login", response_model=Token)
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[Session, Depends(get_db)],
):
"""
Authenticate user and return JWT tokens in the response body.
"""
@@ -30,39 +52,53 @@ def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annota
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
)
refresh_token = create_refresh_token(data={"sub": user.username})
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
}
@router.post("/refresh")
def refresh_token(payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]):
def refresh_token(
payload: RefreshTokenRequest, db: Annotated[Session, Depends(get_db)]
):
print("Refreshing token...")
refresh_token = payload.refresh_token
if not refresh_token:
raise unauthorized_exception("Refresh token missing in request body")
user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db)
user_data = verify_token(
refresh_token, expected_token_type=TokenType.REFRESH, db=db
)
if not user_data:
raise unauthorized_exception("Invalid refresh token")
new_access_token = create_access_token(data={"sub": user_data.username})
return {"access_token": new_access_token, "token_type": "bearer"}
@router.post("/logout")
def logout(payload: LogoutRequest, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme)):
def logout(
payload: LogoutRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
access_token: str = Depends(oauth2_scheme),
):
try:
refresh_token = payload.refresh_token
if not refresh_token:
raise unauthorized_exception("Refresh token not found in request body")
blacklist_tokens(
access_token=access_token,
refresh_token=refresh_token,
db=db
)
blacklist_tokens(access_token=access_token, refresh_token=refresh_token, db=db)
return {"message": "Logged out successfully"}
except JWTError:
raise unauthorized_exception("Invalid token")
raise unauthorized_exception("Invalid token")

View File

@@ -5,14 +5,18 @@ from modules.auth.schemas import UserRole
from modules.auth.models import User
from core.exceptions import forbidden_exception
class RoleChecker:
def __init__(self, allowed_roles: list[UserRole]):
self.allowed_roles = allowed_roles
def __call__(self, user: User = Depends(get_current_user)):
if user.role not in self.allowed_roles:
raise forbidden_exception("You do not have permission to perform this action.")
raise forbidden_exception(
"You do not have permission to perform this action."
)
return user
admin_only = RoleChecker([UserRole.ADMIN])
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])

View File

@@ -4,10 +4,12 @@ from sqlalchemy import Column, Integer, String, Enum, DateTime
from sqlalchemy.orm import relationship
from enum import Enum as PyEnum
class UserRole(str, PyEnum):
ADMIN = "admin"
USER = "user"
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)

View File

@@ -2,33 +2,41 @@
from enum import Enum as PyEnum
from pydantic import BaseModel
class Token(BaseModel):
access_token: str
token_type: str
refresh_token: str | None = None
class TokenData(BaseModel):
username: str | None = None
scopes: list[str] = []
class RefreshTokenRequest(BaseModel):
refresh_token: str
class LogoutRequest(BaseModel):
refresh_token: str
class UserRole(str, PyEnum):
ADMIN = "admin"
USER = "user"
class UserCreate(BaseModel):
username: str
password: str
name: str
class UserPatch(BaseModel):
name: str | None = None
class UserResponse(BaseModel):
uuid: str
username: str

View File

@@ -18,6 +18,7 @@ from modules.auth.schemas import TokenData
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
class TokenType(str, Enum):
ACCESS = "access"
REFRESH = "refresh"
@@ -25,11 +26,13 @@ class TokenType(str, Enum):
password_hasher = PasswordHasher()
def hash_password(password: str) -> str:
"""Hash a password with Argon2 (and optional pepper)."""
peppered_password = password + settings.PEPPER # Prepend/append pepper
return password_hasher.hash(peppered_password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hashed version using Argon2."""
peppered_password = plain_password + settings.PEPPER
@@ -38,6 +41,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
except VerifyMismatchError:
return False
def authenticate_user(username: str, password: str, db: Session) -> User | None:
"""
Authenticate a user by checking username/password against the database.
@@ -45,41 +49,46 @@ def authenticate_user(username: str, password: str, db: Session) -> User | None:
"""
# Get user from database
user = db.query(User).filter(User.username == username).first()
# If user not found or password doesn't match
if not user or not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
# expire = datetime.now(timezone.utc) + timedelta(seconds=5)
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
expire = datetime.now(timezone.utc) + timedelta(
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
)
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM
)
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
def verify_token(
token: str, expected_token_type: TokenType, db: Session
) -> TokenData | None:
"""Verify a JWT token and return TokenData if valid.
Parameters
@@ -96,24 +105,32 @@ def verify_token(token: str, expected_token_type: TokenType, db: Session) -> Tok
TokenData | None
TokenData instance if the token is valid, None otherwise.
"""
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
is_blacklisted = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
is not None
)
if is_blacklisted:
return None
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub")
token_type: str = payload.get("token_type")
if username is None or token_type != expected_token_type:
return None
return TokenData(username=username)
except JWTError:
return None
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
def get_current_user(
db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
@@ -121,26 +138,28 @@ def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depen
)
# Check if the token is blacklisted
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
is_blacklisted = (
db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first()
is not None
)
if is_blacklisted:
raise credentials_exception
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user: User = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
"""Blacklist both access and refresh tokens.
@@ -154,7 +173,9 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
Database session to perform the operation.
"""
for token in [access_token, refresh_token]:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist
@@ -163,10 +184,13 @@ def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None
db.commit()
def blacklist_token(token: str, db: Session) -> None:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
db.add(blacklisted_token)

View File

@@ -20,11 +20,13 @@ def create_user(username: str, password: str, name: str, db: Session) -> UserRes
existing_user = db.query(User).filter(User.username == username).first()
if existing_user:
raise conflict_exception("Username already exists")
hashed_password = hash_password(password)
user_uuid = str(uuid.uuid4())
user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid)
user = User(
username=username, hashed_password=hashed_password, name=name, uuid=user_uuid
)
db.add(user)
db.commit()
db.refresh(user) # Loads the generated ID
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic