Files
MAIA/backend/tests/test_nlp.py
2025-04-23 01:00:56 +02:00

303 lines
11 KiB
Python

import pytest
from fastapi import status
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from unittest.mock import patch, MagicMock
from datetime import datetime
from tests.helpers import generators
from modules.nlp.schemas import ProcessCommandResponse
from modules.nlp.models import (
MessageSender,
ChatMessage,
) # Import necessary models/enums
# --- Mocks ---
# Mock the external AI call and internal service functions
@pytest.fixture(autouse=True)
def mock_nlp_services():
with patch("modules.nlp.api.process_request") as mock_process, patch(
"modules.nlp.api.ask_ai"
) as mock_ask, patch("modules.nlp.api.save_chat_message") as mock_save, patch(
"modules.nlp.api.get_chat_history"
) as mock_get_history, patch(
"modules.nlp.api.create_calendar_event"
) as mock_create_event, patch(
"modules.nlp.api.get_calendar_events"
) as mock_get_events, patch(
"modules.nlp.api.update_calendar_event"
) as mock_update_event, patch(
"modules.nlp.api.delete_calendar_event"
) as mock_delete_event, patch(
"modules.nlp.api.todo_service.create_todo"
) as mock_create_todo, patch(
"modules.nlp.api.todo_service.get_todos"
) as mock_get_todos, patch(
"modules.nlp.api.todo_service.update_todo"
) as mock_update_todo, patch(
"modules.nlp.api.todo_service.delete_todo"
) as mock_delete_todo:
mocks = {
"process_request": mock_process,
"ask_ai": mock_ask,
"save_chat_message": mock_save,
"get_chat_history": mock_get_history,
"create_calendar_event": mock_create_event,
"get_calendar_events": mock_get_events,
"update_calendar_event": mock_update_event,
"delete_calendar_event": mock_delete_event,
"create_todo": mock_create_todo,
"get_todos": mock_get_todos,
"update_todo": mock_update_todo,
"delete_todo": mock_delete_todo,
}
yield mocks
# --- Helper Function ---
def _login_user(db: Session, client: TestClient):
user, password = generators.create_user(db)
login_rsp = generators.login(db, user.username, password)
return user, login_rsp["access_token"], login_rsp["refresh_token"]
# --- Tests for /process-command ---
def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client)
user_input = "What is the capital of France?"
mock_nlp_services["process_request"].return_value = {
"intent": "ask_ai",
"params": {"request": user_input},
"response_text": "Let me check that for you.",
}
mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris."
response = client.post(
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
assert (
response.json()
== ProcessCommandResponse(
responses=["Let me check that for you.", "The capital of France is Paris."]
).model_dump()
)
# Verify save calls: user message, initial AI response, final AI answer
assert mock_nlp_services["save_chat_message"].call_count == 3
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.USER, text=user_input
)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you."
)
mock_nlp_services["save_chat_message"].assert_any_call(
db,
user_id=user.id,
sender=MessageSender.AI,
text="The capital of France is Paris.",
)
mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input)
def test_process_command_get_calendar(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client)
user_input = "What are my events today?"
mock_nlp_services["process_request"].return_value = {
"intent": "get_calendar_events",
"params": {
"start": "2024-01-01T00:00:00Z",
"end": "2024-01-01T23:59:59Z",
}, # Example params
"response_text": "Okay, fetching your events.",
}
# Mock the actual event model returned by the service
mock_event = MagicMock()
mock_event.title = "Team Meeting"
mock_event.start = datetime(2024, 1, 1, 10, 0, 0)
mock_event.end = datetime(2024, 1, 1, 11, 0, 0)
mock_nlp_services["get_calendar_events"].return_value = [mock_event]
response = client.post(
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
expected_responses = [
"Okay, fetching your events.",
"Here are the events:",
"- Team Meeting (2024-01-01 10:00 - 11:00)",
]
assert (
response.json()
== ProcessCommandResponse(responses=expected_responses).model_dump()
)
assert (
mock_nlp_services["save_chat_message"].call_count == 4
) # User, Initial AI, Header, Event
mock_nlp_services["get_calendar_events"].assert_called_once()
def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client)
user_input = "Add buy milk to my list"
mock_nlp_services["process_request"].return_value = {
"intent": "add_todo",
"params": {"task": "buy milk"},
"response_text": "Adding it now.",
}
# Mock the actual Todo model returned by the service
mock_todo = MagicMock()
mock_todo.task = "buy milk"
mock_todo.id = 1
mock_nlp_services["create_todo"].return_value = mock_todo
response = client.post(
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."]
assert (
response.json()
== ProcessCommandResponse(responses=expected_responses).model_dump()
)
assert (
mock_nlp_services["save_chat_message"].call_count == 3
) # User, Initial AI, Confirmation AI
mock_nlp_services["create_todo"].assert_called_once()
def test_process_command_clarification(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client)
user_input = "Delete the event"
clarification_text = "Which event do you mean? Please provide the ID."
mock_nlp_services["process_request"].return_value = {
"intent": "clarification_needed",
"params": {"request": user_input},
"response_text": clarification_text,
}
response = client.post(
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
assert (
response.json()
== ProcessCommandResponse(responses=[clarification_text]).model_dump()
)
# Verify save calls: user message, clarification AI response
assert mock_nlp_services["save_chat_message"].call_count == 2
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.USER, text=user_input
)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text=clarification_text
)
# Ensure no action services were called
mock_nlp_services["delete_calendar_event"].assert_not_called()
def test_process_command_error_intent(
client: TestClient, db: Session, mock_nlp_services
):
user, access_token, refresh_token = _login_user(db, client)
user_input = "Gibberish request"
error_text = "Sorry, I didn't understand that."
mock_nlp_services["process_request"].return_value = {
"intent": "error",
"params": {},
"response_text": error_text,
}
response = client.post(
"/api/nlp/process-command",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
json={"user_input": user_input},
)
assert response.status_code == status.HTTP_200_OK
assert (
response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
)
# Verify save calls: user message, error AI response
assert mock_nlp_services["save_chat_message"].call_count == 2
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.USER, text=user_input
)
mock_nlp_services["save_chat_message"].assert_any_call(
db, user_id=user.id, sender=MessageSender.AI, text=error_text
)
# --- Tests for /history ---
def test_get_history(client: TestClient, db: Session, mock_nlp_services):
user, access_token, refresh_token = _login_user(db, client)
# Mock the history data returned by the service
mock_history = [
ChatMessage(
id=1,
user_id=user.id,
sender=MessageSender.USER,
text="Hello",
timestamp=datetime.now(),
),
ChatMessage(
id=2,
user_id=user.id,
sender=MessageSender.AI,
text="Hi there!",
timestamp=datetime.now(),
),
]
mock_nlp_services["get_chat_history"].return_value = mock_history
response = client.get(
"/api/nlp/history",
headers={"Authorization": f"Bearer {access_token}"},
cookies={"refresh_token": refresh_token},
)
assert response.status_code == status.HTTP_200_OK
# We need to compare JSON representations as datetime objects might differ slightly
response_data = response.json()
assert len(response_data) == 2
assert response_data[0]["text"] == "Hello"
assert response_data[1]["text"] == "Hi there!"
mock_nlp_services["get_chat_history"].assert_called_once_with(
db, user_id=user.id, limit=50
)
def test_get_history_unauthorized(client: TestClient):
response = client.get("/api/nlp/history")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
# Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.)
# Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete)