MedicalDiagnosisSystem / tests /core /test_memory_manager.py
LiamKhoaLe's picture
Refactor tests organisation
d753c16
# tests/test_memory_manager.py
"""
Run locally: `SKIP_API_TESTS=1 python -m tests.test_memory_manager`
"""
import os
import unittest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from src.core.memory_manager import MemoryManager
from src.data.connection import ActionFailed
from src.models.account import Account
from src.models.medical import MedicalMemory, SemanticSearchResult
from src.models.patient import Patient
from src.models.session import Message, Session
from src.utils.embeddings import EmbeddingClient
from src.utils.rotator import APIKeyRotator
# Check an environment variable to see if API-dependent tests should be skipped
SKIP_API_TESTS = os.getenv('SKIP_API_TESTS', 'false').lower() in ('true', '1', 'yes')
# Use the modern unittest features for async code
class TestMemoryManager(unittest.IsolatedAsyncioTestCase):
def setUp(self):
"""Set up mocks and the MemoryManager instance before each test."""
# 1. Mock the repository dependencies
self.account_repo_patcher = patch('src.core.memory_manager.account_repo')
self.patient_repo_patcher = patch('src.core.memory_manager.patient_repo')
self.session_repo_patcher = patch('src.core.memory_manager.session_repo')
self.memory_repo_patcher = patch('src.core.memory_manager.memory_repo')
self.mock_account_repo = self.account_repo_patcher.start()
self.mock_patient_repo = self.patient_repo_patcher.start()
self.mock_session_repo = self.session_repo_patcher.start()
self.mock_memory_repo = self.memory_repo_patcher.start()
# 2. Mock the service dependencies, specifically using AsyncMock for async functions
self.summarise_title_patcher = patch('src.core.memory_manager.summariser.summarise_title_with_nvidia', new_callable=AsyncMock)
self.summarise_gemini_patcher = patch('src.core.memory_manager.summariser.summarise_qa_with_gemini', new_callable=AsyncMock)
self.summarise_nvidia_patcher = patch('src.core.memory_manager.summariser.summarise_qa_with_nvidia', new_callable=AsyncMock)
self.nvidia_chat_patcher = patch('src.core.memory_manager.nvidia_chat', new_callable=AsyncMock)
self.mock_summarise_title = self.summarise_title_patcher.start()
self.mock_summarise_gemini = self.summarise_gemini_patcher.start()
self.mock_summarise_nvidia = self.summarise_nvidia_patcher.start()
self.mock_nvidia_chat = self.nvidia_chat_patcher.start()
# 3. Create instances of dependencies needed for MemoryManager
self.mock_embedder = MagicMock(spec=EmbeddingClient)
self.mock_gemini_rotator = MagicMock(spec=APIKeyRotator)
self.mock_nvidia_rotator = MagicMock(spec=APIKeyRotator)
# 4. Instantiate the class under test
self.manager = MemoryManager(embedder=self.mock_embedder, max_sessions_per_user=20)
# 5. Common test data
self.user_id = "60c72b2f9b1d8b3b3c9d8b1a"
self.patient_id = "60c72b2f9b1d8b3b3c9d8b1b"
self.session_id = "60c72b2f9b1d8b3b3c9d8b1c"
self.now = datetime.now(timezone.utc)
def tearDown(self):
"""Stop all patchers after each test."""
patch.stopall()
# --- Account Management Tests ---
def test_create_account_success(self):
self.mock_account_repo.create_account.return_value = self.user_id
result = self.manager.create_account(name="Dr. Test", role="Doctor")
self.assertEqual(result, self.user_id)
self.mock_account_repo.create_account.assert_called_once_with(name="Dr. Test", role="Doctor", specialty=None)
def test_create_account_failure(self):
self.mock_account_repo.create_account.side_effect = ActionFailed("DB error")
result = self.manager.create_account()
self.assertIsNone(result)
def test_get_account_success(self):
mock_account = MagicMock(spec=Account)
self.mock_account_repo.get_account.return_value = mock_account
result = self.manager.get_account(self.user_id)
self.assertEqual(result, mock_account)
self.mock_account_repo.get_account.assert_called_once_with(self.user_id)
def test_get_account_failure(self):
self.mock_account_repo.get_account.side_effect = ActionFailed("DB error")
result = self.manager.get_account(self.user_id)
self.assertIsNone(result)
def test_get_all_accounts_success(self):
self.mock_account_repo.get_all_accounts.return_value = [MagicMock(spec=Account)]
result = self.manager.get_all_accounts()
self.assertEqual(len(result), 1)
self.mock_account_repo.get_all_accounts.assert_called_once_with(limit=50)
def test_get_all_accounts_failure(self):
self.mock_account_repo.get_all_accounts.side_effect = ActionFailed("DB error")
result = self.manager.get_all_accounts()
self.assertEqual(result, [])
def test_search_accounts_success(self):
self.mock_account_repo.search_accounts.return_value = [MagicMock(spec=Account)]
result = self.manager.search_accounts("query")
self.assertEqual(len(result), 1)
self.mock_account_repo.search_accounts.assert_called_once_with("query", limit=10)
def test_search_accounts_failure(self):
self.mock_account_repo.search_accounts.side_effect = ActionFailed("DB error")
result = self.manager.search_accounts("query")
self.assertEqual(result, [])
# --- Patient Management Tests ---
def test_create_patient_success(self):
self.mock_patient_repo.create_patient.return_value = self.patient_id
result = self.manager.create_patient(name="John Doe", age=40)
self.assertEqual(result, self.patient_id)
self.mock_patient_repo.create_patient.assert_called_once_with(name="John Doe", age=40)
def test_create_patient_failure(self):
self.mock_patient_repo.create_patient.side_effect = ActionFailed("DB error")
result = self.manager.create_patient(name="John Doe")
self.assertIsNone(result)
def test_get_patient_by_id_success(self):
mock_patient = MagicMock(spec=Patient)
self.mock_patient_repo.get_patient_by_id.return_value = mock_patient
result = self.manager.get_patient_by_id(self.patient_id)
self.assertEqual(result, mock_patient)
self.mock_patient_repo.get_patient_by_id.assert_called_once_with(self.patient_id)
def test_get_patient_by_id_failure(self):
self.mock_patient_repo.get_patient_by_id.side_effect = ActionFailed("DB error")
result = self.manager.get_patient_by_id(self.patient_id)
self.assertIsNone(result)
def test_update_patient_profile_success(self):
self.mock_patient_repo.update_patient_profile.return_value = 1
updates = {"age": 41}
result = self.manager.update_patient_profile(self.patient_id, updates)
self.assertEqual(result, 1)
self.mock_patient_repo.update_patient_profile.assert_called_once_with(self.patient_id, updates)
def test_update_patient_profile_failure(self):
self.mock_patient_repo.update_patient_profile.side_effect = ActionFailed("DB error")
result = self.manager.update_patient_profile(self.patient_id, {})
self.assertEqual(result, 0)
# --- Session Management Tests ---
def test_create_session_success(self):
mock_session = MagicMock(spec=Session)
self.mock_session_repo.create_session.return_value = mock_session
result = self.manager.create_session(self.user_id, self.patient_id)
self.assertEqual(result, mock_session)
self.mock_session_repo.create_session.assert_called_once_with(self.user_id, self.patient_id, "New Chat")
def test_create_session_failure(self):
self.mock_session_repo.create_session.side_effect = ActionFailed("DB error")
result = self.manager.create_session(self.user_id, self.patient_id)
self.assertIsNone(result)
def test_get_user_sessions_success(self):
self.mock_session_repo.get_user_sessions.return_value = [MagicMock(spec=Session)]
result = self.manager.get_user_sessions(self.user_id)
self.assertEqual(len(result), 1)
self.mock_session_repo.get_user_sessions.assert_called_once_with(self.user_id, limit=20)
def test_delete_session_success(self):
self.mock_session_repo.delete_session.return_value = True
result = self.manager.delete_session(self.session_id)
self.assertTrue(result)
self.mock_session_repo.delete_session.assert_called_once_with(self.session_id)
def test_delete_session_failure(self):
self.mock_session_repo.delete_session.side_effect = ActionFailed("DB error")
result = self.manager.delete_session(self.session_id)
self.assertFalse(result)
# --- Core Business Logic (Async) Tests ---
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_process_medical_exchange_success(self):
question = "What are the side effects?"
answer = "Common side effects include..."
summary = "q: side effects a: common ones are..."
embedding = [0.1, 0.2, 0.3]
# Configure mocks
self.mock_summarise_gemini.return_value = summary
self.mock_embedder.embed.return_value = [embedding]
self.manager._update_session_title_if_first_message = AsyncMock()
# Call the method
result = await self.manager.process_medical_exchange(
self.session_id, self.patient_id, self.user_id, question, answer,
self.mock_gemini_rotator, self.mock_nvidia_rotator
)
# Assertions
self.assertEqual(result, summary)
self.assertEqual(self.mock_session_repo.add_message.call_count, 2)
self.mock_session_repo.add_message.assert_any_call(self.session_id, question, sent_by_user=True)
self.mock_session_repo.add_message.assert_any_call(self.session_id, answer, sent_by_user=False)
self.mock_summarise_gemini.assert_awaited_once()
self.mock_embedder.embed.assert_called_once_with([summary])
self.mock_memory_repo.create_memory.assert_called_once_with(
patient_id=self.patient_id,
doctor_id=self.user_id,
session_id=self.session_id,
summary=summary,
embedding=embedding
)
self.manager._update_session_title_if_first_message.assert_awaited_once()
async def test_process_medical_exchange_db_failure(self):
self.mock_session_repo.add_message.side_effect = ActionFailed("DB write failed")
result = await self.manager.process_medical_exchange(
self.session_id, self.patient_id, self.user_id, "q", "a",
self.mock_gemini_rotator, self.mock_nvidia_rotator
)
self.assertIsNone(result)
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_process_medical_exchange_embedding_failure(self):
self.mock_embedder.embed.side_effect = Exception("Embedding model down")
self.mock_summarise_gemini.return_value = "summary"
self.manager._update_session_title_if_first_message = AsyncMock()
await self.manager.process_medical_exchange(
self.session_id, self.patient_id, self.user_id, "q", "a",
self.mock_gemini_rotator, self.mock_nvidia_rotator
)
# Check that create_memory was still called, but with embedding=None
self.mock_memory_repo.create_memory.assert_called_once()
args, kwargs = self.mock_memory_repo.create_memory.call_args
self.assertIsNone(kwargs.get("embedding"))
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_get_enhanced_context_full(self):
question = "Is this medication safe?"
# Mock data
mock_stm = [MagicMock(spec=MedicalMemory, summary="STM summary 1")]
mock_ltm = [MagicMock(spec=SemanticSearchResult, summary="LTM summary 1")]
mock_messages = [MagicMock(spec=Message, sent_by_user=True, content="Previous question")]
mock_session = MagicMock(spec=Session, messages=mock_messages)
# Configure mocks
self.mock_memory_repo.get_recent_memories.return_value = mock_stm
self.mock_nvidia_chat.return_value = "STM summary 1"
self.mock_embedder.embed.return_value = [[0.5]]
self.mock_memory_repo.search_memories_semantic.return_value = mock_ltm
self.mock_session_repo.get_session.return_value = mock_session
# Call method
context = await self.manager.get_enhanced_context(
self.session_id, self.patient_id, question, self.mock_nvidia_rotator
)
# Assertions
self.assertIn("Recent relevant medical context (STM)", context)
self.assertIn("STM summary 1", context)
self.assertIn("Semantically relevant medical history (LTM)", context)
self.assertIn("LTM summary 1", context)
self.assertIn("Current conversation", context)
self.assertIn("User: Previous question", context)
self.mock_memory_repo.get_recent_memories.assert_called_once_with(self.patient_id, limit=3)
self.mock_nvidia_chat.assert_awaited_once()
self.mock_memory_repo.search_memories_semantic.assert_called_once()
self.mock_session_repo.get_session.assert_called_once_with(self.session_id)
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_get_enhanced_context_no_ltm(self):
# Configure mocks for only STM and session context
self.mock_memory_repo.get_recent_memories.return_value = [MagicMock(spec=MedicalMemory, summary="STM")]
self.mock_nvidia_chat.return_value = "STM"
self.mock_embedder.embed.return_value = [[0.5]]
self.mock_memory_repo.search_memories_semantic.return_value = [] # No LTM results
self.mock_session_repo.get_session.return_value = MagicMock(spec=Session, messages=[])
context = await self.manager.get_enhanced_context(
self.session_id, self.patient_id, "question", self.mock_nvidia_rotator
)
self.assertIn("Recent relevant medical context (STM)", context)
self.assertNotIn("Semantically relevant medical history (LTM)", context)
self.assertNotIn("Current conversation", context) # No messages
# --- Private Helper (Async) Tests ---
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_update_session_title_if_first_message_success(self):
question = "My leg hurts, what should I do?"
mock_session = MagicMock(spec=Session, messages=[1, 2]) # Length is 2
self.manager.get_session = MagicMock(return_value=mock_session)
self.manager.update_session_title = MagicMock()
self.mock_summarise_title.return_value = "Leg Pain Inquiry"
await self.manager._update_session_title_if_first_message(
self.session_id, question, self.mock_nvidia_rotator
)
self.manager.get_session.assert_called_once_with(self.session_id)
self.mock_summarise_title.assert_awaited_once_with(question, self.mock_nvidia_rotator, max_words=5)
self.manager.update_session_title.assert_called_once_with(self.session_id, "Leg Pain Inquiry")
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_update_session_title_not_first_message(self):
mock_session = MagicMock(spec=Session, messages=[1, 2, 3]) # Length is not 2
self.manager.get_session = MagicMock(return_value=mock_session)
self.manager.update_session_title = MagicMock()
await self.manager._update_session_title_if_first_message(
self.session_id, "question", self.mock_nvidia_rotator
)
self.manager.get_session.assert_called_once_with(self.session_id)
self.mock_summarise_title.assert_not_awaited()
self.manager.update_session_title.assert_not_called()
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_generate_summary_gemini_success(self):
self.mock_summarise_gemini.return_value = "Gemini summary"
result = await self.manager._generate_summary("q", "a", self.mock_gemini_rotator, self.mock_nvidia_rotator)
self.assertEqual(result, "Gemini summary")
self.mock_summarise_gemini.assert_awaited_once()
self.mock_summarise_nvidia.assert_not_awaited()
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_generate_summary_gemini_fails_nvidia_success(self):
self.mock_summarise_gemini.return_value = None # Gemini fails
self.mock_summarise_nvidia.return_value = "NVIDIA summary"
result = await self.manager._generate_summary("q", "a", self.mock_gemini_rotator, self.mock_nvidia_rotator)
self.assertEqual(result, "NVIDIA summary")
self.mock_summarise_gemini.assert_awaited_once()
self.mock_summarise_nvidia.assert_awaited_once()
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_generate_summary_all_fail(self):
self.mock_summarise_gemini.return_value = None
self.mock_summarise_nvidia.return_value = None
result = await self.manager._generate_summary("question", "answer", self.mock_gemini_rotator, self.mock_nvidia_rotator)
self.assertEqual(result, "Question: question\nAnswer: answer")
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_filter_summaries_for_relevance_success(self):
summaries = ["Summary A", "Summary B", "Summary C"]
self.mock_nvidia_chat.return_value = "Summary A\nSummary C"
result = await self.manager._filter_summaries_for_relevance("question", summaries, self.mock_nvidia_rotator)
self.assertEqual(result, ["Summary A", "Summary C"])
self.mock_nvidia_chat.assert_awaited_once()
@unittest.skipIf(SKIP_API_TESTS, "Skipping tests that require external APIs")
async def test_filter_summaries_for_relevance_api_fails(self):
summaries = ["Summary A", "Summary B"]
self.mock_nvidia_chat.side_effect = Exception("API error")
result = await self.manager._filter_summaries_for_relevance("question", summaries, self.mock_nvidia_rotator)
# Should return all summaries as a fallback
self.assertEqual(result, summaries)
if __name__ == '__main__':
unittest.main()