Spaces:
Sleeping
Sleeping
File size: 12,949 Bytes
29c313c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
# tests/test_guard.py
import pytest
import os
from unittest.mock import Mock, patch, MagicMock
from src.services.guard import SafetyGuard
class TestSafetyGuard:
"""Test suite for SafetyGuard functionality."""
@pytest.fixture
def mock_settings(self):
"""Mock settings for testing."""
with patch('src.services.guard.settings') as mock_settings:
mock_settings.SAFETY_GUARD_TIMEOUT = 30
mock_settings.SAFETY_GUARD_ENABLED = True
mock_settings.SAFETY_GUARD_FAIL_OPEN = True
yield mock_settings
@pytest.fixture
def mock_rotator(self):
"""Mock NVIDIA rotator for testing."""
mock_rotator = Mock()
mock_rotator.get_key.return_value = "test_api_key"
mock_rotator.rotate.return_value = "test_api_key_2"
return mock_rotator
@pytest.fixture
def safety_guard(self, mock_settings, mock_rotator):
"""Create SafetyGuard instance for testing."""
return SafetyGuard(mock_rotator)
def test_init_with_valid_config(self, mock_settings, mock_rotator):
"""Test SafetyGuard initialization with valid configuration."""
guard = SafetyGuard(mock_rotator)
assert guard.nvidia_rotator == mock_rotator
assert guard.timeout_s == 30
assert guard.enabled is True
assert guard.fail_open is True
def test_init_without_api_key(self, mock_settings):
"""Test SafetyGuard initialization without API key."""
mock_rotator = Mock()
mock_rotator.get_key.return_value = None
with pytest.raises(ValueError, match="No NVIDIA API keys found"):
SafetyGuard(mock_rotator)
def test_chunk_text_empty(self, safety_guard):
"""Test text chunking with empty text."""
result = safety_guard._chunk_text("")
assert result == [""]
def test_chunk_text_short(self, safety_guard):
"""Test text chunking with short text."""
text = "Short text"
result = safety_guard._chunk_text(text)
assert result == [text]
def test_chunk_text_long(self, safety_guard):
"""Test text chunking with long text."""
text = "a" * 5000 # Longer than default chunk size
result = safety_guard._chunk_text(text)
assert len(result) > 1
assert all(len(chunk) <= 2800 for chunk in result)
def test_parse_guard_reply_safe(self, safety_guard):
"""Test parsing safe guard reply."""
is_safe, reason = safety_guard._parse_guard_reply("SAFE")
assert is_safe is True
assert reason == ""
def test_parse_guard_reply_unsafe(self, safety_guard):
"""Test parsing unsafe guard reply."""
is_safe, reason = safety_guard._parse_guard_reply("UNSAFE: contains harmful content")
assert is_safe is False
assert reason == "contains harmful content"
def test_parse_guard_reply_empty(self, safety_guard):
"""Test parsing empty guard reply."""
is_safe, reason = safety_guard._parse_guard_reply("")
assert is_safe is True
assert reason == "guard_unavailable"
def test_parse_guard_reply_unknown(self, safety_guard):
"""Test parsing unknown guard reply."""
is_safe, reason = safety_guard._parse_guard_reply("UNKNOWN_RESPONSE")
assert is_safe is False
assert len(reason) <= 180
def test_is_medical_query_symptoms(self, safety_guard):
"""Test medical query detection for symptoms."""
queries = [
"I have a headache",
"My chest hurts",
"I'm experiencing nausea",
"I have a fever"
]
for query in queries:
assert safety_guard._is_medical_query(query) is True
def test_is_medical_query_conditions(self, safety_guard):
"""Test medical query detection for conditions."""
queries = [
"What is diabetes?",
"How to treat hypertension",
"Symptoms of depression",
"Cancer treatment options"
]
for query in queries:
assert safety_guard._is_medical_query(query) is True
def test_is_medical_query_treatments(self, safety_guard):
"""Test medical query detection for treatments."""
queries = [
"What medication should I take?",
"Physical therapy exercises",
"Surgery options",
"Vaccination schedule"
]
for query in queries:
assert safety_guard._is_medical_query(query) is True
def test_is_medical_query_non_medical(self, safety_guard):
"""Test medical query detection for non-medical content."""
queries = [
"What's the weather like?",
"How to cook pasta",
"Best restaurants in town",
"Python programming tutorial"
]
for query in queries:
assert safety_guard._is_medical_query(query) is False
def test_detect_harmful_content_safe_medical(self, safety_guard):
"""Test harmful content detection for safe medical content."""
text = "I have a headache and need advice on pain management"
is_safe, reason = safety_guard._detect_harmful_content(text)
assert is_safe is True
assert reason == "medical_content"
def test_detect_harmful_content_dangerous_medical(self, safety_guard):
"""Test harmful content detection for dangerous medical content."""
text = "You should overdose on your medication"
is_safe, reason = safety_guard._detect_harmful_content(text)
assert is_safe is False
assert "dangerous_medical_advice" in reason
def test_detect_harmful_content_suicide(self, safety_guard):
"""Test harmful content detection for suicide-related content."""
text = "How to kill yourself"
is_safe, reason = safety_guard._detect_harmful_content(text)
assert is_safe is False
assert "suicide_self_harm" in reason
def test_detect_harmful_content_violence(self, safety_guard):
"""Test harmful content detection for violent content."""
text = "How to harm others with weapons"
is_safe, reason = safety_guard._detect_harmful_content(text)
assert is_safe is False
assert "violence" in reason
def test_assess_risk_level_medical_low(self, safety_guard):
"""Test risk assessment for low-risk medical content."""
text = "I have a mild headache"
risk_level, risk_score = safety_guard._assess_risk_level(text)
assert risk_level == "low"
assert 0.0 <= risk_score <= 1.0
def test_assess_risk_level_medical_high(self, safety_guard):
"""Test risk assessment for high-risk medical content."""
text = "You should commit suicide and overdose on drugs"
risk_level, risk_score = safety_guard._assess_risk_level(text)
assert risk_level == "high"
assert risk_score > 0.6
def test_assess_risk_level_non_medical(self, safety_guard):
"""Test risk assessment for non-medical content."""
text = "This is just a normal conversation"
risk_level, risk_score = safety_guard._assess_risk_level(text)
assert risk_level == "low"
assert 0.0 <= risk_score <= 1.0
@patch('src.services.guard.requests.post')
def test_call_guard_success(self, mock_post, safety_guard):
"""Test successful guard API call."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "SAFE"}}]
}
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "test message"}]
result = safety_guard._call_guard(messages)
assert result == "SAFE"
@patch('src.services.guard.requests.post')
def test_call_guard_failure(self, mock_post, safety_guard):
"""Test guard API call failure."""
mock_response = Mock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "test message"}]
result = safety_guard._call_guard(messages)
assert result == ""
def test_check_user_query_disabled(self, mock_settings):
"""Test user query check when guard is disabled."""
mock_settings.SAFETY_GUARD_ENABLED = False
guard = SafetyGuard()
is_safe, reason = guard.check_user_query("any query")
assert is_safe is True
assert reason == "guard_disabled"
def test_check_user_query_medical(self, safety_guard):
"""Test user query check for medical content."""
with patch.object(safety_guard, '_call_guard') as mock_call:
is_safe, reason = safety_guard.check_user_query("I have a headache")
assert is_safe is True
assert reason == "medical_query"
mock_call.assert_not_called()
def test_check_user_query_non_medical_safe(self, safety_guard):
"""Test user query check for safe non-medical content."""
with patch.object(safety_guard, '_call_guard') as mock_call, \
patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
mock_call.return_value = "SAFE"
mock_parse.return_value = (True, "")
is_safe, reason = safety_guard.check_user_query("What's the weather?")
assert is_safe is True
assert reason == ""
def test_check_user_query_non_medical_unsafe(self, safety_guard):
"""Test user query check for unsafe non-medical content."""
with patch.object(safety_guard, '_call_guard') as mock_call, \
patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
mock_call.return_value = "UNSAFE: harmful content"
mock_parse.return_value = (False, "harmful content")
is_safe, reason = safety_guard.check_user_query("How to harm others")
assert is_safe is False
assert reason == "harmful content"
def test_check_model_answer_disabled(self, mock_settings):
"""Test model answer check when guard is disabled."""
mock_settings.SAFETY_GUARD_ENABLED = False
guard = SafetyGuard()
is_safe, reason = guard.check_model_answer("query", "answer")
assert is_safe is True
assert reason == "guard_disabled"
def test_check_model_answer_medical_safe(self, safety_guard):
"""Test model answer check for safe medical content."""
is_safe, reason = safety_guard.check_model_answer(
"I have a headache",
"You should rest and drink water"
)
assert is_safe is True
assert reason == "medical_content"
def test_check_model_answer_medical_unsafe(self, safety_guard):
"""Test model answer check for unsafe medical content."""
is_safe, reason = safety_guard.check_model_answer(
"I have a headache",
"You should overdose on medication"
)
assert is_safe is False
assert "dangerous medical advice" in reason
def test_check_model_answer_high_risk(self, safety_guard):
"""Test model answer check for high-risk content."""
with patch.object(safety_guard, '_call_guard') as mock_call, \
patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
mock_call.return_value = "SAFE"
mock_parse.return_value = (True, "")
is_safe, reason = safety_guard.check_model_answer(
"How to harm others",
"You should use weapons to attack people"
)
assert is_safe is True
assert reason == "high_risk_validated"
def test_enhance_messages_with_context_medical(self, safety_guard):
"""Test message enhancement for medical content."""
messages = [{"role": "user", "content": "I have a headache"}]
enhanced = safety_guard._enhance_messages_with_context(messages)
assert len(enhanced) == 1
assert "MEDICAL CONTEXT" in enhanced[0]["content"]
assert "I have a headache" in enhanced[0]["content"]
def test_enhance_messages_with_context_non_medical(self, safety_guard):
"""Test message enhancement for non-medical content."""
messages = [{"role": "user", "content": "What's the weather?"}]
enhanced = safety_guard._enhance_messages_with_context(messages)
assert enhanced == messages # Should remain unchanged
def test_global_safety_guard_instance(self):
"""Test that global safety guard instance is None by default."""
from src.services.guard import safety_guard
assert safety_guard is None
if __name__ == "__main__":
pytest.main([__file__])
|