working auth + users systems

This commit is contained in:
c-d-p
2025-04-16 21:32:57 +02:00
parent 516adc606d
commit 18ddb2f332
56 changed files with 943 additions and 0 deletions

2
backend/.env Normal file
View File

@@ -0,0 +1,2 @@
PEPPER = "LsD7%"
JWT_SECRET_KEY="1c8cf3ca6972b365f8108dad247e61abdcb6faff5a6c8ba00cb6fa17396702bf"

1
backend/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/env

7
backend/.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

2
backend/TODO Normal file
View File

@@ -0,0 +1,2 @@
Pedantic:
- Shouldn't really return a 409 Conflict when user made with same username, could be used to enumerate users.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,10 @@
# core/celery_app.py
from celery import Celery
from core.config import settings
celery = Celery(
"maia",
broker=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0",
backend=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/1",
include=["modules.auth.tasks"], # List all task modules here
)

21
backend/core/config.py Normal file
View File

@@ -0,0 +1,21 @@
# core/config.py
from pydantic_settings import BaseSettings
from os import getenv
from dotenv import load_dotenv
load_dotenv() # Load .env file
class Settings(BaseSettings):
DB_URL: str = "postgresql://maia:maia@localhost:5432/maia"
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
JWT_ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
PEPPER: str = getenv("PEPPER", "")
JWT_SECRET_KEY: str = getenv("JWT_SECRET_KEY", "")
settings = Settings()

36
backend/core/database.py Normal file
View File

@@ -0,0 +1,36 @@
# core/database.py
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session, declarative_base
from typing import Generator
from core.config import settings
Base = declarative_base() # Used for models
_engine = None
_SessionLocal = None
def get_engine():
global _engine
if _engine is None:
if not settings.DB_URL:
raise ValueError("DB_URL is not set in Settings.")
print(f"Connecting to database at {settings.DB_URL}")
_engine = create_engine(settings.DB_URL)
Base.metadata.create_all(_engine) # Create tables here
return _engine
def get_sessionmaker():
global _SessionLocal
if _SessionLocal is None:
engine = get_engine()
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return _SessionLocal
def get_db() -> Generator[Session, None, None]:
SessionLocal = get_sessionmaker()
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -0,0 +1,27 @@
from fastapi import HTTPException
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
HTTP_403_FORBIDDEN,
HTTP_404_NOT_FOUND,
HTTP_500_INTERNAL_SERVER_ERROR,
HTTP_409_CONFLICT,
)
def bad_request_exception(detail: str = "Bad Request"):
return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
def unauthorized_exception(detail: str = "Unauthorized"):
return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
def forbidden_exception(detail: str = "Forbidden"):
return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
def not_found_exception(detail: str = "Not Found"):
return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
def internal_server_error_exception(detail: str = "Internal Server Error"):
return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
def conflict_exception(detail: str = "Conflict"):
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)

View File

@@ -0,0 +1,23 @@
# docker-compose.yml
services:
postgres:
image: postgres:14
environment:
POSTGRES_USER: maia
POSTGRES_PASSWORD: maia
POSTGRES_DB: maia
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
redis:
image: redis:7
ports:
- "6379:6379"
volumes:
- redis_data:/data
volumes:
postgres_data:
redis_data:

25
backend/main.py Normal file
View File

@@ -0,0 +1,25 @@
# main.py
from fastapi import FastAPI, Depends
from core.database import get_engine, Base
from modules.auth.api import router as auth_router
from modules.user.api import router as user_router
from modules.admin.api import router as admin_router
from modules.auth.dependencies import admin_only
import logging
from modules.auth.security import get_current_user
logging.getLogger('passlib').setLevel(logging.ERROR) # fix bc package logging is broken
# Create DB tables (remove in production; use migrations instead)
def lifespan(app):
# Base.metadata.drop_all(bind=get_engine())
Base.metadata.create_all(bind=get_engine())
yield
app = FastAPI(lifespan=lifespan)
# Include all module routers
app.include_router(auth_router, prefix="/api/auth")
app.include_router(user_router, prefix="/api/user")
app.include_router(admin_router, prefix="/api/admin", dependencies=[Depends(admin_only)])

