working auth + users systems
This commit is contained in:
2
backend/.env
Normal file
2
backend/.env
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
PEPPER = "LsD7%"
|
||||||
|
JWT_SECRET_KEY="1c8cf3ca6972b365f8108dad247e61abdcb6faff5a6c8ba00cb6fa17396702bf"
|
||||||
1
backend/.gitignore
vendored
Normal file
1
backend/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
/env
|
||||||
7
backend/.vscode/settings.json
vendored
Normal file
7
backend/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"python.testing.pytestArgs": [
|
||||||
|
"tests"
|
||||||
|
],
|
||||||
|
"python.testing.unittestEnabled": false,
|
||||||
|
"python.testing.pytestEnabled": true
|
||||||
|
}
|
||||||
2
backend/TODO
Normal file
2
backend/TODO
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
Pedantic:
|
||||||
|
- Shouldn't really return a 409 Conflict when user made with same username, could be used to enumerate users.
|
||||||
BIN
backend/__pycache__/main.cpython-312.pyc
Normal file
BIN
backend/__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/core/__pycache__/celery_app.cpython-312.pyc
Normal file
BIN
backend/core/__pycache__/celery_app.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/core/__pycache__/config.cpython-312.pyc
Normal file
BIN
backend/core/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/core/__pycache__/database.cpython-312.pyc
Normal file
BIN
backend/core/__pycache__/database.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/core/__pycache__/exceptions.cpython-312.pyc
Normal file
BIN
backend/core/__pycache__/exceptions.cpython-312.pyc
Normal file
Binary file not shown.
10
backend/core/celery_app.py
Normal file
10
backend/core/celery_app.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# core/celery_app.py
|
||||||
|
from celery import Celery
|
||||||
|
from core.config import settings
|
||||||
|
|
||||||
|
celery = Celery(
|
||||||
|
"maia",
|
||||||
|
broker=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0",
|
||||||
|
backend=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/1",
|
||||||
|
include=["modules.auth.tasks"], # List all task modules here
|
||||||
|
)
|
||||||
21
backend/core/config.py
Normal file
21
backend/core/config.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# core/config.py
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from os import getenv
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv() # Load .env file
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
DB_URL: str = "postgresql://maia:maia@localhost:5432/maia"
|
||||||
|
|
||||||
|
REDIS_HOST: str = "localhost"
|
||||||
|
REDIS_PORT: int = 6379
|
||||||
|
|
||||||
|
JWT_ALGORITHM: str = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||||
|
|
||||||
|
PEPPER: str = getenv("PEPPER", "")
|
||||||
|
JWT_SECRET_KEY: str = getenv("JWT_SECRET_KEY", "")
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
36
backend/core/database.py
Normal file
36
backend/core/database.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# core/database.py
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session, declarative_base
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from core.config import settings
|
||||||
|
|
||||||
|
Base = declarative_base() # Used for models
|
||||||
|
|
||||||
|
_engine = None
|
||||||
|
_SessionLocal = None
|
||||||
|
|
||||||
|
def get_engine():
|
||||||
|
global _engine
|
||||||
|
if _engine is None:
|
||||||
|
if not settings.DB_URL:
|
||||||
|
raise ValueError("DB_URL is not set in Settings.")
|
||||||
|
print(f"Connecting to database at {settings.DB_URL}")
|
||||||
|
_engine = create_engine(settings.DB_URL)
|
||||||
|
Base.metadata.create_all(_engine) # Create tables here
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
def get_sessionmaker():
|
||||||
|
global _SessionLocal
|
||||||
|
if _SessionLocal is None:
|
||||||
|
engine = get_engine()
|
||||||
|
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
return _SessionLocal
|
||||||
|
|
||||||
|
def get_db() -> Generator[Session, None, None]:
|
||||||
|
SessionLocal = get_sessionmaker()
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
27
backend/core/exceptions.py
Normal file
27
backend/core/exceptions.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from fastapi import HTTPException
|
||||||
|
from starlette.status import (
|
||||||
|
HTTP_400_BAD_REQUEST,
|
||||||
|
HTTP_401_UNAUTHORIZED,
|
||||||
|
HTTP_403_FORBIDDEN,
|
||||||
|
HTTP_404_NOT_FOUND,
|
||||||
|
HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
HTTP_409_CONFLICT,
|
||||||
|
)
|
||||||
|
|
||||||
|
def bad_request_exception(detail: str = "Bad Request"):
|
||||||
|
return HTTPException(status_code=HTTP_400_BAD_REQUEST, detail=detail)
|
||||||
|
|
||||||
|
def unauthorized_exception(detail: str = "Unauthorized"):
|
||||||
|
return HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=detail)
|
||||||
|
|
||||||
|
def forbidden_exception(detail: str = "Forbidden"):
|
||||||
|
return HTTPException(status_code=HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
def not_found_exception(detail: str = "Not Found"):
|
||||||
|
return HTTPException(status_code=HTTP_404_NOT_FOUND, detail=detail)
|
||||||
|
|
||||||
|
def internal_server_error_exception(detail: str = "Internal Server Error"):
|
||||||
|
return HTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail=detail)
|
||||||
|
|
||||||
|
def conflict_exception(detail: str = "Conflict"):
|
||||||
|
return HTTPException(status_code=HTTP_409_CONFLICT, detail=detail)
|
||||||
23
backend/docker-compose.yml
Normal file
23
backend/docker-compose.yml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# docker-compose.yml
|
||||||
|
services:
|
||||||
|
postgres:
|
||||||
|
image: postgres:14
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: maia
|
||||||
|
POSTGRES_PASSWORD: maia
|
||||||
|
POSTGRES_DB: maia
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7
|
||||||
|
ports:
|
||||||
|
- "6379:6379"
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
redis_data:
|
||||||
25
backend/main.py
Normal file
25
backend/main.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# main.py
|
||||||
|
from fastapi import FastAPI, Depends
|
||||||
|
from core.database import get_engine, Base
|
||||||
|
from modules.auth.api import router as auth_router
|
||||||
|
from modules.user.api import router as user_router
|
||||||
|
from modules.admin.api import router as admin_router
|
||||||
|
from modules.auth.dependencies import admin_only
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from modules.auth.security import get_current_user
|
||||||
|
|
||||||
|
logging.getLogger('passlib').setLevel(logging.ERROR) # fix bc package logging is broken
|
||||||
|
|
||||||
|
# Create DB tables (remove in production; use migrations instead)
|
||||||
|
def lifespan(app):
|
||||||
|
# Base.metadata.drop_all(bind=get_engine())
|
||||||
|
Base.metadata.create_all(bind=get_engine())
|
||||||
|
yield
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
# Include all module routers
|
||||||
|
app.include_router(auth_router, prefix="/api/auth")
|
||||||
|
app.include_router(user_router, prefix="/api/user")
|
||||||
|
app.include_router(admin_router, prefix="/api/admin", dependencies=[Depends(admin_only)])
|
||||||
0
backend/modules/__init__.py
Normal file
0
backend/modules/__init__.py
Normal file
BIN
backend/modules/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/modules/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/admin/__pycache__/api.cpython-312.pyc
Normal file
BIN
backend/modules/admin/__pycache__/api.cpython-312.pyc
Normal file
Binary file not shown.
30
backend/modules/admin/api.py
Normal file
30
backend/modules/admin/api.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# modules/admin/api.py
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from core.database import Base, get_db
|
||||||
|
from modules.auth.models import User, UserRole
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
def read_admin():
|
||||||
|
return {"message": "Admin route"}
|
||||||
|
|
||||||
|
@router.get("/cleardb")
|
||||||
|
def clear_db(db: Annotated[Session, Depends(get_db)]):
|
||||||
|
"""
|
||||||
|
Clear the database.
|
||||||
|
"""
|
||||||
|
tables = Base.metadata.tables.keys()
|
||||||
|
for table in tables:
|
||||||
|
# delete all tables that isn't the users table
|
||||||
|
if table != "users":
|
||||||
|
table = Base.metadata.tables[table]
|
||||||
|
db.execute(table.delete())
|
||||||
|
|
||||||
|
# delete all non-admin accounts
|
||||||
|
db.query(User).filter(User.role != UserRole.ADMIN).delete()
|
||||||
|
db.commit()
|
||||||
|
return {"message": "Database cleared"}
|
||||||
4
backend/modules/admin/services.py
Normal file
4
backend/modules/admin/services.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# modules/admin/services.py
|
||||||
|
|
||||||
|
|
||||||
|
## temp
|
||||||
0
backend/modules/auth/__init__.py
Normal file
0
backend/modules/auth/__init__.py
Normal file
BIN
backend/modules/auth/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/auth/__pycache__/api.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/api.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/auth/__pycache__/dependencies.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/dependencies.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/auth/__pycache__/models.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/models.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/auth/__pycache__/schemas.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/schemas.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/auth/__pycache__/security.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/security.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/auth/__pycache__/service.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/service.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/auth/__pycache__/services.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/services.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
backend/modules/auth/__pycache__/utils.cpython-312.pyc
Normal file
BIN
backend/modules/auth/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
74
backend/modules/auth/api.py
Normal file
74
backend/modules/auth/api.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# modules/auth/api.py
|
||||||
|
from fastapi import APIRouter, Cookie, Depends, HTTPException, status, Request, Response
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from jose import JWTError
|
||||||
|
from modules.auth.models import User
|
||||||
|
from modules.auth.schemas import UserCreate, UserResponse, Token
|
||||||
|
from modules.auth.services import create_user
|
||||||
|
from modules.auth.security import TokenType, get_current_user, oauth2_scheme, create_access_token, create_refresh_token, verify_token, authenticate_user, blacklist_tokens
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
from core.database import get_db
|
||||||
|
from datetime import timedelta
|
||||||
|
from core.config import settings # Assuming settings is defined in core.config
|
||||||
|
from core.exceptions import unauthorized_exception
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
def register(user: UserCreate, db: Annotated[Session, Depends(get_db)]):
|
||||||
|
return create_user(user.username, user.password, user.name, db)
|
||||||
|
|
||||||
|
@router.post("/login", response_model=Token)
|
||||||
|
def login(response: Response, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
|
||||||
|
"""
|
||||||
|
Authenticate user and return JWT token.
|
||||||
|
"""
|
||||||
|
user = authenticate_user(form_data.username, form_data.password, db)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect username or password",
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||||
|
refresh_token = create_refresh_token(data={"sub": user.username})
|
||||||
|
|
||||||
|
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||||
|
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token", value=refresh_token, httponly=True, secure=True, samesite="Lax", max_age=max_age
|
||||||
|
)
|
||||||
|
return {"access_token": access_token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
@router.post("/refresh")
|
||||||
|
def refresh_token(request: Request, db: Annotated[Session, Depends(get_db)]):
|
||||||
|
refresh_token = request.cookies.get("refresh_token")
|
||||||
|
if not refresh_token:
|
||||||
|
raise unauthorized_exception("Refresh token missing")
|
||||||
|
|
||||||
|
|
||||||
|
user_data = verify_token(refresh_token, expected_token_type=TokenType.REFRESH, db=db)
|
||||||
|
if not user_data:
|
||||||
|
raise unauthorized_exception("Invalid refresh token")
|
||||||
|
|
||||||
|
|
||||||
|
new_access_token = create_access_token(data={"sub": user_data.username})
|
||||||
|
return {"access_token": new_access_token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
@router.post("/logout")
|
||||||
|
def logout(response: Response, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)], access_token: str = Depends(oauth2_scheme), refresh_token: Optional[str] = Cookie(None, alias="refresh_token")):
|
||||||
|
try:
|
||||||
|
if not refresh_token:
|
||||||
|
raise unauthorized_exception("Refresh token not found")
|
||||||
|
|
||||||
|
blacklist_tokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
db=db
|
||||||
|
)
|
||||||
|
response.delete_cookie(key="refresh_token")
|
||||||
|
|
||||||
|
return {"message": "Logged out successfully"}
|
||||||
|
except JWTError:
|
||||||
|
raise unauthorized_exception("Invalid token")
|
||||||
18
backend/modules/auth/dependencies.py
Normal file
18
backend/modules/auth/dependencies.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# modules/auth/dependencies.py
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from modules.auth.security import get_current_user
|
||||||
|
from modules.auth.schemas import UserRole
|
||||||
|
from modules.auth.models import User
|
||||||
|
from core.exceptions import forbidden_exception
|
||||||
|
|
||||||
|
class RoleChecker:
|
||||||
|
def __init__(self, allowed_roles: list[UserRole]):
|
||||||
|
self.allowed_roles = allowed_roles
|
||||||
|
|
||||||
|
def __call__(self, user: User = Depends(get_current_user)):
|
||||||
|
if user.role not in self.allowed_roles:
|
||||||
|
forbidden_exception("You do not have permission to perform this action.")
|
||||||
|
return user
|
||||||
|
|
||||||
|
admin_only = RoleChecker([UserRole.ADMIN])
|
||||||
|
any_user = RoleChecker([UserRole.ADMIN, UserRole.USER])
|
||||||
25
backend/modules/auth/models.py
Normal file
25
backend/modules/auth/models.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# modules/auth/models.py
|
||||||
|
from core.database import Base
|
||||||
|
from sqlalchemy import CheckConstraint, Column, Integer, String, Enum, DateTime
|
||||||
|
from enum import Enum as PyEnum
|
||||||
|
|
||||||
|
class UserRole(str, PyEnum):
|
||||||
|
ADMIN = "admin"
|
||||||
|
USER = "user"
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
uuid = Column(String, unique=True)
|
||||||
|
username = Column(String, unique=True)
|
||||||
|
hashed_password = Column(String)
|
||||||
|
role = Column(Enum(UserRole), nullable=False, default=UserRole.USER)
|
||||||
|
|
||||||
|
name = Column(String)
|
||||||
|
|
||||||
|
class TokenBlacklist(Base):
|
||||||
|
__tablename__ = "token_blacklist"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
token = Column(String, unique=True)
|
||||||
|
expires_at = Column(DateTime)
|
||||||
33
backend/modules/auth/schemas.py
Normal file
33
backend/modules/auth/schemas.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# modules/auth/schemas.py
|
||||||
|
from enum import Enum as PyEnum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
refresh_token: str | None = None
|
||||||
|
|
||||||
|
class TokenData(BaseModel):
|
||||||
|
username: str | None = None
|
||||||
|
scopes: list[str] = []
|
||||||
|
|
||||||
|
class UserRole(str, PyEnum):
|
||||||
|
ADMIN = "admin"
|
||||||
|
USER = "user"
|
||||||
|
|
||||||
|
class UserCreate(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class UserPatch(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
uuid: str
|
||||||
|
username: str
|
||||||
|
name: str
|
||||||
|
role: UserRole
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
172
backend/modules/auth/security.py
Normal file
172
backend/modules/auth/security.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# modules/auth/security.py
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from argon2 import PasswordHasher
|
||||||
|
from argon2.exceptions import VerifyMismatchError
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.database import get_db
|
||||||
|
from core.config import settings
|
||||||
|
from modules.auth.models import TokenBlacklist, User
|
||||||
|
from modules.auth.schemas import TokenData
|
||||||
|
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
|
||||||
|
|
||||||
|
class TokenType(str, Enum):
|
||||||
|
ACCESS = "access"
|
||||||
|
REFRESH = "refresh"
|
||||||
|
|
||||||
|
|
||||||
|
password_hasher = PasswordHasher()
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""Hash a password with Argon2 (and optional pepper)."""
|
||||||
|
peppered_password = password + settings.PEPPER # Prepend/append pepper
|
||||||
|
return password_hasher.hash(peppered_password)
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""Verify a password against its hashed version using Argon2."""
|
||||||
|
peppered_password = plain_password + settings.PEPPER
|
||||||
|
try:
|
||||||
|
return password_hasher.verify(hashed_password, peppered_password)
|
||||||
|
except VerifyMismatchError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def authenticate_user(username: str, password: str, db: Session) -> User | None:
|
||||||
|
"""
|
||||||
|
Authenticate a user by checking username/password against the database.
|
||||||
|
Returns User object if valid, None otherwise.
|
||||||
|
"""
|
||||||
|
# Get user from database
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
|
||||||
|
# If user not found or password doesn't match
|
||||||
|
if not user or not verify_password(password, user.hashed_password):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||||
|
to_encode = data.copy()
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
|
||||||
|
return jwt.encode(
|
||||||
|
to_encode,
|
||||||
|
settings.JWT_SECRET_KEY,
|
||||||
|
algorithm=settings.JWT_ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_refresh_token(data: dict, expires_delta: timedelta | None = None):
|
||||||
|
to_encode = data.copy()
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
|
||||||
|
return jwt.encode(
|
||||||
|
to_encode,
|
||||||
|
settings.JWT_SECRET_KEY,
|
||||||
|
algorithm=settings.JWT_ALGORITHM
|
||||||
|
)
|
||||||
|
|
||||||
|
def verify_token(token: str, expected_token_type: TokenType, db: Session) -> TokenData | None:
|
||||||
|
"""Verify a JWT token and return TokenData if valid.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
token: str
|
||||||
|
The JWT token to be verified.
|
||||||
|
expected_token_type: TokenType
|
||||||
|
The expected type of token (access or refresh)
|
||||||
|
db: Session
|
||||||
|
Database session to fetch user data.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
TokenData | None
|
||||||
|
TokenData instance if the token is valid, None otherwise.
|
||||||
|
"""
|
||||||
|
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
|
||||||
|
if is_blacklisted:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
token_type: str = payload.get("token_type")
|
||||||
|
|
||||||
|
if username is None or token_type != expected_token_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return TokenData(username=username)
|
||||||
|
|
||||||
|
except JWTError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_current_user(db: Annotated[Session, Depends(get_db)], token: str = Depends(oauth2_scheme)) -> User:
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the token is blacklisted
|
||||||
|
is_blacklisted = db.query(TokenBlacklist).filter(TokenBlacklist.token == token).first() is not None
|
||||||
|
if is_blacklisted:
|
||||||
|
raise credentials_exception
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.JWT_SECRET_KEY,
|
||||||
|
algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise credentials_exception
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
user: User = db.query(User).filter(User.username == username).first()
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception
|
||||||
|
return user
|
||||||
|
|
||||||
|
def blacklist_tokens(access_token: str, refresh_token: str, db: Session) -> None:
|
||||||
|
"""Blacklist both access and refresh tokens.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
access_token: str
|
||||||
|
The access token to blacklist
|
||||||
|
refresh_token: str
|
||||||
|
The refresh token to blacklist
|
||||||
|
db: Session
|
||||||
|
Database session to perform the operation.
|
||||||
|
"""
|
||||||
|
for token in [access_token, refresh_token]:
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||||
|
|
||||||
|
# Add the token to the blacklist
|
||||||
|
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
|
||||||
|
db.add(blacklisted_token)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
def blacklist_token(token: str, db: Session) -> None:
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||||
|
|
||||||
|
# Add the token to the blacklist
|
||||||
|
blacklisted_token = TokenBlacklist(token=token, expires_at=expires_at)
|
||||||
|
db.add(blacklisted_token)
|
||||||
|
db.commit()
|
||||||
30
backend/modules/auth/services.py
Normal file
30
backend/modules/auth/services.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# modules/auth/services.py
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from modules.auth.models import User
|
||||||
|
from modules.auth.schemas import UserResponse
|
||||||
|
from modules.auth.security import hash_password
|
||||||
|
from core.exceptions import conflict_exception
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(username: str, password: str, name: str, db: Session) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Create a new user in the database.
|
||||||
|
Hashes the password before storing it.
|
||||||
|
Returns the created user object.
|
||||||
|
"""
|
||||||
|
if db is None:
|
||||||
|
raise ValueError("Database session is required")
|
||||||
|
|
||||||
|
# Check if the user already exists
|
||||||
|
existing_user = db.query(User).filter(User.username == username).first()
|
||||||
|
if existing_user:
|
||||||
|
raise conflict_exception("Username already exists")
|
||||||
|
|
||||||
|
hashed_password = hash_password(password)
|
||||||
|
user_uuid = str(uuid.uuid4())
|
||||||
|
user = User(username=username, hashed_password=hashed_password, name=name, uuid=user_uuid)
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user) # Loads the generated ID
|
||||||
|
return UserResponse.model_validate(user) # Converts SQLAlchemy model -> Pydantic
|
||||||
0
backend/modules/auth/tasks.py
Normal file
0
backend/modules/auth/tasks.py
Normal file
0
backend/modules/user/__init__.py
Normal file
0
backend/modules/user/__init__.py
Normal file
BIN
backend/modules/user/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/modules/user/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/modules/user/__pycache__/api.cpython-312.pyc
Normal file
BIN
backend/modules/user/__pycache__/api.cpython-312.pyc
Normal file
Binary file not shown.
78
backend/modules/user/api.py
Normal file
78
backend/modules/user/api.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# modules/user/api.py
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.database import get_db
|
||||||
|
from core.exceptions import unauthorized_exception, not_found_exception, forbidden_exception
|
||||||
|
from modules.auth.schemas import UserPatch, UserResponse
|
||||||
|
from modules.auth.dependencies import get_current_user
|
||||||
|
from modules.auth.models import User
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
def me(db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Get the current user. Requires user to be logged in.
|
||||||
|
Returns the user object.
|
||||||
|
"""
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
@router.get("/{username}", response_model=UserResponse)
|
||||||
|
def get_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Get a user by username.
|
||||||
|
Returns the user object.
|
||||||
|
"""
|
||||||
|
if current_user.username != username:
|
||||||
|
raise forbidden_exception("You can only view your own profile")
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
if not user:
|
||||||
|
raise not_found_exception("User not found")
|
||||||
|
return user
|
||||||
|
|
||||||
|
@router.patch("/{username}", response_model=UserResponse)
|
||||||
|
def update_user(username: str, user_data: UserPatch, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Update a user by username.
|
||||||
|
Returns the updated user object.
|
||||||
|
"""
|
||||||
|
if current_user.username != username:
|
||||||
|
raise forbidden_exception("You can only update your own profile")
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
if not user:
|
||||||
|
raise not_found_exception("User not found")
|
||||||
|
|
||||||
|
# Define fields that should not be updated
|
||||||
|
non_updateable_fields = {"uuid", "role", "username"}
|
||||||
|
|
||||||
|
print("BEFORE: ", user_data.model_dump(exclude_unset=True))
|
||||||
|
# Update only allowed fields
|
||||||
|
for key, value in user_data.model_dump(exclude_unset=True).items():
|
||||||
|
if key not in non_updateable_fields:
|
||||||
|
setattr(user, key, value)
|
||||||
|
|
||||||
|
print("AFTER:", user_data.model_dump(exclude_unset=True))
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@router.delete("/{username}", response_model=UserResponse)
|
||||||
|
def delete_user(username: str, db: Annotated[Session, Depends(get_db)], current_user: Annotated[User, Depends(get_current_user)]) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Delete a user by username.
|
||||||
|
Returns the deleted user object.
|
||||||
|
"""
|
||||||
|
if current_user.username != username:
|
||||||
|
raise forbidden_exception("You can only delete your own profile")
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
if not user:
|
||||||
|
raise not_found_exception("User not found")
|
||||||
|
|
||||||
|
db.delete(user)
|
||||||
|
db.commit()
|
||||||
|
return user
|
||||||
45
backend/requirements.txt
Normal file
45
backend/requirements.txt
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
amqp==5.3.1
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.9.0
|
||||||
|
bcrypt==4.3.0
|
||||||
|
billiard==4.2.1
|
||||||
|
celery==5.5.1
|
||||||
|
cffi==1.17.1
|
||||||
|
click==8.1.8
|
||||||
|
click-didyoumean==0.3.1
|
||||||
|
click-plugins==1.1.1
|
||||||
|
click-repl==0.3.0
|
||||||
|
cryptography==44.0.2
|
||||||
|
ecdsa==0.19.1
|
||||||
|
fastapi==0.115.12
|
||||||
|
greenlet==3.1.1
|
||||||
|
h11==0.14.0
|
||||||
|
idna==3.10
|
||||||
|
iniconfig==2.1.0
|
||||||
|
kombu==5.5.2
|
||||||
|
packaging==24.2
|
||||||
|
passlib==1.7.4
|
||||||
|
pluggy==1.5.0
|
||||||
|
prompt_toolkit==3.0.50
|
||||||
|
psycopg2-binary==2.9.10
|
||||||
|
pyasn1==0.4.8
|
||||||
|
pycparser==2.22
|
||||||
|
pydantic==2.11.3
|
||||||
|
pydantic_core==2.33.1
|
||||||
|
pytest==8.3.5
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
python-dotenv==1.1.0
|
||||||
|
python-jose==3.4.0
|
||||||
|
python-multipart==0.0.20
|
||||||
|
redis==5.2.1
|
||||||
|
rsa==4.9
|
||||||
|
six==1.17.0
|
||||||
|
sniffio==1.3.1
|
||||||
|
SQLAlchemy==2.0.40
|
||||||
|
starlette==0.46.2
|
||||||
|
typing-inspection==0.4.0
|
||||||
|
typing_extensions==4.13.2
|
||||||
|
tzdata==2025.2
|
||||||
|
uvicorn==0.34.1
|
||||||
|
vine==5.1.0
|
||||||
|
wcwidth==0.2.13
|
||||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
BIN
backend/tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.5.pyc
Normal file
BIN
backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.5.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.5.pyc
Normal file
BIN
backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.5.pyc
Normal file
Binary file not shown.
BIN
backend/tests/__pycache__/test_conf.cpython-312-pytest-8.3.5.pyc
Normal file
BIN
backend/tests/__pycache__/test_conf.cpython-312-pytest-8.3.5.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
58
backend/tests/conftest.py
Normal file
58
backend/tests/conftest.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# conftest.py
|
||||||
|
from typing import Generator, Callable, Any
|
||||||
|
import pytest
|
||||||
|
from testcontainers.postgres import PostgresContainer
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from core.config import settings
|
||||||
|
from faker import Faker
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.database import get_db, get_sessionmaker
|
||||||
|
|
||||||
|
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def postgres_container() -> Generator[PostgresContainer, None, None]:
|
||||||
|
"""Fixture to create a PostgreSQL container for testing."""
|
||||||
|
print("Starting Postgres container...")
|
||||||
|
with PostgresContainer("postgres:latest") as postgres:
|
||||||
|
settings.DB_URL = postgres.get_connection_url()
|
||||||
|
print(f"Postgres container started at {settings.DB_URL}")
|
||||||
|
yield postgres
|
||||||
|
print("Postgres container stopped.")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def db(postgres_container) -> Generator[Session, None, None]:
|
||||||
|
"""Function-scoped database session with rollback"""
|
||||||
|
SessionLocal = get_sessionmaker()
|
||||||
|
session = SessionLocal()
|
||||||
|
session.begin_nested() # Enable nested transaction
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.rollback()
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def client(db: Session) -> Generator[TestClient, None, None]:
|
||||||
|
"""Function-scoped test client with dependency override"""
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
# Override the database dependency
|
||||||
|
def override_get_db():
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
pass # Don't close session here
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
def override_dependency(dependency: Callable[..., Any], mocked_response: Any) -> None:
|
||||||
|
from main import app
|
||||||
|
app.dependency_overrides[dependency] = lambda: mocked_response
|
||||||
BIN
backend/tests/helpers/__pycache__/generators.cpython-312.pyc
Normal file
BIN
backend/tests/helpers/__pycache__/generators.cpython-312.pyc
Normal file
Binary file not shown.
42
backend/tests/helpers/generators.py
Normal file
42
backend/tests/helpers/generators.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
import uuid as uuid_pkg
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.config import settings
|
||||||
|
from modules.auth.models import User
|
||||||
|
from modules.auth.security import authenticate_user, create_access_token, create_refresh_token, hash_password
|
||||||
|
from modules.auth.schemas import UserRole
|
||||||
|
from tests.conftest import fake
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(db: Session, is_admin: bool = False) -> User:
|
||||||
|
unhashed_password = fake.password()
|
||||||
|
_user = User(
|
||||||
|
name=fake.name(),
|
||||||
|
username=fake.user_name(),
|
||||||
|
hashed_password=hash_password(unhashed_password),
|
||||||
|
uuid=uuid_pkg.uuid4(),
|
||||||
|
role=UserRole.ADMIN if is_admin else UserRole.USER,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(_user)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(_user)
|
||||||
|
return _user, unhashed_password # return for testing
|
||||||
|
|
||||||
|
def login(db: Session, username: str, password: str) -> str:
|
||||||
|
user = authenticate_user(username, password, db)
|
||||||
|
if not user:
|
||||||
|
raise Exception("Incorrect username or password")
|
||||||
|
|
||||||
|
access_token = create_access_token(data={"sub": user.username}, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||||
|
refresh_token = create_refresh_token(data={"sub": user.username})
|
||||||
|
|
||||||
|
max_age = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"max_age": max_age,
|
||||||
|
}
|
||||||
180
backend/tests/test_auth.py
Normal file
180
backend/tests/test_auth.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
# Main test file for the authentication process.
|
||||||
|
# uses conftest -> db_session as an in-memory db.
|
||||||
|
|
||||||
|
# Goes through the whole authentication process:
|
||||||
|
# 1. Register a user
|
||||||
|
# 2. Login the user
|
||||||
|
# 3. Refresh the token
|
||||||
|
# 4. Logout the user
|
||||||
|
# 5. Verify that the user is logged out
|
||||||
|
# 6. Verify that the user cannot refresh the token
|
||||||
|
# 7. Verify that the user cannot login again
|
||||||
|
# 8. Verify that the user cannot register again
|
||||||
|
# 9. Verify that the user cannot access protected routes (/admin)
|
||||||
|
|
||||||
|
import time
|
||||||
|
from fastapi import status
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from modules.auth.models import TokenBlacklist, User
|
||||||
|
from tests.conftest import fake
|
||||||
|
|
||||||
|
from .helpers import generators
|
||||||
|
|
||||||
|
|
||||||
|
def test_register(client: TestClient) -> None:
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/register",
|
||||||
|
json={
|
||||||
|
"username": fake.user_name(),
|
||||||
|
"password": fake.password(),
|
||||||
|
"name": fake.name(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_201_CREATED
|
||||||
|
|
||||||
|
def test_login(db: Session, client: TestClient) -> None:
|
||||||
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/login",
|
||||||
|
data={
|
||||||
|
"username": user.username,
|
||||||
|
"password": unhashed_password,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
assert "access_token" in response_data
|
||||||
|
assert "token_type" in response_data
|
||||||
|
assert response_data["token_type"] == "bearer"
|
||||||
|
|
||||||
|
def test_refresh_token(db: Session, client: TestClient) -> None:
|
||||||
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
rsp = generators.login(db, user.username, unhashed_password)
|
||||||
|
access_token = rsp["access_token"]
|
||||||
|
refresh_token = rsp["refresh_token"]
|
||||||
|
|
||||||
|
time.sleep(1) # Sleep to ensure tokens won't be identical
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
cookies={"refresh_token": refresh_token},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
assert "access_token" in response_data
|
||||||
|
assert "token_type" in response_data
|
||||||
|
assert response_data["token_type"] == "bearer"
|
||||||
|
assert response_data["access_token"] != access_token # Ensure the token is refreshed
|
||||||
|
|
||||||
|
def test_logout(db: Session, client: TestClient) -> None:
|
||||||
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
rsp = generators.login(db, user.username, unhashed_password)
|
||||||
|
access_token = rsp["access_token"]
|
||||||
|
refresh_token = rsp["refresh_token"]
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/logout",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
cookies={"refresh_token": refresh_token},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# Verify that the token is blacklisted
|
||||||
|
blacklisted_token = db.query(TokenBlacklist).filter(TokenBlacklist.token == access_token).first()
|
||||||
|
assert blacklisted_token is not None
|
||||||
|
|
||||||
|
# Verify that we can't still actually do anything
|
||||||
|
response = client.get(
|
||||||
|
"/api/user/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/auth/refresh",
|
||||||
|
cookies={"refresh_token": refresh_token},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_me(db: Session, client: TestClient) -> None:
|
||||||
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/user/me",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
assert response_data["uuid"] == user.uuid
|
||||||
|
assert response_data["username"] == user.username
|
||||||
|
|
||||||
|
def test_get_me_unauthorized(client: TestClient) -> None:
|
||||||
|
### This test should fail (unauthorized) because the user isn't logged in
|
||||||
|
response = client.get("/api/user/me")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_get_user(db: Session, client: TestClient) -> None:
|
||||||
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
f"/api/user/{user.username}",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
assert response_data["uuid"] == user.uuid
|
||||||
|
assert response_data["username"] == user.username
|
||||||
|
|
||||||
|
def test_get_user_unauthorized(db: Session, client: TestClient) -> None:
|
||||||
|
### This test should fail (unauthorized) because the user isn't us
|
||||||
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
user2, _ = generators.create_user(db)
|
||||||
|
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
f"/api/user/{user2.username}",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
def test_update_user(db: Session, client: TestClient) -> None:
|
||||||
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
new_name = fake.name()
|
||||||
|
|
||||||
|
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||||
|
response = client.patch(
|
||||||
|
f"/api/user/{user.username}",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
json={"name": new_name},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
response_data = response.json()
|
||||||
|
assert response_data["name"] == new_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_user(db: Session, client: TestClient) -> None:
|
||||||
|
user, unhashed_password = generators.create_user(db)
|
||||||
|
access_token = generators.login(db, user.username, unhashed_password)["access_token"]
|
||||||
|
response = client.delete(
|
||||||
|
f"/api/user/{user.username}",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
# Verify that the user is deleted
|
||||||
|
deleted_user = db.query(User).filter(User.username == user.username).first()
|
||||||
|
assert deleted_user is None
|
||||||
|
|
||||||
Reference in New Issue
Block a user