File size: 6,656 Bytes
c541ca0
f15fec7
c541ca0
 
f15fec7
c541ca0
b2429d1
c541ca0
21c55a3
c541ca0
d753c16
c541ca0
 
 
f15fec7
c541ca0
 
 
 
 
 
 
 
 
 
 
 
 
 
f15fec7
c541ca0
f15fec7
c541ca0
 
 
 
 
 
f15fec7
 
57635b5
 
 
 
c541ca0
f15fec7
 
 
 
c541ca0
 
57635b5
c541ca0
 
 
57635b5
 
 
f15fec7
c541ca0
 
f15fec7
 
c541ca0
f15fec7
c541ca0
f15fec7
57635b5
f15fec7
 
c541ca0
 
57635b5
c541ca0
 
 
 
57635b5
 
f15fec7
c541ca0
 
f15fec7
 
c541ca0
f15fec7
 
 
 
 
c541ca0
f15fec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c541ca0
 
3a9698b
 
57635b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import unittest
from unittest.mock import patch

from bson import ObjectId
from pymongo.errors import ConnectionFailure

from src.data.connection import ActionFailed, Collections, get_collection
from src.data.repositories import patient as patient_repo
from src.models.patient import Patient
from src.utils.logger import logger
from ..base_test import BaseMongoTest


class TestPatientRepository(BaseMongoTest):
	"""Test class for the 'happy path' and edge cases of patient repository functions."""

	def setUp(self):
		"""Set up a clean test environment before each test."""
		super().setUp()
		self.test_collection = self._collections[Collections.PATIENT]
		patient_repo.init(collection_name=self.test_collection, drop=True)

	def test_init_functionality(self):
		"""Test that the init function correctly sets up the collection and its indexes."""
		self.assertIn(self.test_collection, self.db.list_collection_names())
		index_info = get_collection(self.test_collection).index_information()
		self.assertIn("assigned_doctor_id_1", index_info)

	def test_create_patient(self):
		"""Test patient creation with minimal and full data."""
		patient_id = patient_repo.create_patient(
			"John Doe", 45, "Male", "Caucasian", collection_name=self.test_collection
		)
		self.assertIsInstance(patient_id, str)
		doc = self.get_doc_by_id(Collections.PATIENT, patient_id)
		self.assertIsNotNone(doc)
		self.assertEqual(doc["name"], "John Doe") # type: ignore

		doctor_id = str(ObjectId())
		full_id = patient_repo.create_patient(
			name="Jane Doe", age=30, sex="Female", ethnicity="Asian",
			address="123 Wellness Way", phone="555-123-4567", email="jane.doe@example.com",
			medications=["Lisinopril"], past_assessment_summary="Routine check-up.",
			assigned_doctor_id=doctor_id, collection_name=self.test_collection
		)
		full_doc = self.get_doc_by_id(Collections.PATIENT, full_id)
		self.assertIsNotNone(full_doc)
		self.assertEqual(full_doc["email"], "jane.doe@example.com") # type: ignore
		self.assertEqual(str(full_doc["assigned_doctor_id"]), doctor_id) # type: ignore

	def test_get_patient_by_id(self):
		"""Test retrieving an existing patient by ID returns a Patient model."""
		patient_id = patient_repo.create_patient("GetMe", 33, "Female", "Other", collection_name=self.test_collection)
		patient = patient_repo.get_patient_by_id(patient_id, collection_name=self.test_collection)
		self.assertIsNotNone(patient)
		self.assertIsInstance(patient, Patient)
		self.assertEqual(patient.id, patient_id) # type: ignore
		self.assertEqual(patient.name, "GetMe") # type: ignore
		self.assertIsNone(patient_repo.get_patient_by_id(str(ObjectId()), collection_name=self.test_collection))

	def test_update_patient_profile(self):
		"""Test successfully updating an existing patient's profile."""
		patient_id = patient_repo.create_patient("Update Test", 25, "Male", "Hispanic", collection_name=self.test_collection)
		updates = {"age": 26, "phone": "555-9999"}
		modified_count = patient_repo.update_patient_profile(patient_id, updates, collection_name=self.test_collection)
		self.assertEqual(modified_count, 1)
		doc = self.get_doc_by_id(Collections.PATIENT, patient_id)
		self.assertIsNotNone(doc)
		self.assertEqual(doc["age"], 26) # type: ignore
		self.assertEqual(patient_repo.update_patient_profile(str(ObjectId()), {"name": "Ghost"}, collection_name=self.test_collection), 0)

	def test_search_patients(self):
		"""Test patient search functionality returns a list of Patient models."""
		patient_repo.create_patient("Alice Smith", 30, "Female", "Asian", collection_name=self.test_collection)
		patient_repo.create_patient("Bob Smith", 45, "Male", "Caucasian", collection_name=self.test_collection)
		results = patient_repo.search_patients("smith", collection_name=self.test_collection)
		self.assertEqual(len(results), 2)
		self.assertIsInstance(results[0], Patient)
		self.assertEqual(results[0].name, "Alice Smith")
		self.assertEqual(len(patient_repo.search_patients("s", limit=1, collection_name=self.test_collection)), 1)


class TestPatientRepositoryExceptions(BaseMongoTest):
	"""Test class for the exception handling of all patient repository functions."""

	def setUp(self):
		"""Set up the test environment before each test."""
		super().setUp()
		self.test_collection = self._collections[Collections.PATIENT]
		patient_repo.init(collection_name=self.test_collection, drop=True)

	def test_write_error_raises_action_failed(self):
		"""Test that creating or updating with data violating schema raises ActionFailed."""
		with self.assertRaises(ActionFailed):
			patient_repo.create_patient("Schema Test", 25, "InvalidValue", "Other", collection_name=self.test_collection)
		patient_id = patient_repo.create_patient("UpdateSchema", 30, "Male", "Other", collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			patient_repo.update_patient_profile(patient_id, {"ethnicity": 123}, collection_name=self.test_collection)

	def test_invalid_id_raises_action_failed(self):
		"""Test that functions raise ActionFailed when given a malformed ObjectId string."""
		with self.assertRaises(ActionFailed):
			patient_repo.create_patient("Test", 30, "Male", "Other", assigned_doctor_id="not-a-valid-id", collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			patient_repo.get_patient_by_id("not-a-valid-id", collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			patient_repo.update_patient_profile("not-a-valid-id", {"name": "test"}, collection_name=self.test_collection)

	@patch('src.data.repositories.patient.get_collection')
	def test_all_functions_raise_on_connection_error(self, mock_get_collection):
		"""Test that all repo functions catch generic PyMongoErrors and raise ActionFailed."""
		mock_get_collection.side_effect = ConnectionFailure("Simulated connection error")
		with self.assertRaises(ActionFailed):
			patient_repo.init(collection_name=self.test_collection, drop=True)
		with self.assertRaises(ActionFailed):
			patient_repo.create_patient("Test", 30, "Male", "Other", collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			patient_repo.get_patient_by_id(str(ObjectId()), collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			patient_repo.update_patient_profile(str(ObjectId()), {"name": "test"}, collection_name=self.test_collection)
		with self.assertRaises(ActionFailed):
			patient_repo.search_patients("test", collection_name=self.test_collection)

if __name__ == "__main__":
	logger().info("Starting MongoDB repository integration tests...")
	unittest.main(verbosity=2)
	logger().info("Tests completed.")