# modules/nlp/api.py from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from typing import List from core.database import get_db from modules.auth.dependencies import get_current_user from modules.auth.models import User # Import the new service functions and Enum from modules.nlp.service import process_request, ask_ai, save_chat_message, get_chat_history, MessageSender # Import the response schema and the new ChatMessage model for response type hinting from modules.nlp.schemas import ProcessCommandRequest, ProcessCommandResponse from modules.nlp.models import ChatMessage # Import ChatMessage model from modules.calendar.service import create_calendar_event, get_calendar_events, update_calendar_event, delete_calendar_event from modules.calendar.models import CalendarEvent from modules.calendar.schemas import CalendarEventCreate, CalendarEventUpdate # Import TODO services, schemas, and model from modules.todo import service as todo_service from modules.todo.models import Todo from modules.todo.schemas import TodoCreate, TodoUpdate router = APIRouter(prefix="/nlp", tags=["nlp"]) # Helper to format calendar events (expects list of CalendarEvent models) def format_calendar_events(events: List[CalendarEvent]) -> List[str]: if not events: return ["You have no events matching that criteria."] formatted = ["Here are the events:"] for event in events: # Access attributes directly from the model instance start_str = event.start.strftime("%Y-%m-%d %H:%M") if event.start else "No start time" end_str = event.end.strftime("%H:%M") if event.end else "" title = event.title or "Untitled Event" formatted.append(f"- {title} ({start_str}{' - ' + end_str if end_str else ''})") return formatted # Helper to format TODO items (expects list of Todo models) def format_todos(todos: List[Todo]) -> List[str]: if not todos: return ["Your TODO list is empty."] formatted = ["Here is your TODO list:"] for todo in todos: status = "[X]" if todo.complete else "[ ]" date_str = f" (Due: {todo.date.strftime('%Y-%m-%d')})" if todo.date else "" remind_str = " (Reminder)" if todo.remind else "" formatted.append(f"- {status} {todo.task}{date_str}{remind_str} (ID: {todo.id})") return formatted # Update the response model for the endpoint @router.post("/process-command", response_model=ProcessCommandResponse) def process_command(request_data: ProcessCommandRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): """ Process the user command, save messages, execute action, save response, and return user-friendly responses. """ user_input = request_data.user_input # --- Save User Message --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.USER, text=user_input) # ------------------------ command_data = process_request(user_input) intent = command_data["intent"] params = command_data["params"] response_text = command_data["response_text"] responses = [response_text] # Start with the initial response # --- Save Initial AI Response --- # Save the first response generated by process_request save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text) # ----------------------------- if intent == "error": # Don't raise HTTPException here if we want to save the error message # Instead, return the error response directly # save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text) # Already saved above return ProcessCommandResponse(responses=responses) if intent == "clarification_needed" or intent == "unknown": # save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=response_text) # Already saved above return ProcessCommandResponse(responses=responses) try: match intent: case "ask_ai": ai_answer = ask_ai(**params) responses.append(ai_answer) # --- Save Additional AI Response --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=ai_answer) # --------------------------------- return ProcessCommandResponse(responses=responses) case "get_calendar_events": events: List[CalendarEvent] = get_calendar_events(db, current_user.id, **params) formatted_responses = format_calendar_events(events) responses.extend(formatted_responses) # --- Save Additional AI Responses --- for resp in formatted_responses: save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp) # ---------------------------------- return ProcessCommandResponse(responses=responses) case "add_calendar_event": event_data = CalendarEventCreate(**params) created_event = create_calendar_event(db, current_user.id, event_data) start_str = created_event.start.strftime("%Y-%m-%d %H:%M") if created_event.start else "No start time" title = created_event.title or "Untitled Event" add_response = f"Added: {title} starting at {start_str}." responses.append(add_response) # --- Save Additional AI Response --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response) # --------------------------------- return ProcessCommandResponse(responses=responses) case "update_calendar_event": event_id = params.pop('event_id', None) if event_id is None: # Save the error message before raising error_msg = "Event ID is required for update." save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) raise HTTPException(status_code=400, detail=error_msg) event_data = CalendarEventUpdate(**params) updated_event = update_calendar_event(db, current_user.id, event_id, event_data=event_data) title = updated_event.title or "Untitled Event" update_response = f"Updated event ID {updated_event.id}: {title}." responses.append(update_response) # --- Save Additional AI Response --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response) # --------------------------------- return ProcessCommandResponse(responses=responses) case "delete_calendar_event": event_id = params.get('event_id') if event_id is None: # Save the error message before raising error_msg = "Event ID is required for delete." save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) raise HTTPException(status_code=400, detail=error_msg) delete_calendar_event(db, current_user.id, event_id) delete_response = f"Deleted event ID {event_id}." responses.append(delete_response) # --- Save Additional AI Response --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response) # --------------------------------- return ProcessCommandResponse(responses=responses) # --- Add TODO Cases --- case "get_todos": todos: List[Todo] = todo_service.get_todos(db, user=current_user, **params) formatted_responses = format_todos(todos) responses.extend(formatted_responses) # --- Save Additional AI Responses --- for resp in formatted_responses: save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=resp) # ---------------------------------- return ProcessCommandResponse(responses=responses) case "add_todo": todo_data = TodoCreate(**params) created_todo = todo_service.create_todo(db, todo=todo_data, user=current_user) add_response = f"Added TODO: '{created_todo.task}' (ID: {created_todo.id})." responses.append(add_response) # --- Save Additional AI Response --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=add_response) # --------------------------------- return ProcessCommandResponse(responses=responses) case "update_todo": todo_id = params.pop('todo_id', None) if todo_id is None: error_msg = "TODO ID is required for update." save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) raise HTTPException(status_code=400, detail=error_msg) todo_data = TodoUpdate(**params) updated_todo = todo_service.update_todo(db, todo_id=todo_id, todo_update=todo_data, user=current_user) update_response = f"Updated TODO ID {updated_todo.id}: '{updated_todo.task}'." if 'complete' in params: status = "complete" if params['complete'] else "incomplete" update_response += f" Marked as {status}." responses.append(update_response) # --- Save Additional AI Response --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=update_response) # --------------------------------- return ProcessCommandResponse(responses=responses) case "delete_todo": todo_id = params.get('todo_id') if todo_id is None: error_msg = "TODO ID is required for delete." save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_msg) raise HTTPException(status_code=400, detail=error_msg) deleted_todo = todo_service.delete_todo(db, todo_id=todo_id, user=current_user) delete_response = f"Deleted TODO ID {deleted_todo.id}: '{deleted_todo.task}'." responses.append(delete_response) # --- Save Additional AI Response --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=delete_response) # --------------------------------- return ProcessCommandResponse(responses=responses) # --- End TODO Cases --- case _: print(f"Warning: Unhandled intent '{intent}' reached api.py match statement.") # The initial response_text was already saved return ProcessCommandResponse(responses=responses) except HTTPException as http_exc: # Don't save again if already saved before raising if http_exc.status_code != 400 or ('event_id' not in http_exc.detail.lower()): save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=http_exc.detail) raise http_exc except Exception as e: print(f"Error executing intent '{intent}': {e}") error_response = "Sorry, I encountered an error while trying to perform that action." # --- Save Final Error AI Response --- save_chat_message(db, user_id=current_user.id, sender=MessageSender.AI, text=error_response) # ---------------------------------- return ProcessCommandResponse(responses=[error_response]) # --- New Endpoint for Chat History --- # Define a Pydantic schema for the response (optional but good practice) from pydantic import BaseModel from datetime import datetime class ChatMessageResponse(BaseModel): id: int sender: MessageSender # Use the enum directly text: str timestamp: datetime class Config: from_attributes = True # Allow Pydantic to work with ORM models @router.get("/history", response_model=List[ChatMessageResponse]) def read_chat_history(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)): """Retrieves the last 50 chat messages for the current user.""" history = get_chat_history(db, user_id=current_user.id, limit=50) return history # -------------------------------------