Spaces:
Sleeping
Sleeping
| import time | |
| import unittest | |
| from datetime import datetime | |
| from unittest.mock import patch | |
| from bson import ObjectId | |
| from pymongo.errors import ConnectionFailure | |
| from src.data.connection import ActionFailed, Collections, get_collection | |
| from src.data.repositories import account as account_repo | |
| from src.models.account import Account | |
| from src.utils.logger import logger | |
| from ..base_test import BaseMongoTest | |
| class TestAccountRepository(BaseMongoTest): | |
| """Test class for the 'happy path' and edge cases of account repository functions.""" | |
| def setUp(self): | |
| """Set up the test environment before each test.""" | |
| super().setUp() | |
| self.test_collection = self._collections[Collections.ACCOUNT] | |
| account_repo.init(collection_name=self.test_collection, drop=True) | |
| def test_init_functionality(self): | |
| """Test the init function's ability to create, drop, and preserve collections.""" | |
| self.assertIn(self.test_collection, self.db.list_collection_names()) | |
| account_repo.create_account("Persist Test", "Doctor", collection_name=self.test_collection) | |
| account_repo.init(collection_name=self.test_collection, drop=False) | |
| self.assertEqual(get_collection(self.test_collection).count_documents({}), 1) | |
| account_repo.init(collection_name=self.test_collection, drop=True) | |
| self.assertEqual(get_collection(self.test_collection).count_documents({}), 0) | |
| def test_create_account(self): | |
| """Test successful account creation, including optional fields.""" | |
| name, role = "Test Doctor", "Doctor" | |
| account_id = account_repo.create_account(name=name, role=role, collection_name=self.test_collection) | |
| self.assertIsInstance(account_id, str) | |
| doc = self.get_doc_by_id(Collections.ACCOUNT, account_id) | |
| self.assertIsNotNone(doc) | |
| self.assertEqual(doc["name"], name) # type: ignore | |
| spec_id = account_repo.create_account("Spec", "Nurse", specialty="Cardiology", collection_name=self.test_collection) | |
| spec_doc = self.get_doc_by_id(Collections.ACCOUNT, spec_id) | |
| self.assertEqual(spec_doc["specialty"], "Cardiology") # type: ignore | |
| def test_update_account_logic(self): | |
| """Test the specific business logic of the update_account function.""" | |
| account_id = account_repo.create_account("Update Logic", "Doctor", collection_name=self.test_collection) | |
| original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id) | |
| self.assertIsNotNone(original_doc) | |
| updates = {"name": "Updated Name", "created_at": datetime(2000, 1, 1)} | |
| success = account_repo.update_account(account_id, updates, collection_name=self.test_collection) | |
| self.assertTrue(success) | |
| updated_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id) | |
| self.assertIsNotNone(updated_doc) | |
| self.assertEqual(updated_doc["created_at"], original_doc["created_at"]) # type: ignore | |
| self.assertLess(original_doc["updated_at"], updated_doc["updated_at"]) # type: ignore | |
| self.assertFalse(account_repo.update_account(str(ObjectId()), {"name": "No One"}, collection_name=self.test_collection)) | |
| def test_get_account_logic(self): | |
| """Test that get_account updates 'last_seen' and returns a valid Account model.""" | |
| account_id = account_repo.create_account("GetMe", "Doctor", collection_name=self.test_collection) | |
| original_doc = self.get_doc_by_id(Collections.ACCOUNT, account_id) | |
| self.assertIsNotNone(original_doc) | |
| time.sleep(0.01) # Ensure timestamp will be different | |
| account = account_repo.get_account(account_id, collection_name=self.test_collection) | |
| self.assertIsNotNone(account) | |
| self.assertIsInstance(account, Account) | |
| self.assertLess(original_doc["last_seen"], account.last_seen) # type: ignore | |
| self.assertEqual(original_doc["updated_at"], account.updated_at) # type: ignore | |
| self.assertIsNone(account_repo.get_account(str(ObjectId()), collection_name=self.test_collection)) | |
| def test_get_account_by_name(self): | |
| """Test retrieving an account by name and check for deprecation warning.""" | |
| name = "FindByName" | |
| account_repo.create_account(name, "Nurse", collection_name=self.test_collection) | |
| account = account_repo.get_account_by_name(name, collection_name=self.test_collection) | |
| self.assertIsNotNone(account) | |
| self.assertIsInstance(account, Account) | |
| self.assertEqual(account.name, name) # type: ignore | |
| self.assertIsNone(account_repo.get_account_by_name("NonExistent", collection_name=self.test_collection)) | |
| def test_search_accounts(self): | |
| """Test search functionality returns a list of Account models.""" | |
| account_repo.create_account("Alpha Doctor", "Doctor", collection_name=self.test_collection) | |
| account_repo.create_account("Beta Nurse", "Nurse", collection_name=self.test_collection) | |
| results = account_repo.search_accounts("alpha", collection_name=self.test_collection) | |
| self.assertEqual(len(results), 1) | |
| self.assertIsInstance(results[0], Account) | |
| self.assertEqual(results[0].name, "Alpha Doctor") | |
| self.assertEqual(len(account_repo.search_accounts("NonExistent", collection_name=self.test_collection)), 0) | |
| def test_get_all_accounts(self): | |
| """Test retrieving all accounts, verifying sorting and model type.""" | |
| account_repo.create_account("Charlie", "Doctor", collection_name=self.test_collection) | |
| account_repo.create_account("Alpha", "Nurse", collection_name=self.test_collection) | |
| all_accounts = account_repo.get_all_accounts(collection_name=self.test_collection) | |
| self.assertEqual(len(all_accounts), 2) | |
| self.assertIsInstance(all_accounts[0], Account) | |
| self.assertEqual(all_accounts[0].name, "Alpha") | |
| self.assertEqual(all_accounts[1].name, "Charlie") | |
| def test_get_account_frame(self): | |
| """Test retrieving accounts as a pandas DataFrame.""" | |
| df_empty = account_repo.get_account_frame(collection_name=self.test_collection) | |
| self.assertTrue(df_empty.empty) | |
| account_repo.create_account("Frame Alpha", "Doctor", collection_name=self.test_collection) | |
| df_full = account_repo.get_account_frame(collection_name=self.test_collection) | |
| self.assertEqual(len(df_full), 1) | |
| class TestAccountRepositoryExceptions(BaseMongoTest): | |
| """Test class for the exception handling of all account repository functions.""" | |
| def setUp(self): | |
| """Set up the test environment before each test.""" | |
| super().setUp() | |
| self.test_collection = self._collections[Collections.ACCOUNT] | |
| account_repo.init(collection_name=self.test_collection, drop=True) | |
| get_collection(self.test_collection).create_index("name", unique=True) | |
| def test_create_account_write_error(self): | |
| """Test that creating an account with invalid data raises ActionFailed.""" | |
| account_repo.create_account("Duplicate Name", "Doctor", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.create_account("Duplicate Name", "Nurse", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.create_account("Schema Test", "InvalidRole", collection_name=self.test_collection) | |
| def test_invalid_id_raises_action_failed(self): | |
| """Test that functions raise ActionFailed when given a malformed ObjectId string.""" | |
| with self.assertRaises(ActionFailed): | |
| account_repo.get_account("not-a-valid-id", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.update_account("not-a-valid-id", {"name": "test"}, collection_name=self.test_collection) | |
| def test_all_functions_raise_on_connection_error(self, mock_get_collection): | |
| """Test that all repo functions catch generic PyMongoErrors and raise ActionFailed.""" | |
| mock_get_collection.side_effect = ConnectionFailure("Simulated connection error") | |
| with self.assertRaises(ActionFailed): | |
| account_repo.init(collection_name=self.test_collection, drop=True) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.get_account_frame(collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.create_account("test", "Doctor", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.update_account(str(ObjectId()), {"name": "test"}, collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.get_account(str(ObjectId()), collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.get_account_by_name("test", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.search_accounts("test", collection_name=self.test_collection) | |
| with self.assertRaises(ActionFailed): | |
| account_repo.get_all_accounts(collection_name=self.test_collection) | |
| if __name__ == "__main__": | |
| logger().info("Starting MongoDB repository integration tests...") | |
| unittest.main(verbosity=2) | |
| logger().info("Tests completed.") | |