Spaces:
Sleeping
Sleeping
| import time | |
| import unittest | |
| from datetime import datetime, timedelta, timezone | |
| from unittest.mock import patch | |
| from bson import ObjectId | |
| from pymongo.errors import ConnectionFailure, WriteError | |
| from src.data.connection import ActionFailed, Collections, get_collection | |
| from src.data.repositories import session as session_repo | |
| from src.models.session import Message, Session | |
| from src.utils.logger import logger | |
| from ..base_test import BaseMongoTest | |
| class TestSessionRepository(BaseMongoTest): | |
| """Test class for the 'happy path' and edge cases of session repository functions.""" | |
| def setUp(self): | |
| """Set up a clean test environment before each test.""" | |
| super().setUp() | |
| self.test_collection = self._collections[Collections.SESSION] | |
| session_repo.init(collection_name=self.test_collection, drop=True) | |
| self.account_id = str(ObjectId()) | |
| self.patient_id = str(ObjectId()) | |
| def test_init_functionality(self): | |
| """Test that init sets up the collection and indexes correctly.""" | |
| self.assertIn(self.test_collection, self.db.list_collection_names()) | |
| index_info = get_collection(self.test_collection).index_information() | |
| self.assertIn("messages._id_1", index_info) | |
| def test_create_and_get_session(self): | |
| """Test chat session creation and retrieval by ID.""" | |
| session = session_repo.create_session( | |
| self.account_id, self.patient_id, "Test Chat", collection_name=self.test_collection | |
| ) | |
| self.assertIsInstance(session, Session) | |
| self.assertEqual(session.title, "Test Chat") | |
| retrieved = session_repo.get_session(session.id, collection_name=self.test_collection) | |
| self.assertIsNotNone(retrieved) | |
| self.assertIsInstance(retrieved, Session) | |
| self.assertEqual(retrieved.id, session.id) # type: ignore | |
| self.assertEqual(retrieved.account_id, self.account_id) # type: ignore | |
| self.assertIsNone(session_repo.get_session(str(ObjectId()), collection_name=self.test_collection)) | |
| def test_add_and_get_messages(self): | |
| """Test adding messages and retrieving them as Message models.""" | |
| session = session_repo.create_session(self.account_id, self.patient_id, "Msg Test", collection_name=self.test_collection) | |
| session_repo.add_message(session.id, "User message 1", True, collection_name=self.test_collection) | |
| # Add a small delay to ensure a distinct timestamp for the next message | |
| time.sleep(0.01) | |
| session_repo.add_message(session.id, "AI response 1", False, collection_name=self.test_collection) | |
| messages = session_repo.get_session_messages(session.id, collection_name=self.test_collection) | |
| self.assertEqual(len(messages), 2) | |
| self.assertIsInstance(messages[0], Message) | |
| self.assertEqual(messages[0].content, "AI response 1") # Descending order is now guaranteed | |
| self.assertEqual(messages[1].id, 0) | |
| self.assertEqual(len(session_repo.get_session_messages(session.id, limit=1, collection_name=self.test_collection)), 1) | |
| def test_list_sessions(self): | |
| """Test listing sessions for a patient and user returns lists of Session models.""" | |
| session_repo.create_session(self.account_id, self.patient_id, "First", collection_name=self.test_collection) | |
| time.sleep(0.01) | |
| s2 = session_repo.create_session(self.account_id, self.patient_id, "Second", collection_name=self.test_collection) | |
| patient_sessions = session_repo.list_patient_sessions(self.patient_id, collection_name=self.test_collection) | |
| self.assertEqual(len(patient_sessions), 2) | |
| self.assertIsInstance(patient_sessions[0], Session) | |
| self.assertEqual(patient_sessions[0].id, s2.id) | |
| user_sessions = session_repo.get_user_sessions(self.account_id, collection_name=self.test_collection) | |
| self.assertEqual(len(user_sessions), 2) | |
| self.assertEqual(user_sessions[0].id, s2.id) | |
| def test_update_session_title(self): | |
| """Test updating a session's title and its timestamp.""" | |
| session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection) | |
| original_doc = self.get_doc_by_id(Collections.SESSION, session.id) | |
| self.assertIsNotNone(original_doc) | |
| success = session_repo.update_session_title(session.id, "New", collection_name=self.test_collection) | |
| self.assertTrue(success) | |
| updated_doc = self.get_doc_by_id(Collections.SESSION, session.id) | |
| self.assertIsNotNone(updated_doc) | |
| self.assertEqual(updated_doc["title"], "New") # type: ignore | |
| self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore | |
| self.assertFalse(session_repo.update_session_title(str(ObjectId()), "Ghost", collection_name=self.test_collection)) | |
| def test_delete_session(self): | |
| """Test deleting a session.""" | |
| session = session_repo.create_session(self.account_id, self.patient_id, "To Delete", collection_name=self.test_collection) | |
| self.assertTrue(session_repo.delete_session(session.id, collection_name=self.test_collection)) | |
| self.assertIsNone(session_repo.get_session(session.id, collection_name=self.test_collection)) | |
| self.assertFalse(session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection)) | |
| def test_prune_old_sessions(self): | |
| """Test deleting sessions older than a specified number of days.""" | |
| old_session = session_repo.create_session(self.account_id, self.patient_id, "Old", collection_name=self.test_collection) | |
| session_repo.create_session(self.account_id, self.patient_id, "New", collection_name=self.test_collection) | |
| old_date = datetime.now(timezone.utc) - timedelta(days=31) | |
| get_collection(self.test_collection).update_one( | |
| {"_id": ObjectId(old_session.id)}, {"$set": {"updated_at": old_date}} | |
| ) | |
| self.assertEqual(get_collection(self.test_collection).count_documents({}), 2) | |
| deleted_count = session_repo.prune_old_sessions(days=30, collection_name=self.test_collection) | |
| self.assertEqual(deleted_count, 1) | |
| class TestSessionRepositoryExceptions(BaseMongoTest): | |
| """Test class for the exception handling of all session repository functions.""" | |
| def setUp(self): | |
| """Set up a clean test environment before each test.""" | |
| super().setUp() | |
| self.test_collection = self._collections[Collections.SESSION] | |
| session_repo.init(collection_name=self.test_collection, drop=True) | |
| self.account_id = str(ObjectId()) | |
| self.patient_id = str(ObjectId()) | |
| def test_write_error_raises_action_failed(self, mock_get_collection): | |
| """Test that a WriteError during an operation is raised as ActionFailed.""" | |
| mock_collection = mock_get_collection.return_value | |
| mock_collection.update_one.side_effect = WriteError("Simulated schema validation error") | |
| mock_collection.find_one.return_value = {"messages": []} | |
| with self.assertRaises(ActionFailed): | |
| session_repo.add_message(str(ObjectId()), "content", True, collection_name=self.test_collection) | |
| def test_invalid_id_raises_action_failed(self): | |
| """Test that functions raise ActionFailed when given a malformed ObjectId string.""" | |
| with self.assertRaises(ActionFailed): | |
| session_repo.create_session("bad-id", self.patient_id, "t", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.get_user_sessions("bad-id", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.list_patient_sessions("bad-id", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.get_session("bad-id", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.get_session_messages("bad-id", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.update_session_title("bad-id", "t", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.delete_session("bad-id", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.add_message("bad-id", "t", True, collection_name=self.test_collection) | |
| def test_all_functions_raise_on_connection_error(self, mock_get_collection): | |
| """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed.""" | |
| mock_get_collection.side_effect = ConnectionFailure("Simulated connection error") | |
| with self.assertRaises(ActionFailed): | |
| session_repo.init(collection_name=self.test_collection, drop=True) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.create_session(self.account_id, self.patient_id, "t", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.get_user_sessions(self.account_id, collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.list_patient_sessions(self.patient_id, collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.get_session(str(ObjectId()), collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.get_session_messages(str(ObjectId()), collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.update_session_title(str(ObjectId()), "t", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.delete_session(str(ObjectId()), collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.prune_old_sessions(collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| session_repo.add_message(str(ObjectId()), "t", True, collection_name=self.test_collection) | |
| if __name__ == "__main__": | |
| logger().info("Starting MongoDB repository integration tests...") | |
| unittest.main(verbosity=2) | |
| logger().info("Tests completed.") | |