[V1.0] Working application, added notifications.
Ready to upload to store.
This commit is contained in:
@@ -2,6 +2,9 @@ DB_HOST = "db"
|
||||
DB_USER = "maia"
|
||||
DB_PASSWORD = "maia"
|
||||
DB_NAME = "maia"
|
||||
|
||||
REDIS_URL = "redis://redis:6379"
|
||||
|
||||
PEPPER = "LsD7%"
|
||||
JWT_SECRET_KEY="1c8cf3ca6972b365f8108dad247e61abdcb6faff5a6c8ba00cb6fa17396702bf"
|
||||
GOOGLE_API_KEY="AIzaSyBrte_mETZJce8qE6cRTSz_fHOjdjlShBk"
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import create_engine # Add create_engine import
|
||||
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Initial migration with existing tables
|
||||
|
||||
Revision ID: 69069d6184b3
|
||||
Revises:
|
||||
Create Date: 2025-04-21 01:14:33.233195
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "69069d6184b3"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Add todo table
|
||||
|
||||
Revision ID: 9a82960db482
|
||||
Revises: 69069d6184b3
|
||||
Create Date: 2025-04-21 20:33:27.028529
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9a82960db482"
|
||||
down_revision: Union[str, None] = "69069d6184b3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,94 +0,0 @@
|
||||
"""Add all_day column to calendar_events
|
||||
|
||||
Revision ID: a34d847510da
|
||||
Revises: 9a82960db482
|
||||
Create Date: 2025-04-26 11:09:35.400748
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a34d847510da'
|
||||
down_revision: Union[str, None] = '9a82960db482'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('calendar_events')
|
||||
op.drop_table('users')
|
||||
op.drop_index('ix_todos_id', table_name='todos')
|
||||
op.drop_index('ix_todos_task', table_name='todos')
|
||||
op.drop_table('todos')
|
||||
op.drop_table('token_blacklist')
|
||||
op.drop_index('ix_chat_messages_id', table_name='chat_messages')
|
||||
op.drop_index('ix_chat_messages_user_id', table_name='chat_messages')
|
||||
op.drop_table('chat_messages')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('chat_messages',
|
||||
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column('sender', postgresql.ENUM('USER', 'AI', name='messagesender'), autoincrement=False, nullable=False),
|
||||
sa.Column('text', sa.TEXT(), autoincrement=False, nullable=False),
|
||||
sa.Column('timestamp', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], name='chat_messages_user_id_fkey'),
|
||||
sa.PrimaryKeyConstraint('id', name='chat_messages_pkey')
|
||||
)
|
||||
op.create_index('ix_chat_messages_user_id', 'chat_messages', ['user_id'], unique=False)
|
||||
op.create_index('ix_chat_messages_id', 'chat_messages', ['id'], unique=False)
|
||||
op.create_table('token_blacklist',
|
||||
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('token', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('expires_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='token_blacklist_pkey'),
|
||||
sa.UniqueConstraint('token', name='token_blacklist_token_key')
|
||||
)
|
||||
op.create_table('todos',
|
||||
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('task', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('date', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
|
||||
sa.Column('remind', sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
sa.Column('complete', sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
sa.Column('owner_id', sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], name='todos_owner_id_fkey'),
|
||||
sa.PrimaryKeyConstraint('id', name='todos_pkey')
|
||||
)
|
||||
op.create_index('ix_todos_task', 'todos', ['task'], unique=False)
|
||||
op.create_index('ix_todos_id', 'todos', ['id'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.INTEGER(), server_default=sa.text("nextval('users_id_seq'::regclass)"), autoincrement=True, nullable=False),
|
||||
sa.Column('uuid', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('username', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('role', postgresql.ENUM('ADMIN', 'USER', name='userrole'), autoincrement=False, nullable=False),
|
||||
sa.Column('hashed_password', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='users_pkey'),
|
||||
sa.UniqueConstraint('username', name='users_username_key'),
|
||||
sa.UniqueConstraint('uuid', name='users_uuid_key'),
|
||||
postgresql_ignore_search_path=False
|
||||
)
|
||||
op.create_table('calendar_events',
|
||||
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column('title', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('description', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('start', postgresql.TIMESTAMP(), autoincrement=False, nullable=False),
|
||||
sa.Column('end', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
|
||||
sa.Column('location', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('tags', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
||||
sa.Column('color', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('user_id', sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], name='calendar_events_user_id_fkey'),
|
||||
sa.PrimaryKeyConstraint('id', name='calendar_events_pkey')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,6 +1,7 @@
|
||||
# core/celery_app.py
|
||||
from celery import Celery
|
||||
from core.config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"worker",
|
||||
broker=settings.REDIS_URL,
|
||||
@@ -8,5 +9,15 @@ celery_app = Celery(
|
||||
include=[
|
||||
"modules.auth.tasks",
|
||||
"modules.admin.tasks",
|
||||
"modules.calendar.tasks", # Add calendar tasks
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# Optional: Configure Celery Beat if you need periodic tasks later
|
||||
# celery_app.conf.beat_schedule = {
|
||||
# 'check-something-every-5-minutes': {
|
||||
# 'task': 'your_app.tasks.check_something',
|
||||
# 'schedule': timedelta(minutes=5),
|
||||
# },
|
||||
# }
|
||||
celery_app.conf.timezone = "UTC" # Recommended to use UTC
|
||||
|
||||
@@ -27,6 +27,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# Other settings
|
||||
GOOGLE_API_KEY: str
|
||||
EXPO_PUSH_API_URL: str = "https://exp.host/--/api/v2/push/send"
|
||||
|
||||
class Config:
|
||||
# Tell pydantic-settings to load variables from a .env file
|
||||
|
||||
@@ -11,9 +11,10 @@ _SessionLocal = None
|
||||
|
||||
settings.DB_URL = f"postgresql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}"
|
||||
|
||||
|
||||
def get_engine():
|
||||
global _engine
|
||||
if (_engine is None):
|
||||
if _engine is None:
|
||||
if not settings.DB_URL:
|
||||
raise ValueError("DB_URL is not set in Settings.")
|
||||
print(f"Connecting to database at {settings.DB_URL}")
|
||||
|
||||
@@ -47,7 +47,7 @@ services:
|
||||
image: postgres:15 # Use a specific version
|
||||
container_name: MAIA-DB
|
||||
volumes:
|
||||
- ./db:/var/lib/postgresql/data # Persist data using a named volume
|
||||
- db:/var/lib/postgresql/data # Persist data using a named volume
|
||||
environment:
|
||||
- POSTGRES_USER=${DB_USER}
|
||||
- POSTGRES_PASSWORD=${DB_PASSWORD}
|
||||
@@ -63,11 +63,17 @@ services:
|
||||
image: redis:7 # Use a specific version
|
||||
container_name: MAIA-Redis
|
||||
volumes:
|
||||
- ./redis_data:/data
|
||||
- redis_data:/data
|
||||
networks:
|
||||
- maia_network
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
db: # Named volume for PostgreSQL data
|
||||
driver: local
|
||||
redis_data: # Named volume for Redis data
|
||||
driver: local
|
||||
|
||||
# ----- Network Definition -----
|
||||
networks:
|
||||
maia_network: # Define a custom bridge network
|
||||
|
||||
@@ -30,7 +30,6 @@ app.add_middleware(
|
||||
"https://maia.depaoli.id.au",
|
||||
"http://localhost:8081",
|
||||
],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
Binary file not shown.
@@ -1,10 +1,12 @@
|
||||
# modules/admin/api.py
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import Annotated, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from core.database import get_db
|
||||
from modules.auth.dependencies import admin_only
|
||||
from modules.auth.models import User
|
||||
from modules.notifications.service import send_push_notification
|
||||
from .tasks import cleardb
|
||||
|
||||
router = APIRouter(prefix="/admin", tags=["admin"], dependencies=[Depends(admin_only)])
|
||||
@@ -14,6 +16,13 @@ class ClearDbRequest(BaseModel):
|
||||
hard: bool
|
||||
|
||||
|
||||
class SendNotificationRequest(BaseModel):
|
||||
username: str
|
||||
title: str
|
||||
body: str
|
||||
data: Optional[dict] = None
|
||||
|
||||
|
||||
@router.get("/")
|
||||
def read_admin():
|
||||
return {"message": "Admin route"}
|
||||
@@ -29,3 +38,43 @@ def clear_db(payload: ClearDbRequest, db: Annotated[Session, Depends(get_db)]):
|
||||
hard = payload.hard
|
||||
cleardb.delay(hard)
|
||||
return {"message": "Clearing database in the background", "hard": hard}
|
||||
|
||||
|
||||
@router.post("/send-notification", status_code=status.HTTP_200_OK)
|
||||
async def send_user_notification(
|
||||
payload: SendNotificationRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
):
|
||||
"""
|
||||
Admin endpoint to send a push notification to a specific user by username.
|
||||
"""
|
||||
target_user = db.query(User).filter(User.username == payload.username).first()
|
||||
|
||||
if not target_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User with username '{payload.username}' not found.",
|
||||
)
|
||||
|
||||
if not target_user.expo_push_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"User '{payload.username}' does not have a registered push token.",
|
||||
)
|
||||
|
||||
success = await send_push_notification(
|
||||
push_token=target_user.expo_push_token,
|
||||
title=payload.title,
|
||||
body=payload.body,
|
||||
data=payload.data,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to send push notification via Expo service.",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Push notification sent successfully to user '{payload.username}'"
|
||||
}
|
||||
|
||||
Binary file not shown.
@@ -1,6 +1,6 @@
|
||||
# modules/auth/models.py
|
||||
from core.database import Base
|
||||
from sqlalchemy import Column, Integer, String, Enum, DateTime
|
||||
from sqlalchemy import Column, Integer, String, Enum, DateTime, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
@@ -18,6 +18,7 @@ class User(Base):
|
||||
name = Column(String)
|
||||
role = Column(Enum(UserRole), nullable=False, default=UserRole.USER)
|
||||
hashed_password = Column(String)
|
||||
expo_push_token = Column(Text, nullable=True)
|
||||
calendar_events = relationship("CalendarEvent", back_populates="user")
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
BIN
backend/modules/calendar/__pycache__/tasks.cpython-312.pyc
Normal file
BIN
backend/modules/calendar/__pycache__/tasks.cpython-312.pyc
Normal file
Binary file not shown.
@@ -7,7 +7,7 @@ from sqlalchemy import (
|
||||
ForeignKey,
|
||||
JSON,
|
||||
Boolean,
|
||||
) # Add Boolean
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from core.database import Base
|
||||
|
||||
@@ -18,15 +18,12 @@ class CalendarEvent(Base):
|
||||
id = Column(Integer, primary_key=True)
|
||||
title = Column(String, nullable=False)
|
||||
description = Column(String)
|
||||
start = Column(DateTime, nullable=False)
|
||||
end = Column(DateTime)
|
||||
start = Column(DateTime(timezone=True), nullable=False)
|
||||
end = Column(DateTime(timezone=True))
|
||||
location = Column(String)
|
||||
all_day = Column(Boolean, default=False) # Add all_day column
|
||||
all_day = Column(Boolean, default=False)
|
||||
tags = Column(JSON)
|
||||
color = Column(String) # hex code for color
|
||||
user_id = Column(
|
||||
Integer, ForeignKey("users.id"), nullable=False
|
||||
) # <-- Relationship
|
||||
color = Column(String)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Bi-directional relationship (for eager loading)
|
||||
user = relationship("User", back_populates="calendar_events")
|
||||
|
||||
@@ -7,7 +7,13 @@ from core.exceptions import not_found_exception
|
||||
from modules.calendar.schemas import (
|
||||
CalendarEventCreate,
|
||||
CalendarEventUpdate,
|
||||
) # Import schemas
|
||||
)
|
||||
|
||||
# Import the celery app instance instead of the task functions directly
|
||||
from core.celery_app import celery_app
|
||||
|
||||
# Keep task imports if cancel_event_notifications is still called directly and synchronously
|
||||
from modules.calendar.tasks import cancel_event_notifications
|
||||
|
||||
|
||||
def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCreate):
|
||||
@@ -23,6 +29,11 @@ def create_calendar_event(db: Session, user_id: int, event_data: CalendarEventCr
|
||||
db.add(event)
|
||||
db.commit()
|
||||
db.refresh(event)
|
||||
# Schedule notifications using send_task
|
||||
celery_app.send_task(
|
||||
"modules.calendar.tasks.schedule_event_notifications", # Task name as string
|
||||
args=[event.id],
|
||||
)
|
||||
return event
|
||||
|
||||
|
||||
@@ -114,10 +125,17 @@ def update_calendar_event(
|
||||
|
||||
db.commit()
|
||||
db.refresh(event)
|
||||
# Re-schedule notifications using send_task
|
||||
celery_app.send_task(
|
||||
"modules.calendar.tasks.schedule_event_notifications", args=[event.id]
|
||||
)
|
||||
return event
|
||||
|
||||
|
||||
def delete_calendar_event(db: Session, user_id: int, event_id: int):
|
||||
event = get_calendar_event_by_id(db, user_id, event_id) # Reuse get_by_id for check
|
||||
# Cancel any scheduled notifications before deleting
|
||||
# Run synchronously here or make cancel_event_notifications an async task
|
||||
cancel_event_notifications(event_id)
|
||||
db.delete(event)
|
||||
db.commit()
|
||||
|
||||
233
backend/modules/calendar/tasks.py
Normal file
233
backend/modules/calendar/tasks.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# backend/modules/calendar/tasks.py
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, time, timezone
|
||||
|
||||
from celery import shared_task
|
||||
from celery.exceptions import Ignore
|
||||
|
||||
from core.celery_app import celery_app
|
||||
from core.database import get_db
|
||||
from modules.calendar.models import CalendarEvent
|
||||
from modules.notifications.service import send_push_notification
|
||||
from modules.auth.models import User # Assuming user model is in modules/user/models.py
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Key prefix for storing scheduled task IDs in Redis (or Celery backend)
|
||||
SCHEDULED_TASK_KEY_PREFIX = "calendar_event_tasks:"
|
||||
|
||||
|
||||
def get_scheduled_task_key(event_id: int) -> str:
|
||||
return f"{SCHEDULED_TASK_KEY_PREFIX}{event_id}"
|
||||
|
||||
|
||||
@shared_task(bind=True)
|
||||
def schedule_event_notifications(self, event_id: int):
|
||||
"""Schedules reminder notifications for a calendar event."""
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
event = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first()
|
||||
if not event:
|
||||
logger.warning(
|
||||
f"Calendar event {event_id} not found for scheduling notifications."
|
||||
)
|
||||
raise Ignore() # Don't retry if event doesn't exist
|
||||
|
||||
user = db.query(User).filter(User.id == event.user_id).first()
|
||||
if not user or not user.expo_push_token:
|
||||
logger.warning(
|
||||
f"User {event.user_id} or their push token not found for event {event_id}. Skipping notification scheduling."
|
||||
)
|
||||
# Cancel any potentially existing tasks for this event if user/token is now invalid
|
||||
cancel_event_notifications(event_id)
|
||||
raise Ignore() # Don't retry if user/token missing
|
||||
|
||||
# Cancel any existing notifications for this event first
|
||||
cancel_event_notifications(event_id) # Run synchronously within this task
|
||||
|
||||
scheduled_task_ids = []
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
|
||||
if event.all_day:
|
||||
# Schedule one notification at 6:00 AM in the event's original timezone (or UTC if naive)
|
||||
event_start_date = event.start.date()
|
||||
notification_time_local = datetime.combine(
|
||||
event_start_date, time(6, 0), tzinfo=event.start.tzinfo
|
||||
)
|
||||
# Convert scheduled time to UTC for Celery ETA
|
||||
notification_time_utc = notification_time_local.astimezone(timezone.utc)
|
||||
|
||||
if notification_time_utc > now_utc:
|
||||
task = send_event_notification.apply_async(
|
||||
args=[event.id, user.expo_push_token, "all_day"],
|
||||
eta=notification_time_utc,
|
||||
)
|
||||
scheduled_task_ids.append(task.id)
|
||||
logger.info(
|
||||
f"Scheduled all-day notification for event {event_id} at {notification_time_utc} (Task ID: {task.id})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"All-day notification time {notification_time_utc} for event {event_id} is in the past. Skipping."
|
||||
)
|
||||
|
||||
else:
|
||||
# Ensure event start time is timezone-aware (assume UTC if naive)
|
||||
event_start_utc = event.start
|
||||
if event_start_utc.tzinfo is None:
|
||||
event_start_utc = event_start_utc.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
event_start_utc = event_start_utc.astimezone(timezone.utc)
|
||||
|
||||
times_before = {
|
||||
"1_hour": timedelta(hours=1),
|
||||
"30_min": timedelta(minutes=30),
|
||||
}
|
||||
|
||||
for label, delta in times_before.items():
|
||||
notification_time_utc = event_start_utc - delta
|
||||
if notification_time_utc > now_utc:
|
||||
task = send_event_notification.apply_async(
|
||||
args=[event.id, user.expo_push_token, label],
|
||||
eta=notification_time_utc,
|
||||
)
|
||||
scheduled_task_ids.append(task.id)
|
||||
logger.info(
|
||||
f"Scheduled {label} notification for event {event_id} at {notification_time_utc} (Task ID: {task.id})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"{label} notification time {notification_time_utc} for event {event_id} is in the past. Skipping."
|
||||
)
|
||||
|
||||
# Store the new task IDs using Celery backend (Redis)
|
||||
if scheduled_task_ids:
|
||||
key = get_scheduled_task_key(event_id)
|
||||
# Store as a simple comma-separated string
|
||||
celery_app.backend.set(key, ",".join(scheduled_task_ids))
|
||||
logger.debug(f"Stored task IDs for event {event_id}: {scheduled_task_ids}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error scheduling notifications for event {event_id}: {e}")
|
||||
# Optional: Add retry logic if appropriate
|
||||
# self.retry(exc=e, countdown=60)
|
||||
finally:
|
||||
next(db_gen, None) # Ensure db session is closed
|
||||
|
||||
|
||||
# Note: This task calls an async function. Ensure your Celery worker
|
||||
# is configured to handle async tasks (e.g., using gevent, eventlet, or uvicorn worker).
|
||||
@shared_task(bind=True)
|
||||
def send_event_notification(
|
||||
self, event_id: int, user_push_token: str, notification_type: str
|
||||
):
|
||||
"""Sends a single reminder notification for a calendar event."""
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
event = db.query(CalendarEvent).filter(CalendarEvent.id == event_id).first()
|
||||
if not event:
|
||||
logger.warning(
|
||||
f"Calendar event {event_id} not found for sending {notification_type} notification."
|
||||
)
|
||||
raise Ignore() # Don't retry if event doesn't exist
|
||||
|
||||
# Double-check user and token validity at the time of sending
|
||||
user = db.query(User).filter(User.id == event.user_id).first()
|
||||
if not user or user.expo_push_token != user_push_token:
|
||||
logger.warning(
|
||||
f"User {event.user_id} token mismatch or user not found for event {event_id} at notification time. Skipping."
|
||||
)
|
||||
raise Ignore()
|
||||
|
||||
title = f"Upcoming: {event.title}"
|
||||
if notification_type == "all_day":
|
||||
body = f"Today: {event.title}"
|
||||
if event.description:
|
||||
body += f" - {event.description[:50]}" # Add part of description
|
||||
elif notification_type == "1_hour":
|
||||
local_start_time = event.start.astimezone().strftime(
|
||||
"%I:%M %p"
|
||||
) # Convert to local time for display
|
||||
body = f"Starts at {local_start_time} (in 1 hour)"
|
||||
elif notification_type == "30_min":
|
||||
local_start_time = event.start.astimezone().strftime("%I:%M %p")
|
||||
body = f"Starts at {local_start_time} (in 30 mins)"
|
||||
else:
|
||||
body = "Check your calendar for details." # Fallback
|
||||
|
||||
logger.info(
|
||||
f"Sending {notification_type} notification for event {event_id} to token {user_push_token[:10]}..."
|
||||
)
|
||||
try:
|
||||
# Call the async notification service
|
||||
success = asyncio.run(
|
||||
send_push_notification(
|
||||
push_token=user_push_token,
|
||||
title=title,
|
||||
body=body,
|
||||
data={"eventId": event.id, "type": "calendar_reminder"},
|
||||
)
|
||||
)
|
||||
if not success:
|
||||
logger.error(
|
||||
f"Failed to send {notification_type} notification for event {event_id} via service."
|
||||
)
|
||||
# Optional: self.retry(countdown=60) # Retry sending if failed
|
||||
else:
|
||||
logger.info(
|
||||
f"Successfully sent {notification_type} notification for event {event_id}."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error calling send_push_notification for event {event_id}: {e}"
|
||||
)
|
||||
# Optional: self.retry(exc=e, countdown=60)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"General error sending {notification_type} notification for event {event_id}: {e}"
|
||||
)
|
||||
# Optional: self.retry(exc=e, countdown=60)
|
||||
finally:
|
||||
next(db_gen, None) # Ensure db session is closed
|
||||
|
||||
|
||||
# This is run synchronously when called, or can be called as a task itself
|
||||
# @shared_task # Uncomment if you want to call this asynchronously e.g., .delay()
|
||||
def cancel_event_notifications(event_id: int):
|
||||
"""Cancels all scheduled reminder notifications for a calendar event."""
|
||||
key = get_scheduled_task_key(event_id)
|
||||
try:
|
||||
task_ids_bytes = celery_app.backend.get(key)
|
||||
|
||||
if task_ids_bytes:
|
||||
# Decode from bytes (assuming Redis backend)
|
||||
task_ids_str = task_ids_bytes.decode("utf-8")
|
||||
task_ids = task_ids_str.split(",")
|
||||
logger.info(f"Cancelling scheduled tasks for event {event_id}: {task_ids}")
|
||||
revoked_count = 0
|
||||
for task_id in task_ids:
|
||||
if task_id: # Ensure not empty string
|
||||
try:
|
||||
celery_app.control.revoke(
|
||||
task_id.strip(), terminate=True, signal="SIGKILL"
|
||||
)
|
||||
revoked_count += 1
|
||||
except Exception as revoke_err:
|
||||
logger.error(
|
||||
f"Error revoking task {task_id} for event {event_id}: {revoke_err}"
|
||||
)
|
||||
# Delete the key from Redis after attempting revocation
|
||||
celery_app.backend.delete(key)
|
||||
logger.debug(
|
||||
f"Revoked {revoked_count} tasks and removed task ID key {key} from backend for event {event_id}."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"No scheduled tasks found in backend to cancel for event {event_id} (key: {key})."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error cancelling notifications for event {event_id}: {e}")
|
||||
0
backend/modules/notifications/__init__.py
Normal file
0
backend/modules/notifications/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
111
backend/modules/notifications/service.py
Normal file
111
backend/modules/notifications/service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import httpx
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def send_push_notification(
|
||||
push_token: str, title: str, body: str, data: Optional[Dict[str, Any]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Sends a push notification to a specific Expo push token.
|
||||
|
||||
Args:
|
||||
push_token: The recipient's Expo push token.
|
||||
title: The title of the notification.
|
||||
body: The main message content of the notification.
|
||||
data: Optional dictionary containing extra data to send with the notification.
|
||||
|
||||
Returns:
|
||||
True if the notification was sent successfully (according to Expo API), False otherwise.
|
||||
"""
|
||||
if not push_token:
|
||||
logger.warning("Attempted to send notification but no push token provided.")
|
||||
return False
|
||||
|
||||
message = {
|
||||
"to": push_token,
|
||||
"sound": "default",
|
||||
"title": title,
|
||||
"body": body,
|
||||
"priority": "high",
|
||||
"channelId": "default",
|
||||
}
|
||||
if data:
|
||||
message["data"] = data
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
settings.EXPO_PUSH_API_URL,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=message,
|
||||
timeout=10.0,
|
||||
)
|
||||
response.raise_for_status() # Raise exception for 4xx/5xx responses
|
||||
|
||||
response_data = response.json()
|
||||
logger.debug(f"Expo push API response: {response_data}")
|
||||
|
||||
# Check for top-level errors first
|
||||
if "errors" in response_data:
|
||||
error_messages = [
|
||||
err.get("message", "Unknown error")
|
||||
for err in response_data["errors"]
|
||||
]
|
||||
logger.error(
|
||||
f"Expo API returned errors for {push_token[:10]}...: {'; '.join(error_messages)}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Check the status in the data field
|
||||
receipt = response_data.get("data")
|
||||
|
||||
# if receipts is a list
|
||||
if receipt:
|
||||
status = receipt.get("status")
|
||||
|
||||
if status == "ok":
|
||||
logger.info(
|
||||
f"Successfully sent push notification to token: {push_token[:10]}..."
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# Log details if the status is not 'ok'
|
||||
error_details = receipt.get("details")
|
||||
error_message = receipt.get("message")
|
||||
logger.error(
|
||||
f"Failed to send push notification to {push_token[:10]}... "
|
||||
f"Expo status: {status}, Message: {error_message}, Details: {error_details}"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# Log if 'data' is missing, not a list, or an empty list
|
||||
logger.error(
|
||||
f"Unexpected Expo API response format or empty 'data' field for {push_token[:10]}... "
|
||||
f"Response: {response_data}"
|
||||
)
|
||||
return False
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"HTTP error sending push notification to {push_token[:10]}...: {e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
logger.error(
|
||||
f"Network error sending push notification to {push_token[:10]}...: {e}"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unexpected error sending push notification to {push_token[:10]}...: {e}"
|
||||
)
|
||||
return False
|
||||
@@ -14,6 +14,4 @@ class Todo(Base):
|
||||
complete = Column(Boolean, default=False)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||
|
||||
owner = relationship(
|
||||
"User"
|
||||
)
|
||||
owner = relationship("User")
|
||||
|
||||
Binary file not shown.
@@ -1,6 +1,7 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import Annotated, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import get_db
|
||||
from core.exceptions import not_found_exception, forbidden_exception
|
||||
@@ -11,6 +12,41 @@ from modules.auth.models import User
|
||||
router = APIRouter(prefix="/user", tags=["user"])
|
||||
|
||||
|
||||
# --- Pydantic Schema for Push Token --- #
|
||||
class PushTokenData(BaseModel):
|
||||
token: str
|
||||
device_name: Optional[str] = None
|
||||
token_type: str # Expecting 'expo'
|
||||
|
||||
|
||||
@router.post("/push-token", status_code=status.HTTP_200_OK)
|
||||
def save_push_token(
|
||||
token_data: PushTokenData,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
):
|
||||
"""
|
||||
Save the Expo push token for the current user.
|
||||
Requires user to be logged in.
|
||||
"""
|
||||
if token_data.token_type != "expo":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid token_type. Only 'expo' is supported.",
|
||||
)
|
||||
|
||||
# Update the user's push token
|
||||
current_user.expo_push_token = token_data.token
|
||||
# Optionally, you could store device_name somewhere if needed, perhaps in a separate table
|
||||
# For now, we just update the token on the user model
|
||||
|
||||
db.add(current_user)
|
||||
db.commit()
|
||||
db.refresh(current_user)
|
||||
|
||||
return {"message": "Push token saved successfully"}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
def me(
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
|
||||
@@ -14,4 +14,5 @@ python-multipart
|
||||
redis
|
||||
SQLAlchemy
|
||||
starlette
|
||||
uvicorn
|
||||
uvicorn
|
||||
eventlet
|
||||
|
||||
@@ -47,8 +47,12 @@ click-plugins==1.1.1
|
||||
# via celery
|
||||
click-repl==0.3.0
|
||||
# via celery
|
||||
dnspython==2.7.0
|
||||
# via eventlet
|
||||
ecdsa==0.19.1
|
||||
# via python-jose
|
||||
eventlet==0.39.1
|
||||
# via -r requirements.in
|
||||
fastapi==0.115.12
|
||||
# via -r requirements.in
|
||||
gevent==25.4.1
|
||||
@@ -61,6 +65,7 @@ google-genai==1.11.0
|
||||
# via -r requirements.in
|
||||
greenlet==3.2.0
|
||||
# via
|
||||
# eventlet
|
||||
# gevent
|
||||
# sqlalchemy
|
||||
h11==0.14.0
|
||||
|
||||
Reference in New Issue
Block a user