View File

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,30 @@
# modules/admin/api.py
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from core.database import Base, get_db
from modules.auth.models import User, UserRole
router = APIRouter()
@router.get("/")
def read_admin():
return {"message": "Admin route"}
@router.get("/cleardb")
def clear_db(db: Annotated[Session, Depends(get_db)]):
"""
Clear the database.
"""
tables = Base.metadata.tables.keys()
for table in tables:
# delete all tables that isn't the users table
if table != "users":
table = Base.metadata.tables[table]
db.execute(table.delete())
# delete all non-admin accounts
db.query(User).filter(User.role != UserRole.ADMIN).delete()
db.commit()
return {"message": "Database cleared"}

View File

@@ -0,0 +1,4 @@
# modules/admin/services.py
## temp

View File

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,74 @@
# modules/auth/api.py
from fastapi import APIRouter, Cookie, Depends, HTTPException, status, Request, Response
from fastapi.security import OAuth2PasswordRequestForm
from jose import JWTError
from modules.auth.models import User
from modules.auth.schemas import UserCreate, UserResponse, Token
from modules.auth.services import create_user
from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens
from sqlalchemy.orm import Session
from typing import Annotated, Optional
from core.database import get_db
from datetime import timedelta
from core.config import settings # Assuming settings is defined in core.config
from core.exceptions import unauthorized_exception
router = APIRouter()
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
return create_user(user.username, user.password, user.name, db)
@router.post("/login", response_model=Token)
def login(response: Response, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
"""
Authenticate user and return JWT token.
"""
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
refresh_token = create_refresh_token(data={"sub": user.username})
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
response.set_cookie(
key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="Lax", max_age=max_age
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/refresh")
def refresh_token(request: Request, db: Annotated[Session, Depends(get_db)]):
refresh_token = request.cookies.get("refresh_token")
if not refresh_token:
raise unauthorized_exception("Refresh token missing")
user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db)
if not user_data:
raise unauthorized_exception("Invalid refresh token")
new_access_token = create_access_token(data={"sub": user_data.username})
return {"access_token": new_access_token, "token_type": "bearer"}
@router.post("/logout")
def logout(response: Response, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme), refresh_token: Optional[str] = Cookie(None, alias="refresh_token")):
try:
if not refresh_token:
raise unauthorized_exception("Refresh token not found")
blacklist_tokens(
access_token=access_token,
refresh_token=refresh_token,
db=db
)
response.delete_cookie(key="refresh_token")
return {"message": "Logged out successfully"}
except JWTError:
raise unauthorized_exception("Invalid token")

View File

@@ -0,0 +1,18 @@
# modules/auth/dependencies.py
from fastapi import Depends, HTTPException, status
from modules.auth.security import get_current_user
from modules.auth.schemas import UserRole
from modules.auth.models import User
from core.exceptions import forbidden_exception
class RoleChecker:
def __init__(self, allowed_roles: list[UserRole]):
self.allowed_roles = allowed_roles
def __call__(self, user: User = Depends(get_current_user)):
if user.role not in self.allowed_roles:
forbidden_exception("You do not have permission to perform this action.")
return user
admin_only = RoleChecker([UserRole.ADMIN])
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])

View File

@@ -0,0 +1,25 @@
# modules/auth/models.py
from core.database import Base
from sqlalchemy import CheckConstraint, Column, Integer, String, Enum, DateTime
from enum import Enum as PyEnum
class UserRole(str, PyEnum):
ADMIN = "admin"
USER = "user"
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
uuid = Column(String, unique=True)
username = Column(String, unique=True)
hashed_password = Column(String)
role = Column(Enum(UserRole), nullable=False, default=UserRole.USER)
name = Column(String)
class TokenBlacklist(Base):
__tablename__ = "token_blacklist"
id = Column(Integer, primary_key=True)
token = Column(String, unique=True)
expires_at = Column(DateTime)

View File

