File size: 12,949 Bytes
29c313c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# tests/test_guard.py

import pytest
import os
from unittest.mock import Mock, patch, MagicMock
from src.services.guard import SafetyGuard


class TestSafetyGuard:
    """Test suite for SafetyGuard functionality."""

    @pytest.fixture
    def mock_settings(self):
        """Mock settings for testing."""
        with patch('src.services.guard.settings') as mock_settings:
            mock_settings.SAFETY_GUARD_TIMEOUT = 30
            mock_settings.SAFETY_GUARD_ENABLED = True
            mock_settings.SAFETY_GUARD_FAIL_OPEN = True
            yield mock_settings

    @pytest.fixture
    def mock_rotator(self):
        """Mock NVIDIA rotator for testing."""
        mock_rotator = Mock()
        mock_rotator.get_key.return_value = "test_api_key"
        mock_rotator.rotate.return_value = "test_api_key_2"
        return mock_rotator

    @pytest.fixture
    def safety_guard(self, mock_settings, mock_rotator):
        """Create SafetyGuard instance for testing."""
        return SafetyGuard(mock_rotator)

    def test_init_with_valid_config(self, mock_settings, mock_rotator):
        """Test SafetyGuard initialization with valid configuration."""
        guard = SafetyGuard(mock_rotator)
        assert guard.nvidia_rotator == mock_rotator
        assert guard.timeout_s == 30
        assert guard.enabled is True
        assert guard.fail_open is True

    def test_init_without_api_key(self, mock_settings):
        """Test SafetyGuard initialization without API key."""
        mock_rotator = Mock()
        mock_rotator.get_key.return_value = None
        with pytest.raises(ValueError, match="No NVIDIA API keys found"):
            SafetyGuard(mock_rotator)

    def test_chunk_text_empty(self, safety_guard):
        """Test text chunking with empty text."""
        result = safety_guard._chunk_text("")
        assert result == [""]

    def test_chunk_text_short(self, safety_guard):
        """Test text chunking with short text."""
        text = "Short text"
        result = safety_guard._chunk_text(text)
        assert result == [text]

    def test_chunk_text_long(self, safety_guard):
        """Test text chunking with long text."""
        text = "a" * 5000  # Longer than default chunk size
        result = safety_guard._chunk_text(text)
        assert len(result) > 1
        assert all(len(chunk) <= 2800 for chunk in result)

    def test_parse_guard_reply_safe(self, safety_guard):
        """Test parsing safe guard reply."""
        is_safe, reason = safety_guard._parse_guard_reply("SAFE")
        assert is_safe is True
        assert reason == ""

    def test_parse_guard_reply_unsafe(self, safety_guard):
        """Test parsing unsafe guard reply."""
        is_safe, reason = safety_guard._parse_guard_reply("UNSAFE: contains harmful content")
        assert is_safe is False
        assert reason == "contains harmful content"

    def test_parse_guard_reply_empty(self, safety_guard):
        """Test parsing empty guard reply."""
        is_safe, reason = safety_guard._parse_guard_reply("")
        assert is_safe is True
        assert reason == "guard_unavailable"

    def test_parse_guard_reply_unknown(self, safety_guard):
        """Test parsing unknown guard reply."""
        is_safe, reason = safety_guard._parse_guard_reply("UNKNOWN_RESPONSE")
        assert is_safe is False
        assert len(reason) <= 180

    def test_is_medical_query_symptoms(self, safety_guard):
        """Test medical query detection for symptoms."""
        queries = [
            "I have a headache",
            "My chest hurts",
            "I'm experiencing nausea",
            "I have a fever"
        ]
        for query in queries:
            assert safety_guard._is_medical_query(query) is True

    def test_is_medical_query_conditions(self, safety_guard):
        """Test medical query detection for conditions."""
        queries = [
            "What is diabetes?",
            "How to treat hypertension",
            "Symptoms of depression",
            "Cancer treatment options"
        ]
        for query in queries:
            assert safety_guard._is_medical_query(query) is True

    def test_is_medical_query_treatments(self, safety_guard):
        """Test medical query detection for treatments."""
        queries = [
            "What medication should I take?",
            "Physical therapy exercises",
            "Surgery options",
            "Vaccination schedule"
        ]
        for query in queries:
            assert safety_guard._is_medical_query(query) is True

    def test_is_medical_query_non_medical(self, safety_guard):
        """Test medical query detection for non-medical content."""
        queries = [
            "What's the weather like?",
            "How to cook pasta",
            "Best restaurants in town",
            "Python programming tutorial"
        ]
        for query in queries:
            assert safety_guard._is_medical_query(query) is False

    def test_detect_harmful_content_safe_medical(self, safety_guard):
        """Test harmful content detection for safe medical content."""
        text = "I have a headache and need advice on pain management"
        is_safe, reason = safety_guard._detect_harmful_content(text)
        assert is_safe is True
        assert reason == "medical_content"

    def test_detect_harmful_content_dangerous_medical(self, safety_guard):
        """Test harmful content detection for dangerous medical content."""
        text = "You should overdose on your medication"
        is_safe, reason = safety_guard._detect_harmful_content(text)
        assert is_safe is False
        assert "dangerous_medical_advice" in reason

    def test_detect_harmful_content_suicide(self, safety_guard):
        """Test harmful content detection for suicide-related content."""
        text = "How to kill yourself"
        is_safe, reason = safety_guard._detect_harmful_content(text)
        assert is_safe is False
        assert "suicide_self_harm" in reason

    def test_detect_harmful_content_violence(self, safety_guard):
        """Test harmful content detection for violent content."""
        text = "How to harm others with weapons"
        is_safe, reason = safety_guard._detect_harmful_content(text)
        assert is_safe is False
        assert "violence" in reason

    def test_assess_risk_level_medical_low(self, safety_guard):
        """Test risk assessment for low-risk medical content."""
        text = "I have a mild headache"
        risk_level, risk_score = safety_guard._assess_risk_level(text)
        assert risk_level == "low"
        assert 0.0 <= risk_score <= 1.0

    def test_assess_risk_level_medical_high(self, safety_guard):
        """Test risk assessment for high-risk medical content."""
        text = "You should commit suicide and overdose on drugs"
        risk_level, risk_score = safety_guard._assess_risk_level(text)
        assert risk_level == "high"
        assert risk_score > 0.6

    def test_assess_risk_level_non_medical(self, safety_guard):
        """Test risk assessment for non-medical content."""
        text = "This is just a normal conversation"
        risk_level, risk_score = safety_guard._assess_risk_level(text)
        assert risk_level == "low"
        assert 0.0 <= risk_score <= 1.0

    @patch('src.services.guard.requests.post')
    def test_call_guard_success(self, mock_post, safety_guard):
        """Test successful guard API call."""
        mock_response = Mock()
        mock_response.status_code = 200
        mock_response.json.return_value = {
            "choices": [{"message": {"content": "SAFE"}}]
        }
        mock_post.return_value = mock_response

        messages = [{"role": "user", "content": "test message"}]
        result = safety_guard._call_guard(messages)
        assert result == "SAFE"

    @patch('src.services.guard.requests.post')
    def test_call_guard_failure(self, mock_post, safety_guard):
        """Test guard API call failure."""
        mock_response = Mock()
        mock_response.status_code = 400
        mock_response.text = "Bad Request"
        mock_post.return_value = mock_response

        messages = [{"role": "user", "content": "test message"}]
        result = safety_guard._call_guard(messages)
        assert result == ""

    def test_check_user_query_disabled(self, mock_settings):
        """Test user query check when guard is disabled."""
        mock_settings.SAFETY_GUARD_ENABLED = False
        guard = SafetyGuard()
        
        is_safe, reason = guard.check_user_query("any query")
        assert is_safe is True
        assert reason == "guard_disabled"

    def test_check_user_query_medical(self, safety_guard):
        """Test user query check for medical content."""
        with patch.object(safety_guard, '_call_guard') as mock_call:
            is_safe, reason = safety_guard.check_user_query("I have a headache")
            assert is_safe is True
            assert reason == "medical_query"
            mock_call.assert_not_called()

    def test_check_user_query_non_medical_safe(self, safety_guard):
        """Test user query check for safe non-medical content."""
        with patch.object(safety_guard, '_call_guard') as mock_call, \
             patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
            mock_call.return_value = "SAFE"
            mock_parse.return_value = (True, "")
            
            is_safe, reason = safety_guard.check_user_query("What's the weather?")
            assert is_safe is True
            assert reason == ""

    def test_check_user_query_non_medical_unsafe(self, safety_guard):
        """Test user query check for unsafe non-medical content."""
        with patch.object(safety_guard, '_call_guard') as mock_call, \
             patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
            mock_call.return_value = "UNSAFE: harmful content"
            mock_parse.return_value = (False, "harmful content")
            
            is_safe, reason = safety_guard.check_user_query("How to harm others")
            assert is_safe is False
            assert reason == "harmful content"

    def test_check_model_answer_disabled(self, mock_settings):
        """Test model answer check when guard is disabled."""
        mock_settings.SAFETY_GUARD_ENABLED = False
        guard = SafetyGuard()
        
        is_safe, reason = guard.check_model_answer("query", "answer")
        assert is_safe is True
        assert reason == "guard_disabled"

    def test_check_model_answer_medical_safe(self, safety_guard):
        """Test model answer check for safe medical content."""
        is_safe, reason = safety_guard.check_model_answer(
            "I have a headache", 
            "You should rest and drink water"
        )
        assert is_safe is True
        assert reason == "medical_content"

    def test_check_model_answer_medical_unsafe(self, safety_guard):
        """Test model answer check for unsafe medical content."""
        is_safe, reason = safety_guard.check_model_answer(
            "I have a headache", 
            "You should overdose on medication"
        )
        assert is_safe is False
        assert "dangerous medical advice" in reason

    def test_check_model_answer_high_risk(self, safety_guard):
        """Test model answer check for high-risk content."""
        with patch.object(safety_guard, '_call_guard') as mock_call, \
             patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
            mock_call.return_value = "SAFE"
            mock_parse.return_value = (True, "")
            
            is_safe, reason = safety_guard.check_model_answer(
                "How to harm others", 
                "You should use weapons to attack people"
            )
            assert is_safe is True
            assert reason == "high_risk_validated"

    def test_enhance_messages_with_context_medical(self, safety_guard):
        """Test message enhancement for medical content."""
        messages = [{"role": "user", "content": "I have a headache"}]
        enhanced = safety_guard._enhance_messages_with_context(messages)
        
        assert len(enhanced) == 1
        assert "MEDICAL CONTEXT" in enhanced[0]["content"]
        assert "I have a headache" in enhanced[0]["content"]

    def test_enhance_messages_with_context_non_medical(self, safety_guard):
        """Test message enhancement for non-medical content."""
        messages = [{"role": "user", "content": "What's the weather?"}]
        enhanced = safety_guard._enhance_messages_with_context(messages)
        
        assert enhanced == messages  # Should remain unchanged

    def test_global_safety_guard_instance(self):
        """Test that global safety guard instance is None by default."""
        from src.services.guard import safety_guard
        assert safety_guard is None


if __name__ == "__main__":
    pytest.main([__file__])