303 lines
11 KiB
Python
303 lines
11 KiB
Python
import pytest
|
|
from fastapi import status
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy.orm import Session
|
|
from unittest.mock import patch, MagicMock
|
|
from datetime import datetime
|
|
|
|
from tests.helpers import generators
|
|
from modules.nlp.schemas import ProcessCommandResponse
|
|
from modules.nlp.models import (
|
|
MessageSender,
|
|
ChatMessage,
|
|
) # Import necessary models/enums
|
|
|
|
|
|
# --- Mocks ---
|
|
# Mock the external AI call and internal service functions
|
|
@pytest.fixture(autouse=True)
|
|
def mock_nlp_services():
|
|
with patch("modules.nlp.api.process_request") as mock_process, patch(
|
|
"modules.nlp.api.ask_ai"
|
|
) as mock_ask, patch("modules.nlp.api.save_chat_message") as mock_save, patch(
|
|
"modules.nlp.api.get_chat_history"
|
|
) as mock_get_history, patch(
|
|
"modules.nlp.api.create_calendar_event"
|
|
) as mock_create_event, patch(
|
|
"modules.nlp.api.get_calendar_events"
|
|
) as mock_get_events, patch(
|
|
"modules.nlp.api.update_calendar_event"
|
|
) as mock_update_event, patch(
|
|
"modules.nlp.api.delete_calendar_event"
|
|
) as mock_delete_event, patch(
|
|
"modules.nlp.api.todo_service.create_todo"
|
|
) as mock_create_todo, patch(
|
|
"modules.nlp.api.todo_service.get_todos"
|
|
) as mock_get_todos, patch(
|
|
"modules.nlp.api.todo_service.update_todo"
|
|
) as mock_update_todo, patch(
|
|
"modules.nlp.api.todo_service.delete_todo"
|
|
) as mock_delete_todo:
|
|
mocks = {
|
|
"process_request": mock_process,
|
|
"ask_ai": mock_ask,
|
|
"save_chat_message": mock_save,
|
|
"get_chat_history": mock_get_history,
|
|
"create_calendar_event": mock_create_event,
|
|
"get_calendar_events": mock_get_events,
|
|
"update_calendar_event": mock_update_event,
|
|
"delete_calendar_event": mock_delete_event,
|
|
"create_todo": mock_create_todo,
|
|
"get_todos": mock_get_todos,
|
|
"update_todo": mock_update_todo,
|
|
"delete_todo": mock_delete_todo,
|
|
}
|
|
yield mocks
|
|
|
|
|
|
# --- Helper Function ---
|
|
def _login_user(db: Session, client: TestClient):
|
|
user, password = generators.create_user(db)
|
|
login_rsp = generators.login(db, user.username, password)
|
|
return user, login_rsp["access_token"], login_rsp["refresh_token"]
|
|
|
|
|
|
# --- Tests for /process-command ---
|
|
|
|
|
|
def test_process_command_ask_ai(client: TestClient, db: Session, mock_nlp_services):
|
|
user, access_token, refresh_token = _login_user(db, client)
|
|
user_input = "What is the capital of France?"
|
|
mock_nlp_services["process_request"].return_value = {
|
|
"intent": "ask_ai",
|
|
"params": {"request": user_input},
|
|
"response_text": "Let me check that for you.",
|
|
}
|
|
mock_nlp_services["ask_ai"].return_value = "The capital of France is Paris."
|
|
|
|
response = client.post(
|
|
"/api/nlp/process-command",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
cookies={"refresh_token": refresh_token},
|
|
json={"user_input": user_input},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
assert (
|
|
response.json()
|
|
== ProcessCommandResponse(
|
|
responses=["Let me check that for you.", "The capital of France is Paris."]
|
|
).model_dump()
|
|
)
|
|
# Verify save calls: user message, initial AI response, final AI answer
|
|
assert mock_nlp_services["save_chat_message"].call_count == 3
|
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
|
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
|
)
|
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
|
db, user_id=user.id, sender=MessageSender.AI, text="Let me check that for you."
|
|
)
|
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
|
db,
|
|
user_id=user.id,
|
|
sender=MessageSender.AI,
|
|
text="The capital of France is Paris.",
|
|
)
|
|
mock_nlp_services["ask_ai"].assert_called_once_with(request=user_input)
|
|
|
|
|
|
def test_process_command_get_calendar(
|
|
client: TestClient, db: Session, mock_nlp_services
|
|
):
|
|
user, access_token, refresh_token = _login_user(db, client)
|
|
user_input = "What are my events today?"
|
|
mock_nlp_services["process_request"].return_value = {
|
|
"intent": "get_calendar_events",
|
|
"params": {
|
|
"start": "2024-01-01T00:00:00Z",
|
|
"end": "2024-01-01T23:59:59Z",
|
|
}, # Example params
|
|
"response_text": "Okay, fetching your events.",
|
|
}
|
|
# Mock the actual event model returned by the service
|
|
mock_event = MagicMock()
|
|
mock_event.title = "Team Meeting"
|
|
mock_event.start = datetime(2024, 1, 1, 10, 0, 0)
|
|
mock_event.end = datetime(2024, 1, 1, 11, 0, 0)
|
|
mock_nlp_services["get_calendar_events"].return_value = [mock_event]
|
|
|
|
response = client.post(
|
|
"/api/nlp/process-command",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
cookies={"refresh_token": refresh_token},
|
|
json={"user_input": user_input},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
expected_responses = [
|
|
"Okay, fetching your events.",
|
|
"Here are the events:",
|
|
"- Team Meeting (2024-01-01 10:00 - 11:00)",
|
|
]
|
|
assert (
|
|
response.json()
|
|
== ProcessCommandResponse(responses=expected_responses).model_dump()
|
|
)
|
|
assert (
|
|
mock_nlp_services["save_chat_message"].call_count == 4
|
|
) # User, Initial AI, Header, Event
|
|
mock_nlp_services["get_calendar_events"].assert_called_once()
|
|
|
|
|
|
def test_process_command_add_todo(client: TestClient, db: Session, mock_nlp_services):
|
|
user, access_token, refresh_token = _login_user(db, client)
|
|
user_input = "Add buy milk to my list"
|
|
mock_nlp_services["process_request"].return_value = {
|
|
"intent": "add_todo",
|
|
"params": {"task": "buy milk"},
|
|
"response_text": "Adding it now.",
|
|
}
|
|
# Mock the actual Todo model returned by the service
|
|
mock_todo = MagicMock()
|
|
mock_todo.task = "buy milk"
|
|
mock_todo.id = 1
|
|
mock_nlp_services["create_todo"].return_value = mock_todo
|
|
|
|
response = client.post(
|
|
"/api/nlp/process-command",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
cookies={"refresh_token": refresh_token},
|
|
json={"user_input": user_input},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
expected_responses = ["Adding it now.", "Added TODO: 'buy milk' (ID: 1)."]
|
|
assert (
|
|
response.json()
|
|
== ProcessCommandResponse(responses=expected_responses).model_dump()
|
|
)
|
|
assert (
|
|
mock_nlp_services["save_chat_message"].call_count == 3
|
|
) # User, Initial AI, Confirmation AI
|
|
mock_nlp_services["create_todo"].assert_called_once()
|
|
|
|
|
|
def test_process_command_clarification(
|
|
client: TestClient, db: Session, mock_nlp_services
|
|
):
|
|
user, access_token, refresh_token = _login_user(db, client)
|
|
user_input = "Delete the event"
|
|
clarification_text = "Which event do you mean? Please provide the ID."
|
|
mock_nlp_services["process_request"].return_value = {
|
|
"intent": "clarification_needed",
|
|
"params": {"request": user_input},
|
|
"response_text": clarification_text,
|
|
}
|
|
|
|
response = client.post(
|
|
"/api/nlp/process-command",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
cookies={"refresh_token": refresh_token},
|
|
json={"user_input": user_input},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
assert (
|
|
response.json()
|
|
== ProcessCommandResponse(responses=[clarification_text]).model_dump()
|
|
)
|
|
# Verify save calls: user message, clarification AI response
|
|
assert mock_nlp_services["save_chat_message"].call_count == 2
|
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
|
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
|
)
|
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
|
db, user_id=user.id, sender=MessageSender.AI, text=clarification_text
|
|
)
|
|
# Ensure no action services were called
|
|
mock_nlp_services["delete_calendar_event"].assert_not_called()
|
|
|
|
|
|
def test_process_command_error_intent(
|
|
client: TestClient, db: Session, mock_nlp_services
|
|
):
|
|
user, access_token, refresh_token = _login_user(db, client)
|
|
user_input = "Gibberish request"
|
|
error_text = "Sorry, I didn't understand that."
|
|
mock_nlp_services["process_request"].return_value = {
|
|
"intent": "error",
|
|
"params": {},
|
|
"response_text": error_text,
|
|
}
|
|
|
|
response = client.post(
|
|
"/api/nlp/process-command",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
cookies={"refresh_token": refresh_token},
|
|
json={"user_input": user_input},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
assert (
|
|
response.json() == ProcessCommandResponse(responses=[error_text]).model_dump()
|
|
)
|
|
# Verify save calls: user message, error AI response
|
|
assert mock_nlp_services["save_chat_message"].call_count == 2
|
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
|
db, user_id=user.id, sender=MessageSender.USER, text=user_input
|
|
)
|
|
mock_nlp_services["save_chat_message"].assert_any_call(
|
|
db, user_id=user.id, sender=MessageSender.AI, text=error_text
|
|
)
|
|
|
|
|
|
# --- Tests for /history ---
|
|
|
|
|
|
def test_get_history(client: TestClient, db: Session, mock_nlp_services):
|
|
user, access_token, refresh_token = _login_user(db, client)
|
|
|
|
# Mock the history data returned by the service
|
|
mock_history = [
|
|
ChatMessage(
|
|
id=1,
|
|
user_id=user.id,
|
|
sender=MessageSender.USER,
|
|
text="Hello",
|
|
timestamp=datetime.now(),
|
|
),
|
|
ChatMessage(
|
|
id=2,
|
|
user_id=user.id,
|
|
sender=MessageSender.AI,
|
|
text="Hi there!",
|
|
timestamp=datetime.now(),
|
|
),
|
|
]
|
|
mock_nlp_services["get_chat_history"].return_value = mock_history
|
|
|
|
response = client.get(
|
|
"/api/nlp/history",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
cookies={"refresh_token": refresh_token},
|
|
)
|
|
|
|
assert response.status_code == status.HTTP_200_OK
|
|
# We need to compare JSON representations as datetime objects might differ slightly
|
|
response_data = response.json()
|
|
assert len(response_data) == 2
|
|
assert response_data[0]["text"] == "Hello"
|
|
assert response_data[1]["text"] == "Hi there!"
|
|
mock_nlp_services["get_chat_history"].assert_called_once_with(
|
|
db, user_id=user.id, limit=50
|
|
)
|
|
|
|
|
|
def test_get_history_unauthorized(client: TestClient):
|
|
response = client.get("/api/nlp/history")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
# Add more tests for other intents (update/delete calendar/todo, unknown intent, etc.)
|
|
# Add tests for error handling within the API endpoint (e.g., missing IDs for update/delete)
|