dylanglenister commited on
Commit
e74c715
·
1 Parent(s): a57b9dc

Added tests for api endpoints.

Browse files

They use httpx and pytest. Pytest has been added to the dev-requirements

requirements-dev.txt CHANGED
@@ -10,3 +10,4 @@ python-dotenv
10
  pymongo
11
  pandas
12
  python-multipart
 
 
10
  pymongo
11
  pandas
12
  python-multipart
13
+ pytest
tests/routes/test_account_routes.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/routes/test_account_routes.py
2
+
3
+ from datetime import datetime, timezone
4
+ from unittest.mock import MagicMock
5
+
6
+ import pytest
7
+ from fastapi import FastAPI, status
8
+ from fastapi.testclient import TestClient
9
+
10
+ from src.api.routes.account import router as account_router
11
+ from src.core.state import get_state
12
+ from src.data.connection import ActionFailed
13
+ from src.models.account import Account
14
+
15
+ # --- Test Setup: Mocking and Dependency Injection ---
16
+
17
+ # Mock the service layer that the API routes depend on.
18
+ mock_memory_manager = MagicMock()
19
+
20
+ # This mock AppState and override function will replace the real dependencies.
21
+ class MockAppState:
22
+ def __init__(self):
23
+ self.memory_manager = mock_memory_manager
24
+
25
+ def override_get_state() -> MockAppState:
26
+ return MockAppState()
27
+
28
+ # Create a FastAPI app instance for testing and apply the dependency override.
29
+ app = FastAPI()
30
+ app.include_router(account_router)
31
+ app.dependency_overrides[get_state] = override_get_state
32
+
33
+
34
+ # --- Fixtures ---
35
+
36
+ @pytest.fixture
37
+ def client():
38
+ """Provides a TestClient for making requests to the app."""
39
+ with TestClient(app) as c:
40
+ yield c
41
+
42
+ @pytest.fixture(autouse=True)
43
+ def reset_mocks():
44
+ """Resets mocks before each test to ensure test isolation."""
45
+ mock_memory_manager.reset_mock()
46
+
47
+
48
+ # --- Test Data ---
49
+
50
+ # A sample account object that can be reused in tests.
51
+ fake_account_dict = {
52
+ "_id": "60c72b2f9b1d8b3b3c9d8b1a",
53
+ "name": "Dr. Test",
54
+ "role": "Doctor",
55
+ "specialty": "Testing",
56
+ "created_at": datetime.now(timezone.utc).isoformat(),
57
+ "updated_at": datetime.now(timezone.utc).isoformat(),
58
+ "last_seen": datetime.now(timezone.utc).isoformat(),
59
+ }
60
+ # Use model_validate to handle the string-based datetimes from the dict.
61
+ fake_account = Account.model_validate(fake_account_dict)
62
+
63
+
64
+ # --- Tests for GET /account ---
65
+
66
+ def test_get_all_accounts_success(client: TestClient):
67
+ """Tests successfully retrieving all accounts when no query is provided."""
68
+ mock_memory_manager.get_all_accounts.return_value = [fake_account]
69
+
70
+ response = client.get("/account?limit=10")
71
+
72
+ assert response.status_code == status.HTTP_200_OK
73
+ assert len(response.json()) == 1
74
+ assert response.json()[0]["name"] == "Dr. Test"
75
+ mock_memory_manager.get_all_accounts.assert_called_once_with(limit=10)
76
+ mock_memory_manager.search_accounts.assert_not_called()
77
+
78
+ def test_search_accounts_success(client: TestClient):
79
+ """Tests successfully searching for accounts with a query."""
80
+ mock_memory_manager.search_accounts.return_value = [fake_account]
81
+
82
+ response = client.get("/account?q=Test&limit=5")
83
+
84
+ assert response.status_code == status.HTTP_200_OK
85
+ assert response.json()[0]["name"] == "Dr. Test"
86
+ mock_memory_manager.search_accounts.assert_called_once_with("Test", limit=5)
87
+ mock_memory_manager.get_all_accounts.assert_not_called()
88
+
89
+ def test_get_accounts_db_error(client: TestClient):
90
+ """Tests that a 500 error is returned if the database fails."""
91
+ mock_memory_manager.get_all_accounts.side_effect = ActionFailed("DB connection lost")
92
+
93
+ response = client.get("/account")
94
+
95
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
96
+ assert response.json()["detail"] == "A database error occurred."
97
+
98
+
99
+ # --- Tests for POST /account ---
100
+
101
+ def test_create_account_success(client: TestClient):
102
+ """Tests successful creation of a new account."""
103
+ new_account_id = "new_fake_id_123"
104
+ mock_memory_manager.create_account.return_value = new_account_id
105
+ mock_memory_manager.get_account.return_value = fake_account
106
+
107
+ account_data = {"name": "Dr. Test", "role": "Doctor", "specialty": "Testing"}
108
+ response = client.post("/account", json=account_data)
109
+
110
+ assert response.status_code == status.HTTP_201_CREATED
111
+ assert response.json()["name"] == "Dr. Test"
112
+ mock_memory_manager.create_account.assert_called_once_with(
113
+ name="Dr. Test", role="Doctor", specialty="Testing"
114
+ )
115
+ mock_memory_manager.get_account.assert_called_once_with(new_account_id)
116
+
117
+ def test_create_account_creation_fails(client: TestClient):
118
+ """Tests the case where the memory manager fails to return an ID."""
119
+ mock_memory_manager.create_account.return_value = None
120
+
121
+ account_data = {"name": "Dr. Fail", "role": "Doctor"}
122
+ response = client.post("/account", json=account_data)
123
+
124
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
125
+ assert response.json()["detail"] == "Failed to create account ID."
126
+
127
+ def test_create_account_not_found_after_creation(client: TestClient):
128
+ """Tests the edge case where the account can't be retrieved after creation."""
129
+ new_account_id = "new_fake_id_123"
130
+ mock_memory_manager.create_account.return_value = new_account_id
131
+ mock_memory_manager.get_account.return_value = None # Simulate not found
132
+
133
+ account_data = {"name": "Dr. Ghost", "role": "Doctor"}
134
+ response = client.post("/account", json=account_data)
135
+
136
+ assert response.status_code == status.HTTP_404_NOT_FOUND
137
+ assert response.json()["detail"] == "Could not find newly created account."
138
+
139
+ def test_create_account_action_failed(client: TestClient):
140
+ """Tests that a 400 error is returned for data-related creation failures."""
141
+ error_message = "Account with this name already exists."
142
+ mock_memory_manager.create_account.side_effect = ActionFailed(error_message)
143
+
144
+ account_data = {"name": "Dr. Duplicate", "role": "Doctor"}
145
+ response = client.post("/account", json=account_data)
146
+
147
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
148
+ assert response.json()["detail"] == error_message
149
+
150
+
151
+ # --- Tests for GET /account/{account_id} ---
152
+
153
+ def test_get_account_by_id_success(client: TestClient):
154
+ """Tests successfully retrieving a single account by its ID."""
155
+ mock_memory_manager.get_account.return_value = fake_account
156
+
157
+ response = client.get(f"/account/{fake_account.id}")
158
+
159
+ assert response.status_code == status.HTTP_200_OK
160
+ assert response.json()["name"] == fake_account.name
161
+ mock_memory_manager.get_account.assert_called_once_with(str(fake_account.id))
162
+
163
+ def test_get_account_by_id_not_found(client: TestClient):
164
+ """Tests that a 404 error is returned for a non-existent account ID."""
165
+ mock_memory_manager.get_account.return_value = None
166
+
167
+ response = client.get("/account/non_existent_id")
168
+
169
+ assert response.status_code == status.HTTP_404_NOT_FOUND
170
+ assert response.json()["detail"] == "Account not found"
171
+
172
+ def test_get_account_by_id_db_error(client: TestClient):
173
+ """Tests that a 500 error is returned if the database fails during retrieval."""
174
+ mock_memory_manager.get_account.side_effect = ActionFailed("DB connection lost")
175
+
176
+ response = client.get(f"/account/{fake_account.id}")
177
+
178
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
179
+ assert response.json()["detail"] == "A database error occurred."
tests/routes/test_patient_routes.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/routes/test_patient_routes.py
2
+
3
+ from datetime import datetime, timezone
4
+ from unittest.mock import MagicMock
5
+
6
+ import pytest
7
+ from fastapi import FastAPI, status
8
+ from fastapi.testclient import TestClient
9
+
10
+ from src.api.routes.patient import router as patient_router
11
+ from src.core.state import get_state
12
+ from src.models.patient import Patient
13
+ from src.models.session import Session
14
+
15
+ # --- Test Setup: Mocking and Dependency Injection ---
16
+
17
+ # Mock the service layer that the API routes depend on.
18
+ mock_memory_manager = MagicMock()
19
+
20
+ # This mock AppState and override function will replace the real dependencies.
21
+ class MockAppState:
22
+ def __init__(self):
23
+ self.memory_manager = mock_memory_manager
24
+
25
+ def override_get_state() -> MockAppState:
26
+ return MockAppState()
27
+
28
+ # Create a FastAPI app instance for testing and apply the dependency override.
29
+ app = FastAPI()
30
+ app.include_router(patient_router)
31
+ app.dependency_overrides[get_state] = override_get_state
32
+
33
+
34
+ # --- Fixtures ---
35
+
36
+ @pytest.fixture
37
+ def client():
38
+ """Provides a TestClient for making requests to the app."""
39
+ with TestClient(app) as c:
40
+ yield c
41
+
42
+ @pytest.fixture(autouse=True)
43
+ def reset_mocks():
44
+ """Resets mocks before each test to ensure test isolation."""
45
+ mock_memory_manager.reset_mock()
46
+
47
+
48
+ # --- Test Data ---
49
+
50
+ fake_patient_dict = {
51
+ "_id": "patient123",
52
+ "name": "Jane Doe",
53
+ "age": 42,
54
+ "sex": "Female",
55
+ "ethnicity": "Caucasian",
56
+ "created_at": datetime.now(timezone.utc).isoformat(),
57
+ "updated_at": datetime.now(timezone.utc).isoformat(),
58
+ }
59
+ fake_patient = Patient.model_validate(fake_patient_dict)
60
+
61
+ fake_session_dict = {
62
+ "_id": "session456",
63
+ "account_id": "doctor789",
64
+ "patient_id": "patient123",
65
+ "title": "Checkup",
66
+ "created_at": datetime.now(timezone.utc).isoformat(),
67
+ "updated_at": datetime.now(timezone.utc).isoformat(),
68
+ "messages": [],
69
+ }
70
+ fake_session = Session.model_validate(fake_session_dict)
71
+
72
+
73
+ # --- Tests for GET /patient ---
74
+
75
+ def test_search_patients_success(client: TestClient):
76
+ """Tests successfully searching for patients with a query."""
77
+ mock_memory_manager.search_patients.return_value = [fake_patient]
78
+
79
+ response = client.get("/patient?q=Jane&limit=5")
80
+
81
+ assert response.status_code == status.HTTP_200_OK
82
+ assert len(response.json()) == 1
83
+ assert response.json()[0]["name"] == "Jane Doe"
84
+ mock_memory_manager.search_patients.assert_called_once_with("Jane", limit=5)
85
+
86
+ def test_search_patients_requires_query(client: TestClient):
87
+ """Tests that a 400 error is returned if the search query 'q' is missing."""
88
+ response = client.get("/patient")
89
+
90
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
91
+ assert response.json()["detail"] == "A search query 'q' is required."
92
+
93
+
94
+ # --- Tests for POST /patient ---
95
+
96
+ def test_create_patient_success(client: TestClient):
97
+ """Tests successful creation of a new patient profile."""
98
+ new_patient_id = "new_patient_abc"
99
+ mock_memory_manager.create_patient.return_value = new_patient_id
100
+ mock_memory_manager.get_patient_by_id.return_value = fake_patient
101
+
102
+ patient_data = {"name": "Jane Doe", "age": 42, "sex": "Female", "ethnicity": "Caucasian"}
103
+ response = client.post("/patient", json=patient_data)
104
+
105
+ assert response.status_code == status.HTTP_201_CREATED
106
+ assert response.json()["name"] == "Jane Doe"
107
+ mock_memory_manager.create_patient.assert_called_once()
108
+ mock_memory_manager.get_patient_by_id.assert_called_once_with(new_patient_id)
109
+
110
+ def test_create_patient_invalid_data(client: TestClient):
111
+ """Tests failure when the memory manager cannot create a patient (e.g., bad data)."""
112
+ mock_memory_manager.create_patient.return_value = None
113
+
114
+ patient_data = {"name": "Invalid", "age": -5, "sex": "F", "ethnicity": "Unknown"}
115
+ response = client.post("/patient", json=patient_data)
116
+
117
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
118
+ assert response.json()["detail"] == "Patient could not be created due to invalid data."
119
+
120
+ def test_create_patient_not_found_after_creation(client: TestClient):
121
+ """Tests the edge case where the patient can't be retrieved after creation."""
122
+ new_patient_id = "new_patient_abc"
123
+ mock_memory_manager.create_patient.return_value = new_patient_id
124
+ mock_memory_manager.get_patient_by_id.return_value = None
125
+
126
+ patient_data = {"name": "Jane Doe", "age": 42, "sex": "Female", "ethnicity": "Caucasian"}
127
+ response = client.post("/patient", json=patient_data)
128
+
129
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
130
+ assert response.json()["detail"] == "Could not find newly created patient."
131
+
132
+
133
+ # --- Tests for GET /patient/{patient_id} ---
134
+
135
+ def test_get_patient_by_id_success(client: TestClient):
136
+ """Tests successfully retrieving a single patient by their ID."""
137
+ mock_memory_manager.get_patient_by_id.return_value = fake_patient
138
+
139
+ response = client.get(f"/patient/{fake_patient.id}")
140
+
141
+ assert response.status_code == status.HTTP_200_OK
142
+ assert response.json()["name"] == fake_patient.name
143
+ mock_memory_manager.get_patient_by_id.assert_called_once_with(str(fake_patient.id))
144
+
145
+ def test_get_patient_by_id_not_found(client: TestClient):
146
+ """Tests that a 404 error is returned for a non-existent patient ID."""
147
+ mock_memory_manager.get_patient_by_id.return_value = None
148
+
149
+ response = client.get("/patient/non_existent_id")
150
+
151
+ assert response.status_code == status.HTTP_404_NOT_FOUND
152
+ assert response.json()["detail"] == "Patient not found"
153
+
154
+
155
+ # --- Tests for PATCH /patient/{patient_id} ---
156
+
157
+ def test_update_patient_success(client: TestClient):
158
+ """Tests a successful patient profile update."""
159
+ mock_memory_manager.update_patient_profile.return_value = 1 # 1 document modified
160
+ mock_memory_manager.get_patient_by_id.return_value = fake_patient
161
+
162
+ update_data = {"age": 43}
163
+ response = client.patch(f"/patient/{fake_patient.id}", json=update_data)
164
+
165
+ assert response.status_code == status.HTTP_200_OK
166
+ assert response.json()["name"] == fake_patient.name
167
+ mock_memory_manager.update_patient_profile.assert_called_once_with(str(fake_patient.id), update_data)
168
+
169
+ def test_update_patient_no_fields_provided(client: TestClient):
170
+ """Tests that a 400 error is returned if the update request body is empty."""
171
+ response = client.patch(f"/patient/{fake_patient.id}", json={})
172
+
173
+ assert response.status_code == status.HTTP_400_BAD_REQUEST
174
+ assert response.json()["detail"] == "No update fields provided."
175
+
176
+ def test_update_patient_not_found(client: TestClient):
177
+ """Tests updating a non-existent patient."""
178
+ mock_memory_manager.update_patient_profile.return_value = 0
179
+ # The route logic then checks if the patient exists. Simulate it not existing.
180
+ mock_memory_manager.get_patient_by_id.return_value = None
181
+
182
+ response = client.patch("/patient/non_existent_id", json={"age": 50})
183
+
184
+ assert response.status_code == status.HTTP_404_NOT_FOUND
185
+ assert response.json()["detail"] == "Patient not found"
186
+ # Check that get_patient_by_id was called as part of the 404 check
187
+ mock_memory_manager.get_patient_by_id.assert_called_once_with("non_existent_id")
188
+
189
+
190
+ # --- Tests for GET /patient/{patient_id}/session ---
191
+
192
+ def test_list_sessions_for_patient_success(client: TestClient):
193
+ """Tests successfully listing all sessions for a given patient."""
194
+ mock_memory_manager.get_patient_by_id.return_value = fake_patient
195
+ mock_memory_manager.list_patient_sessions.return_value = [fake_session]
196
+
197
+ response = client.get(f"/patient/{fake_patient.id}/session")
198
+
199
+ assert response.status_code == status.HTTP_200_OK
200
+ assert len(response.json()) == 1
201
+ assert response.json()[0]["title"] == "Checkup"
202
+ # Verify it first checked that the patient exists
203
+ mock_memory_manager.get_patient_by_id.assert_called_once_with(str(fake_patient.id))
204
+ mock_memory_manager.list_patient_sessions.assert_called_once_with(str(fake_patient.id))
205
+
206
+ def test_list_sessions_for_patient_not_found(client: TestClient):
207
+ """Tests listing sessions for a non-existent patient."""
208
+ mock_memory_manager.get_patient_by_id.return_value = None
209
+
210
+ response = client.get("/patient/non_existent_id/session")
211
+
212
+ assert response.status_code == status.HTTP_404_NOT_FOUND
213
+ assert response.json()["detail"] == "Patient not found"
214
+ mock_memory_manager.list_patient_sessions.assert_not_called()
tests/routes/test_session_routes.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/routes/test_session_routes.py
2
+
3
+ from datetime import datetime, timezone
4
+ from unittest.mock import AsyncMock, MagicMock, patch
5
+
6
+ import pytest
7
+ from fastapi import FastAPI, status
8
+ from fastapi.testclient import TestClient
9
+
10
+ from src.api.routes.session import router as session_router
11
+ from src.core.state import get_state
12
+ from src.models.session import Message, Session
13
+
14
+ # --- Test Setup: Mocking and Dependency Injection ---
15
+
16
+ mock_memory_manager = MagicMock()
17
+
18
+ class MockAppState:
19
+ def __init__(self):
20
+ self.memory_manager = mock_memory_manager
21
+ # Add mock rotators for the session endpoint
22
+ self.gemini_rotator = MagicMock()
23
+ self.nvidia_rotator = MagicMock()
24
+
25
+ def override_get_state() -> MockAppState:
26
+ return MockAppState()
27
+
28
+ app = FastAPI()
29
+ app.include_router(session_router)
30
+ app.dependency_overrides[get_state] = override_get_state
31
+
32
+
33
+ # --- Fixtures ---
34
+
35
+ @pytest.fixture
36
+ def client():
37
+ """Provides a TestClient for making requests to the app."""
38
+ with TestClient(app) as c:
39
+ yield c
40
+
41
+ @pytest.fixture(autouse=True)
42
+ def reset_mocks():
43
+ """Resets mocks before each test to ensure test isolation."""
44
+ mock_memory_manager.reset_mock()
45
+
46
+
47
+ # --- Test Data ---
48
+
49
+ fake_session_dict = {
50
+ "_id": "session456",
51
+ "account_id": "doctor789",
52
+ "patient_id": "patient123",
53
+ "title": "Checkup",
54
+ "created_at": datetime.now(timezone.utc).isoformat(),
55
+ "updated_at": datetime.now(timezone.utc).isoformat(),
56
+ "messages": [],
57
+ }
58
+ fake_session = Session.model_validate(fake_session_dict)
59
+
60
+ fake_message_dict = {
61
+ "_id": 1,
62
+ "sent_by_user": True,
63
+ "content": "Hello",
64
+ "timestamp": datetime.now(timezone.utc).isoformat(),
65
+ }
66
+ fake_message = Message.model_validate(fake_message_dict)
67
+
68
+
69
+ # --- Tests for POST /session ---
70
+
71
+ def test_create_session_success(client: TestClient):
72
+ """Tests successful creation of a new chat session."""
73
+ mock_memory_manager.create_session.return_value = fake_session
74
+
75
+ session_data = {"account_id": "doctor789", "patient_id": "patient123", "title": "Checkup"}
76
+ response = client.post("/session", json=session_data)
77
+
78
+ assert response.status_code == status.HTTP_201_CREATED
79
+ assert response.json()["title"] == "Checkup"
80
+ mock_memory_manager.create_session.assert_called_once_with(
81
+ user_id="doctor789", patient_id="patient123", title="Checkup"
82
+ )
83
+
84
+ def test_create_session_failure(client: TestClient):
85
+ """Tests that a 500 error is returned if session creation fails."""
86
+ mock_memory_manager.create_session.return_value = None
87
+
88
+ session_data = {"account_id": "doctor789", "patient_id": "patient123"}
89
+ response = client.post("/session", json=session_data)
90
+
91
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
92
+ assert response.json()["detail"] == "Failed to create session."
93
+
94
+
95
+ # --- Tests for GET /session/{session_id} ---
96
+
97
+ def test_get_session_success(client: TestClient):
98
+ """Tests successfully retrieving a session by its ID."""
99
+ mock_memory_manager.get_session.return_value = fake_session
100
+
101
+ response = client.get(f"/session/{fake_session.id}")
102
+
103
+ assert response.status_code == status.HTTP_200_OK
104
+ assert response.json()["title"] == fake_session.title
105
+ mock_memory_manager.get_session.assert_called_once_with(str(fake_session.id))
106
+
107
+ def test_get_session_not_found(client: TestClient):
108
+ """Tests that a 404 error is returned for a non-existent session."""
109
+ mock_memory_manager.get_session.return_value = None
110
+
111
+ response = client.get("/session/non_existent_id")
112
+
113
+ assert response.status_code == status.HTTP_404_NOT_FOUND
114
+ assert response.json()["detail"] == "Session not found"
115
+
116
+
117
+ # --- Tests for DELETE /session/{session_id} ---
118
+
119
+ def test_delete_session_success(client: TestClient):
120
+ """Tests successful deletion of a session."""
121
+ mock_memory_manager.delete_session.return_value = True
122
+
123
+ response = client.delete(f"/session/{fake_session.id}")
124
+
125
+ assert response.status_code == status.HTTP_204_NO_CONTENT
126
+ mock_memory_manager.delete_session.assert_called_once_with(str(fake_session.id))
127
+
128
+ def test_delete_session_not_found(client: TestClient):
129
+ """Tests that a 404 is returned when trying to delete a non-existent session."""
130
+ mock_memory_manager.delete_session.return_value = False
131
+
132
+ response = client.delete("/session/non_existent_id")
133
+
134
+ assert response.status_code == status.HTTP_404_NOT_FOUND
135
+ assert response.json()["detail"] == "Session not found or already deleted"
136
+
137
+
138
+ # --- Tests for GET /session/{session_id}/messages ---
139
+
140
+ def test_list_messages_success(client: TestClient):
141
+ """Tests successfully listing messages for a session."""
142
+ mock_memory_manager.get_session.return_value = fake_session
143
+ mock_memory_manager.get_session_messages.return_value = [fake_message]
144
+
145
+ response = client.get(f"/session/{fake_session.id}/messages?limit=10")
146
+
147
+ assert response.status_code == status.HTTP_200_OK
148
+ assert len(response.json()) == 1
149
+ assert response.json()[0]["content"] == "Hello"
150
+ mock_memory_manager.get_session.assert_called_once_with(str(fake_session.id))
151
+ mock_memory_manager.get_session_messages.assert_called_once_with(str(fake_session.id), 10)
152
+
153
+ def test_list_messages_session_not_found(client: TestClient):
154
+ """Tests listing messages for a non-existent session."""
155
+ mock_memory_manager.get_session.return_value = None
156
+
157
+ response = client.get("/session/non_existent_id/messages")
158
+
159
+ assert response.status_code == status.HTTP_404_NOT_FOUND
160
+ assert response.json()["detail"] == "Session not found"
161
+
162
+
163
+ # --- Tests for POST /session/{session_id}/messages ---
164
+
165
+ @patch('src.api.routes.session.generate_medical_response', new_callable=AsyncMock)
166
+ def test_post_chat_message_success(mock_generate_response: AsyncMock, client: TestClient):
167
+ """Tests the full, successful flow of posting a message and getting a response."""
168
+ # Arrange: Mock all async dependencies
169
+ mock_memory_manager.get_enhanced_context = AsyncMock(return_value="Enhanced context.")
170
+ mock_memory_manager.process_medical_exchange = AsyncMock(return_value="Generated summary.")
171
+ mock_generate_response.return_value = "This is the AI response."
172
+
173
+ chat_data = {"account_id": "doc1", "patient_id": "pat1", "message": "Patient has a fever."}
174
+ response = client.post(f"/session/{fake_session.id}/messages", json=chat_data)
175
+
176
+ assert response.status_code == status.HTTP_200_OK
177
+ json_response = response.json()
178
+ assert json_response["response"] == "This is the AI response."
179
+ assert json_response["medical_context"] == "Enhanced context."
180
+
181
+ # Assert that all async functions were called correctly
182
+ mock_memory_manager.get_enhanced_context.assert_awaited_once()
183
+ mock_generate_response.assert_awaited_once()
184
+ mock_memory_manager.process_medical_exchange.assert_awaited_once()
185
+
186
+ @patch('src.api.routes.session.generate_medical_response', new_callable=AsyncMock)
187
+ def test_post_chat_message_context_error(mock_generate_response: AsyncMock, client: TestClient):
188
+ """Tests failure during the context generation step."""
189
+ mock_memory_manager.get_enhanced_context = AsyncMock(side_effect=Exception("Context DB failed"))
190
+
191
+ chat_data = {"account_id": "doc1", "patient_id": "pat1", "message": "Patient has a fever."}
192
+ response = client.post(f"/session/{fake_session.id}/messages", json=chat_data)
193
+
194
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
195
+ assert response.json()["detail"] == "Failed to build medical context."
196
+ mock_generate_response.assert_not_awaited() # Should fail before this is called
197
+
198
+ @patch('src.api.routes.session.generate_medical_response', new_callable=AsyncMock)
199
+ def test_post_chat_message_generation_error(mock_generate_response: AsyncMock, client: TestClient):
200
+ """Tests failure during the AI response generation step."""
201
+ mock_memory_manager.get_enhanced_context = AsyncMock(return_value="Context")
202
+ mock_generate_response.side_effect = Exception("AI model API is down")
203
+
204
+ chat_data = {"account_id": "doc1", "patient_id": "pat1", "message": "Patient has a fever."}
205
+ response = client.post(f"/session/{fake_session.id}/messages", json=chat_data)
206
+
207
+ assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
208
+ assert response.json()["detail"] == "Failed to generate AI response."