|
|
|
|
|
""" |
|
|
Run locally: `SKIP_API_TESTS=1 python -m tests.test_memory_manager` |
|
|
""" |
|
|
|
|
|
|
|
|
import os |
|
|
import unittest |
|
|
from datetime import datetime, timezone |
|
|
from unittest.mock import AsyncMock, MagicMock, patch |
|
|
|
|
|
from src.core.memory_manager import MemoryManager |
|
|
from src.data.connection import ActionFailed |
|
|
from src.models.account import Account |
|
|
from src.models.medical import MedicalMemory, SemanticSearchResult |
|
|
from src.models.patient import Patient |
|
|
from src.models.session import Message, Session |
|
|
from src.utils.embeddings import EmbeddingClient |
|
|
from src.utils.rotator import APIKeyRotator |
|
|
|
|
|
|
|
|
SKIP_API_TESTS = os.getenv('SKIP_API_TESTS', 'false').lower() in ('true', '1', 'yes') |
|
|
|
|
|
|
|
|
class TestMemoryManager(unittest.IsolatedAsyncioTestCase): |
|
|
|
|
|
def setUp(self): |
|
|
"""Set up mocks and the MemoryManager instance before each test.""" |
|
|
|
|
|
self.account_repo_patcher = patch('src.core.memory_manager.account_repo') |
|
|
self.patient_repo_patcher = patch('src.core.memory_manager.patient_repo') |
|
|
self.session_repo_patcher = patch('src.core.memory_manager.session_repo') |
|
|
self.memory_repo_patcher = patch('src.core.memory_manager.memory_repo') |
|
|
|
|
|
self.mock_account_repo = self.account_repo_patcher.start() |
|
|
self.mock_patient_repo = self.patient_repo_patcher.start() |
|
|
self.mock_session_repo = self.session_repo_patcher.start() |
|
|
self.mock_memory_repo = self.memory_repo_patcher.start() |
|
|
|
|
|
|
|
|
self.summarise_title_patcher = patch('src.core.memory_manager.summariser.summarise_title_with_nvidia', new_callable=AsyncMock) |
|
|
self.summarise_gemini_patcher = patch('src.core.memory_manager.summariser.summarise_qa_with_gemini', new_callable=AsyncMock) |
|
|
self.summarise_nvidia_patcher = patch('src.core.memory_manager.summariser.summarise_qa_with_nvidia', new_callable=AsyncMock) |
|
|
self.nvidia_chat_patcher = patch('src.core.memory_manager.nvidia_chat', new_callable=AsyncMock) |
|
|
|
|
|
self.mock_summarise_title = self.summarise_title_patcher.start() |
|
|
self.mock_summarise_gemini = self.summarise_gemini_patcher.start() |
|
|
self.mock_summarise_nvidia = self.summarise_nvidia_patcher.start() |
|
|
self.mock_nvidia_chat = self.nvidia_chat_patcher.start() |
|
|
|
|
|
|
|
|
self.mock_embedder = MagicMock(spec=EmbeddingClient) |
|
|
self.mock_gemini_rotator = MagicMock(spec=APIKeyRotator) |
|
|
self.mock_nvidia_rotator = MagicMock(spec=APIKeyRotator) |
|
|
|
|
|
|
|
|
self.manager = MemoryManager(embedder=self.mock_embedder, max_sessions_per_user=20) |
|
|
|
|
|
|
|
|
self.user_id = "60c72b2f9b1d8b3b3c9d8b1a" |
|
|
self.patient_id = "60c72b2f9b1d8b3b3c9d8b1b" |
|
|
self.session_id = "60c72b2f9b1d8b3b3c9d8b1c" |
|
|
self.now = datetime.now(timezone.utc) |
|
|
|
|
|
def tearDown(self): |
|
|
"""Stop all patchers after each test.""" |
|
|
patch.stopall() |
|
|
|
|
|
|
|
|
|
|
|
def test_create_account_success(self): |
|
|
self.mock_account_repo.create_account.return_value = self.user_id |
|
|
result = self.manager.create_account(name="Dr. Test", role="Doctor") |
|
|
self.assertEqual(result, self.user_id) |
|
|
self.mock_account_repo.create_account.assert_called_once_with(name="Dr. Test", role="Doctor", specialty=None) |
|
|
|
|
|
def test_create_account_failure(self): |
|
|
self.mock_account_repo.create_account.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.create_account() |
|
|
self.assertIsNone(result) |
|
|
|
|
|
def test_get_account_success(self): |
|
|
mock_account = MagicMock(spec=Account) |
|
|
self.mock_account_repo.get_account.return_value = mock_account |
|
|
result = self.manager.get_account(self.user_id) |
|
|
self.assertEqual(result, mock_account) |
|
|
self.mock_account_repo.get_account.assert_called_once_with(self.user_id) |
|
|
|
|
|
def test_get_account_failure(self): |
|
|
self.mock_account_repo.get_account.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.get_account(self.user_id) |
|
|
self.assertIsNone(result) |
|
|
|
|
|
def test_get_all_accounts_success(self): |
|
|
self.mock_account_repo.get_all_accounts.return_value = [MagicMock(spec=Account)] |
|
|
result = self.manager.get_all_accounts() |
|
|
self.assertEqual(len(result), 1) |
|
|
self.mock_account_repo.get_all_accounts.assert_called_once_with(limit=50) |
|
|
|
|
|
def test_get_all_accounts_failure(self): |
|
|
self.mock_account_repo.get_all_accounts.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.get_all_accounts() |
|
|
self.assertEqual(result, []) |
|
|
|
|
|
def test_search_accounts_success(self): |
|
|
self.mock_account_repo.search_accounts.return_value = [MagicMock(spec=Account)] |
|
|
result = self.manager.search_accounts("query") |
|
|
self.assertEqual(len(result), 1) |
|
|
self.mock_account_repo.search_accounts.assert_called_once_with("query", limit=10) |
|
|
|
|
|
def test_search_accounts_failure(self): |
|
|
self.mock_account_repo.search_accounts.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.search_accounts("query") |
|
|
self.assertEqual(result, []) |
|
|
|
|
|
|
|
|
|
|
|
def test_create_patient_success(self): |
|
|
self.mock_patient_repo.create_patient.return_value = self.patient_id |
|
|
result = self.manager.create_patient(name="John Doe", age=40) |
|
|
self.assertEqual(result, self.patient_id) |
|
|
self.mock_patient_repo.create_patient.assert_called_once_with(name="John Doe", age=40) |
|
|
|
|
|
def test_create_patient_failure(self): |
|
|
self.mock_patient_repo.create_patient.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.create_patient(name="John Doe") |
|
|
self.assertIsNone(result) |
|
|
|
|
|
def test_get_patient_by_id_success(self): |
|
|
mock_patient = MagicMock(spec=Patient) |
|
|
self.mock_patient_repo.get_patient_by_id.return_value = mock_patient |
|
|
result = self.manager.get_patient_by_id(self.patient_id) |
|
|
self.assertEqual(result, mock_patient) |
|
|
self.mock_patient_repo.get_patient_by_id.assert_called_once_with(self.patient_id) |
|
|
|
|
|
def test_get_patient_by_id_failure(self): |
|
|
self.mock_patient_repo.get_patient_by_id.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.get_patient_by_id(self.patient_id) |
|
|
self.assertIsNone(result) |
|
|
|
|
|
def test_update_patient_profile_success(self): |
|
|
self.mock_patient_repo.update_patient_profile.return_value = 1 |
|
|
updates = {"age": 41} |
|
|
result = self.manager.update_patient_profile(self.patient_id, updates) |
|
|
self.assertEqual(result, 1) |
|
|
self.mock_patient_repo.update_patient_profile.assert_called_once_with(self.patient_id, updates) |
|
|
|
|
|
def test_update_patient_profile_failure(self): |
|
|
self.mock_patient_repo.update_patient_profile.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.update_patient_profile(self.patient_id, {}) |
|
|
self.assertEqual(result, 0) |
|
|
|
|
|
|
|
|
|
|
|
def test_create_session_success(self): |
|
|
mock_session = MagicMock(spec=Session) |
|
|
self.mock_session_repo.create_session.return_value = mock_session |
|
|
result = self.manager.create_session(self.user_id, self.patient_id) |
|
|
self.assertEqual(result, mock_session) |
|
|
self.mock_session_repo.create_session.assert_called_once_with(self.user_id, self.patient_id, "New Chat") |
|
|
|
|
|
def test_create_session_failure(self): |
|
|
self.mock_session_repo.create_session.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.create_session(self.user_id, self.patient_id) |
|
|
self.assertIsNone(result) |
|
|
|
|
|
def test_get_user_sessions_success(self): |
|
|
self.mock_session_repo.get_user_sessions.return_value = [MagicMock(spec=Session)] |
|
|
result = self.manager.get_user_sessions(self.user_id) |
|
|
self.assertEqual(len(result), 1) |
|
|
self.mock_session_repo.get_user_sessions.assert_called_once_with(self.user_id, limit=20) |
|
|
|
|
|
def test_delete_session_success(self): |
|
|
self.mock_session_repo.delete_session.return_value = True |
|
|
result = self.manager.delete_session(self.session_id) |
|
|
self.assertTrue(result) |
|
|
self.mock_session_repo.delete_session.assert_called_once_with(self.session_id) |
|
|
|
|
|
def test_delete_session_failure(self): |
|
|
self.mock_session_repo.delete_session.side_effect = ActionFailed("DB error") |
|
|
result = self.manager.delete_session(self.session_id) |
|
|
self.assertFalse(result) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_process_medical_exchange_success(self): |
|
|
question = "What are the side effects?" |
|
|
answer = "Common side effects include..." |
|
|
summary = "q: side effects a: common ones are..." |
|
|
embedding = [0.1, 0.2, 0.3] |
|
|
|
|
|
|
|
|
self.mock_summarise_gemini.return_value = summary |
|
|
self.mock_embedder.embed.return_value = [embedding] |
|
|
self.manager._update_session_title_if_first_message = AsyncMock() |
|
|
|
|
|
|
|
|
result = await self.manager.process_medical_exchange( |
|
|
self.session_id, self.patient_id, self.user_id, question, answer, |
|
|
self.mock_gemini_rotator, self.mock_nvidia_rotator |
|
|
) |
|
|
|
|
|
|
|
|
self.assertEqual(result, summary) |
|
|
self.assertEqual(self.mock_session_repo.add_message.call_count, 2) |
|
|
self.mock_session_repo.add_message.assert_any_call(self.session_id, question, sent_by_user=True) |
|
|
self.mock_session_repo.add_message.assert_any_call(self.session_id, answer, sent_by_user=False) |
|
|
self.mock_summarise_gemini.assert_awaited_once() |
|
|
self.mock_embedder.embed.assert_called_once_with([summary]) |
|
|
self.mock_memory_repo.create_memory.assert_called_once_with( |
|
|
patient_id=self.patient_id, |
|
|
doctor_id=self.user_id, |
|
|
session_id=self.session_id, |
|
|
summary=summary, |
|
|
embedding=embedding |
|
|
) |
|
|
self.manager._update_session_title_if_first_message.assert_awaited_once() |
|
|
|
|
|
async def test_process_medical_exchange_db_failure(self): |
|
|
self.mock_session_repo.add_message.side_effect = ActionFailed("DB write failed") |
|
|
result = await self.manager.process_medical_exchange( |
|
|
self.session_id, self.patient_id, self.user_id, "q", "a", |
|
|
self.mock_gemini_rotator, self.mock_nvidia_rotator |
|
|
) |
|
|
self.assertIsNone(result) |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_process_medical_exchange_embedding_failure(self): |
|
|
self.mock_embedder.embed.side_effect = Exception("Embedding model down") |
|
|
self.mock_summarise_gemini.return_value = "summary" |
|
|
self.manager._update_session_title_if_first_message = AsyncMock() |
|
|
|
|
|
await self.manager.process_medical_exchange( |
|
|
self.session_id, self.patient_id, self.user_id, "q", "a", |
|
|
self.mock_gemini_rotator, self.mock_nvidia_rotator |
|
|
) |
|
|
|
|
|
|
|
|
self.mock_memory_repo.create_memory.assert_called_once() |
|
|
args, kwargs = self.mock_memory_repo.create_memory.call_args |
|
|
self.assertIsNone(kwargs.get("embedding")) |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_get_enhanced_context_full(self): |
|
|
question = "Is this medication safe?" |
|
|
|
|
|
mock_stm = [MagicMock(spec=MedicalMemory, summary="STM summary 1")] |
|
|
mock_ltm = [MagicMock(spec=SemanticSearchResult, summary="LTM summary 1")] |
|
|
mock_messages = [MagicMock(spec=Message, sent_by_user=True, content="Previous question")] |
|
|
mock_session = MagicMock(spec=Session, messages=mock_messages) |
|
|
|
|
|
|
|
|
self.mock_memory_repo.get_recent_memories.return_value = mock_stm |
|
|
self.mock_nvidia_chat.return_value = "STM summary 1" |
|
|
self.mock_embedder.embed.return_value = [[0.5]] |
|
|
self.mock_memory_repo.search_memories_semantic.return_value = mock_ltm |
|
|
self.mock_session_repo.get_session.return_value = mock_session |
|
|
|
|
|
|
|
|
context = await self.manager.get_enhanced_context( |
|
|
self.session_id, self.patient_id, question, self.mock_nvidia_rotator |
|
|
) |
|
|
|
|
|
|
|
|
self.assertIn("Recent relevant medical context (STM)", context) |
|
|
self.assertIn("STM summary 1", context) |
|
|
self.assertIn("Semantically relevant medical history (LTM)", context) |
|
|
self.assertIn("LTM summary 1", context) |
|
|
self.assertIn("Current conversation", context) |
|
|
self.assertIn("User: Previous question", context) |
|
|
|
|
|
self.mock_memory_repo.get_recent_memories.assert_called_once_with(self.patient_id, limit=3) |
|
|
self.mock_nvidia_chat.assert_awaited_once() |
|
|
self.mock_memory_repo.search_memories_semantic.assert_called_once() |
|
|
self.mock_session_repo.get_session.assert_called_once_with(self.session_id) |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_get_enhanced_context_no_ltm(self): |
|
|
|
|
|
self.mock_memory_repo.get_recent_memories.return_value = [MagicMock(spec=MedicalMemory, summary="STM")] |
|
|
self.mock_nvidia_chat.return_value = "STM" |
|
|
self.mock_embedder.embed.return_value = [[0.5]] |
|
|
self.mock_memory_repo.search_memories_semantic.return_value = [] |
|
|
self.mock_session_repo.get_session.return_value = MagicMock(spec=Session, messages=[]) |
|
|
|
|
|
context = await self.manager.get_enhanced_context( |
|
|
self.session_id, self.patient_id, "question", self.mock_nvidia_rotator |
|
|
) |
|
|
|
|
|
self.assertIn("Recent relevant medical context (STM)", context) |
|
|
self.assertNotIn("Semantically relevant medical history (LTM)", context) |
|
|
self.assertNotIn("Current conversation", context) |
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_update_session_title_if_first_message_success(self): |
|
|
question = "My leg hurts, what should I do?" |
|
|
mock_session = MagicMock(spec=Session, messages=[1, 2]) |
|
|
self.manager.get_session = MagicMock(return_value=mock_session) |
|
|
self.manager.update_session_title = MagicMock() |
|
|
self.mock_summarise_title.return_value = "Leg Pain Inquiry" |
|
|
|
|
|
await self.manager._update_session_title_if_first_message( |
|
|
self.session_id, question, self.mock_nvidia_rotator |
|
|
) |
|
|
|
|
|
self.manager.get_session.assert_called_once_with(self.session_id) |
|
|
self.mock_summarise_title.assert_awaited_once_with(question, self.mock_nvidia_rotator, max_words=5) |
|
|
self.manager.update_session_title.assert_called_once_with(self.session_id, "Leg Pain Inquiry") |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_update_session_title_not_first_message(self): |
|
|
mock_session = MagicMock(spec=Session, messages=[1, 2, 3]) |
|
|
self.manager.get_session = MagicMock(return_value=mock_session) |
|
|
self.manager.update_session_title = MagicMock() |
|
|
|
|
|
await self.manager._update_session_title_if_first_message( |
|
|
self.session_id, "question", self.mock_nvidia_rotator |
|
|
) |
|
|
|
|
|
self.manager.get_session.assert_called_once_with(self.session_id) |
|
|
self.mock_summarise_title.assert_not_awaited() |
|
|
self.manager.update_session_title.assert_not_called() |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_generate_summary_gemini_success(self): |
|
|
self.mock_summarise_gemini.return_value = "Gemini summary" |
|
|
result = await self.manager._generate_summary("q", "a", self.mock_gemini_rotator, self.mock_nvidia_rotator) |
|
|
self.assertEqual(result, "Gemini summary") |
|
|
self.mock_summarise_gemini.assert_awaited_once() |
|
|
self.mock_summarise_nvidia.assert_not_awaited() |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_generate_summary_gemini_fails_nvidia_success(self): |
|
|
self.mock_summarise_gemini.return_value = None |
|
|
self.mock_summarise_nvidia.return_value = "NVIDIA summary" |
|
|
result = await self.manager._generate_summary("q", "a", self.mock_gemini_rotator, self.mock_nvidia_rotator) |
|
|
self.assertEqual(result, "NVIDIA summary") |
|
|
self.mock_summarise_gemini.assert_awaited_once() |
|
|
self.mock_summarise_nvidia.assert_awaited_once() |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_generate_summary_all_fail(self): |
|
|
self.mock_summarise_gemini.return_value = None |
|
|
self.mock_summarise_nvidia.return_value = None |
|
|
result = await self.manager._generate_summary("question", "answer", self.mock_gemini_rotator, self.mock_nvidia_rotator) |
|
|
self.assertEqual(result, "Question: question\nAnswer: answer") |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_filter_summaries_for_relevance_success(self): |
|
|
summaries = ["Summary A", "Summary B", "Summary C"] |
|
|
self.mock_nvidia_chat.return_value = "Summary A\nSummary C" |
|
|
result = await self.manager._filter_summaries_for_relevance("question", summaries, self.mock_nvidia_rotator) |
|
|
self.assertEqual(result, ["Summary A", "Summary C"]) |
|
|
self.mock_nvidia_chat.assert_awaited_once() |
|
|
|
|
|
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs") |
|
|
async def test_filter_summaries_for_relevance_api_fails(self): |
|
|
summaries = ["Summary A", "Summary B"] |
|
|
self.mock_nvidia_chat.side_effect = Exception("API error") |
|
|
result = await self.manager._filter_summaries_for_relevance("question", summaries, self.mock_nvidia_rotator) |
|
|
|
|
|
self.assertEqual(result, summaries) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
unittest.main() |
|
|
|