@@ -0,0 +1,33 @@
# modules/auth/schemas.py
from enum import Enum as PyEnum
from pydantic import BaseModel
class Token(BaseModel):
access_token: str
token_type: str
refresh_token: str | None = None
class TokenData(BaseModel):
username: str | None = None
scopes: list[str] = []
class UserRole(str, PyEnum):
ADMIN = "admin"
USER = "user"
class UserCreate(BaseModel):
username: str
password: str
name: str
class UserPatch(BaseModel):
name: str | None = None
class UserResponse(BaseModel):
uuid: str
username: str
name: str
role: UserRole
class Config:
from_attributes = True

View File

@@ -0,0 +1,172 @@
# modules/auth/security.py
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from sqlalchemy.orm import Session
from core.database import get_db
from core.config import settings
from modules.auth.models import TokenBlacklist, User
from modules.auth.schemas import TokenData
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
class TokenType(str, Enum):
ACCESS = "access"
REFRESH = "refresh"
password_hasher = PasswordHasher()
def hash_password(password: str) -> str:
"""Hash a password with Argon2 (and optional pepper)."""
peppered_password = password + settings.PEPPER # Prepend/append pepper
return password_hasher.hash(peppered_password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hashed version using Argon2."""
peppered_password = plain_password + settings.PEPPER
try:
return password_hasher.verify(hashed_password, peppered_password)
except VerifyMismatchError:
return False
def authenticate_user(username: str, password: str, db: Session) -> User | None:
"""
Authenticate a user by checking username/password against the database.
Returns User object if valid, None otherwise.
"""
# Get user from database
user = db.query(User).filter(User.username == username).first()
# If user not found or password doesn't match
if not user or not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
)
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
)
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
"""Verify a JWT token and return TokenData if valid.
Parameters
----------
token: str
The JWT token to be verified.
expected_token_type: TokenType
The expected type of token (access or refresh)
db: Session
Database session to fetch user data.
Returns
-------
TokenData | None
TokenData instance if the token is valid, None otherwise.
"""
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
if is_blacklisted:
return None
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
username: str = payload.get("sub")
token_type: str = payload.get("token_type")
if username is None or token_type != expected_token_type:
return None
return TokenData(username=username)
except JWTError:
return None
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Check if the token is blacklisted
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
if is_blacklisted:
raise credentials_exception
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user: User = db.query(User).filter(User.username == username).first()
if user is None:
raise credentials_exception
return user
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
"""Blacklist both access and refresh tokens.
Parameters
----------
access_token: str
The access token to blacklist
refresh_token: str
The refresh token to blacklist
db: Session
Database session to perform the operation.
"""
for token in [access_token, refresh_token]:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
db.add(blacklisted_token)
db.commit()
def blacklist_token(token: str, db: Session) -> None:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
expires_at = datetime.fromtimestamp(payload.get("exp"))
# Add the token to the blacklist
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
db.add(blacklisted_token)
db.commit()

View File

@@ -0,0 +1,30 @@
# modules/auth/services.py
from sqlalchemy.orm import Session
from modules.auth.models import User
from modules.auth.schemas import UserResponse
from modules.auth.security import hash_password
from core.exceptions import conflict_exception
import uuid
def create_user(username: str, password: str, name: str, db: Session) -> UserResponse:
"""
Create a new user in the database.
Hashes the password before storing it.
Returns the created user object.
"""
if db is None:
raise ValueError("Database session is required")
# Check if the user already exists
existing_user = db.query(User).filter(User.username == username).first()
if existing_user:
raise conflict_exception("Username already exists")
hashed_password = hash_password(password)
user_uuid = str(uuid.uuid4())
user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid)
db.add(user)
db.commit()
db.refresh(user) # Loads the generated ID
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic

View File

View File

Binary file not shown.

View File

@@ -0,0 +1,78 @@
# modules/user/api.py
from typing import Annotated
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from core.database import get_db
from core.exceptions import unauthorized_exception, not_found_exception, forbidden_exception
from modules.auth.schemas import UserPatch, UserResponse
from modules.auth.dependencies import get_current_user
from modules.auth.models import User
router = APIRouter()
@router.get("/me", response_model=UserResponse)
def me(db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
"""
Get the current user. Requires user to be logged in.
Returns the user object.
"""
return current_user
@router.get("/{username}", response_model=UserResponse)
def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
"""
Get a user by username.
Returns the user object.
"""
if current_user.username != username:
raise forbidden_exception("You can only view your own profile")
user = db.query(User).filter(User.username == username).first()
if not user:
raise not_found_exception("User not found")
return user
@router.patch("/{username}", response_model=UserResponse)
def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
"""
Update a user by username.
Returns the updated user object.
"""
if current_user.username != username:
raise forbidden_exception("You can only update your own profile")
user = db.query(User).filter(User.username == username).first()
if not user:
raise not_found_exception("User not found")
# Define fields that should not be updated
non_updateable_fields = {"uuid", "role", "username"}
print("BEFORE: ", user_data.model_dump(exclude_unset=True))
# Update only allowed fields
for key, value in user_data.model_dump(exclude_unset=True).items():
if key not in non_updateable_fields:
setattr(user, key, value)
print("AFTER:", user_data.model_dump(exclude_unset=True))
db.commit()
db.refresh(user)
return user
@router.delete("/{username}", response_model=UserResponse)
def delete_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
"""
Delete a user by username.
Returns the deleted user object.
"""
if current_user.username != username:
raise forbidden_exception("You can only delete your own profile")
user = db.query(User).filter(User.username == username).first()
if not user:
raise not_found_exception("User not found")
db.delete(user)
db.commit()
return user

45
backend/requirements.txt Normal file
View File

@@ -0,0 +1,45 @@
amqp==5.3.1
annotated-types==0.7.0
anyio==4.9.0
bcrypt==4.3.0
billiard==4.2.1
celery==5.5.1
cffi==1.17.1
click==8.1.8
click-didyoumean==0.3.1
click-plugins==1.1.1
click-repl==0.3.0
cryptography==44.0.2
ecdsa==0.19.1
fastapi==0.115.12
greenlet==3.1.1
h11==0.14.0
idna==3.10
iniconfig==2.1.0
kombu==5.5.2
packaging==24.2
passlib==1.7.4
pluggy==1.5.0
prompt_toolkit==3.0.50
psycopg2-binary==2.9.10
pyasn1==0.4.8
pycparser==2.22
pydantic==2.11.3
pydantic_core==2.33.1
pytest==8.3.5
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-jose==3.4.0
python-multipart==0.0.20
redis==5.2.1
rsa==4.9
six==1.17.0
sniffio==1.3.1
SQLAlchemy==2.0.40
starlette==0.46.2
typing-inspection==0.4.0
typing_extensions==4.13.2
tzdata==2025.2
uvicorn==0.34.1
vine==5.1.0
wcwidth==0.2.13

View File

Binary file not shown.

58
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,58 @@
# conftest.py
from typing import Generator, Callable, Any
import pytest
from testcontainers.postgres import PostgresContainer
from fastapi.testclient import TestClient
from core.config import settings
from faker import Faker
from sqlalchemy.orm import Session
from core.database import get_db, get_sessionmaker
fake = Faker()
@pytest.fixture(scope="session")
def postgres_container() -> Generator[PostgresContainer, None, None]:
"""Fixture to create a PostgreSQL container for testing."""
print("Starting Postgres container...")
with PostgresContainer("postgres:latest") as postgres:
settings.DB_URL = postgres.get_connection_url()
print(f"Postgres container started at {settings.DB_URL}")
yield postgres
print("Postgres container stopped.")
@pytest.fixture(scope="function")
def db(postgres_container) -> Generator[Session, None, None]:
"""Function-scoped database session with rollback"""
SessionLocal = get_sessionmaker()
session = SessionLocal()
session.begin_nested() # Enable nested transaction
try:
yield session
finally:
session.rollback()
session.close()
@pytest.fixture(scope="function")
def client(db: Session) -> Generator[TestClient, None, None]:
"""Function-scoped test client with dependency override"""
from main import app
# Override the database dependency
def override_get_db():
try:
yield db
finally:
pass # Don't close session here
app.dependency_overrides[get_db] = override_get_db
with TestClient(app) as test_client:
yield test_client
app.dependency_overrides.clear()
def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None:
from main import app
app.dependency_overrides[dependency] = lambda: mocked_response

View File

@@ -0,0 +1,42 @@
from datetime import timedelta
import uuid as uuid_pkg
from sqlalchemy.orm import Session
from core.config import settings
from modules.auth.models import User
from modules.auth.security import authenticate_user, create_access_token, create_refresh_token, hash_password
from modules.auth.schemas import UserRole
from tests.conftest import fake
def create_user(db: Session, is_admin: bool = False) -> User:
unhashed_password = fake.password()
_user = User(
name=fake.name(),
username=fake.user_name(),
hashed_password=hash_password(unhashed_password),
uuid=uuid_pkg.uuid4(),
role=UserRole.ADMIN if is_admin else UserRole.USER,
)
db.add(_user)
db.commit()
db.refresh(_user)
return _user, unhashed_password # return for testing
def login(db: Session, username: str, password: str) -> str:
user = authenticate_user(username, password, db)
if not user:
raise Exception("Incorrect username or password")
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
refresh_token = create_refresh_token(data={"sub": user.username})
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
return {
"access_token": access_token,
"refresh_token": refresh_token,
"max_age": max_age,
}

180
backend/tests/test_auth.py Normal file
View File

@@ -0,0 +1,180 @@
# Main test file for the authentication process.
# uses conftest -> db_session as an in-memory db.
# Goes through the whole authentication process:
# 1. Register a user
# 2. Login the user
# 3. Refresh the token
# 4. Logout the user
# 5. Verify that the user is logged out
# 6. Verify that the user cannot refresh the token
# 7. Verify that the user cannot login again
# 8. Verify that the user cannot register again
# 9. Verify that the user cannot access protected routes (/admin)
import time
from fastapi import status
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from modules.auth.models import TokenBlacklist, User
from tests.conftest import fake
from .helpers import generators
def test_register(client: TestClient) -> None:
response = client.post(
"/api/auth/register",
json={
"username": fake.user_name(),
"password": fake.password(),
"name": fake.name(),
},
)
assert response.status_code == status.HTTP_201_CREATED
def test_login(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
response = client.post(
"/api/auth/login",
data={
"username": user.username,
"password": unhashed_password,
},
)
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert "access_token" in response_data
assert "token_type" in response_data
assert response_data["token_type"] == "bearer"
def test_refresh_token(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
rsp = generators.login(db, user.username, unhashed_password)
access_token = rsp["access_token"]
refresh_token = rsp["refresh_token"]
time.sleep(1) # Sleep to ensure tokens won't be identical
response = client.post(
"/api/auth/refresh",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert "access_token" in response_data
assert "token_type" in response_data
assert response_data["token_type"] == "bearer"
assert response_data["access_token"] != access_token # Ensure the token is refreshed
def test_logout(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
rsp = generators.login(db, user.username, unhashed_password)
access_token = rsp["access_token"]
refresh_token = rsp["refresh_token"]
response = client.post(
"/api/auth/logout",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK
# Verify that the token is blacklisted
blacklisted_token = db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
assert blacklisted_token is not None
# Verify that we can't still actually do anything
response = client.get(
"/api/user/me",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
response = client.post(
"/api/auth/refresh",
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_me(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
response = client.get(
"/api/user/me",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert response_data["uuid"] == user.uuid
assert response_data["username"] == user.username
def test_get_me_unauthorized(client: TestClient) -> None:
### This test should fail (unauthorized) because the user isn't logged in
response = client.get("/api/user/me")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_get_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
response = client.get(
f"/api/user/{user.username}",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert response_data["uuid"] == user.uuid
assert response_data["username"] == user.username
def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
### This test should fail (unauthorized) because the user isn't us
user, unhashed_password = generators.create_user(db)
user2, _ = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
response = client.get(
f"/api/user/{user2.username}",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_update_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
new_name = fake.name()
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
response = client.patch(
f"/api/user/{user.username}",
headers={"Authorization": f"Bearer {access_token}"},
json={"name": new_name},
)
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert response_data["name"] == new_name
def test_delete_user(db: Session, client: TestClient) -> None:
user, unhashed_password = generators.create_user(db)
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
response = client.delete(
f"/api/user/{user.username}",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == status.HTTP_200_OK
# Verify that the user is deleted
deleted_user = db.query(User).filter(User.username == user.username).first()
assert deleted_user is None