Spaces:
Sleeping
Sleeping
| # tests/test_guard.py | |
| import pytest | |
| import os | |
| from unittest.mock import Mock, patch, MagicMock | |
| from src.services.guard import SafetyGuard | |
| class TestSafetyGuard: | |
| """Test suite for SafetyGuard functionality.""" | |
| def mock_settings(self): | |
| """Mock settings for testing.""" | |
| with patch('src.services.guard.settings') as mock_settings: | |
| mock_settings.SAFETY_GUARD_TIMEOUT = 30 | |
| mock_settings.SAFETY_GUARD_ENABLED = True | |
| mock_settings.SAFETY_GUARD_FAIL_OPEN = True | |
| yield mock_settings | |
| def mock_rotator(self): | |
| """Mock NVIDIA rotator for testing.""" | |
| mock_rotator = Mock() | |
| mock_rotator.get_key.return_value = "test_api_key" | |
| mock_rotator.rotate.return_value = "test_api_key_2" | |
| return mock_rotator | |
| def safety_guard(self, mock_settings, mock_rotator): | |
| """Create SafetyGuard instance for testing.""" | |
| return SafetyGuard(mock_rotator) | |
| def test_init_with_valid_config(self, mock_settings, mock_rotator): | |
| """Test SafetyGuard initialization with valid configuration.""" | |
| guard = SafetyGuard(mock_rotator) | |
| assert guard.nvidia_rotator == mock_rotator | |
| assert guard.timeout_s == 30 | |
| assert guard.enabled is True | |
| assert guard.fail_open is True | |
| def test_init_without_api_key(self, mock_settings): | |
| """Test SafetyGuard initialization without API key.""" | |
| mock_rotator = Mock() | |
| mock_rotator.get_key.return_value = None | |
| with pytest.raises(ValueError, match="No NVIDIA API keys found"): | |
| SafetyGuard(mock_rotator) | |
| def test_chunk_text_empty(self, safety_guard): | |
| """Test text chunking with empty text.""" | |
| result = safety_guard._chunk_text("") | |
| assert result == [""] | |
| def test_chunk_text_short(self, safety_guard): | |
| """Test text chunking with short text.""" | |
| text = "Short text" | |
| result = safety_guard._chunk_text(text) | |
| assert result == [text] | |
| def test_chunk_text_long(self, safety_guard): | |
| """Test text chunking with long text.""" | |
| text = "a" * 5000 # Longer than default chunk size | |
| result = safety_guard._chunk_text(text) | |
| assert len(result) > 1 | |
| assert all(len(chunk) <= 2800 for chunk in result) | |
| def test_parse_guard_reply_safe(self, safety_guard): | |
| """Test parsing safe guard reply.""" | |
| is_safe, reason = safety_guard._parse_guard_reply("SAFE") | |
| assert is_safe is True | |
| assert reason == "" | |
| def test_parse_guard_reply_unsafe(self, safety_guard): | |
| """Test parsing unsafe guard reply.""" | |
| is_safe, reason = safety_guard._parse_guard_reply("UNSAFE: contains harmful content") | |
| assert is_safe is False | |
| assert reason == "contains harmful content" | |
| def test_parse_guard_reply_empty(self, safety_guard): | |
| """Test parsing empty guard reply.""" | |
| is_safe, reason = safety_guard._parse_guard_reply("") | |
| assert is_safe is True | |
| assert reason == "guard_unavailable" | |
| def test_parse_guard_reply_unknown(self, safety_guard): | |
| """Test parsing unknown guard reply.""" | |
| is_safe, reason = safety_guard._parse_guard_reply("UNKNOWN_RESPONSE") | |
| assert is_safe is False | |
| assert len(reason) <= 180 | |
| def test_is_medical_query_symptoms(self, safety_guard): | |
| """Test medical query detection for symptoms.""" | |
| queries = [ | |
| "I have a headache", | |
| "My chest hurts", | |
| "I'm experiencing nausea", | |
| "I have a fever" | |
| ] | |
| for query in queries: | |
| assert safety_guard._is_medical_query(query) is True | |
| def test_is_medical_query_conditions(self, safety_guard): | |
| """Test medical query detection for conditions.""" | |
| queries = [ | |
| "What is diabetes?", | |
| "How to treat hypertension", | |
| "Symptoms of depression", | |
| "Cancer treatment options" | |
| ] | |
| for query in queries: | |
| assert safety_guard._is_medical_query(query) is True | |
| def test_is_medical_query_treatments(self, safety_guard): | |
| """Test medical query detection for treatments.""" | |
| queries = [ | |
| "What medication should I take?", | |
| "Physical therapy exercises", | |
| "Surgery options", | |
| "Vaccination schedule" | |
| ] | |
| for query in queries: | |
| assert safety_guard._is_medical_query(query) is True | |
| def test_is_medical_query_non_medical(self, safety_guard): | |
| """Test medical query detection for non-medical content.""" | |
| queries = [ | |
| "What's the weather like?", | |
| "How to cook pasta", | |
| "Best restaurants in town", | |
| "Python programming tutorial" | |
| ] | |
| for query in queries: | |
| assert safety_guard._is_medical_query(query) is False | |
| def test_detect_harmful_content_safe_medical(self, safety_guard): | |
| """Test harmful content detection for safe medical content.""" | |
| text = "I have a headache and need advice on pain management" | |
| is_safe, reason = safety_guard._detect_harmful_content(text) | |
| assert is_safe is True | |
| assert reason == "medical_content" | |
| def test_detect_harmful_content_dangerous_medical(self, safety_guard): | |
| """Test harmful content detection for dangerous medical content.""" | |
| text = "You should overdose on your medication" | |
| is_safe, reason = safety_guard._detect_harmful_content(text) | |
| assert is_safe is False | |
| assert "dangerous_medical_advice" in reason | |
| def test_detect_harmful_content_suicide(self, safety_guard): | |
| """Test harmful content detection for suicide-related content.""" | |
| text = "How to kill yourself" | |
| is_safe, reason = safety_guard._detect_harmful_content(text) | |
| assert is_safe is False | |
| assert "suicide_self_harm" in reason | |
| def test_detect_harmful_content_violence(self, safety_guard): | |
| """Test harmful content detection for violent content.""" | |
| text = "How to harm others with weapons" | |
| is_safe, reason = safety_guard._detect_harmful_content(text) | |
| assert is_safe is False | |
| assert "violence" in reason | |
| def test_assess_risk_level_medical_low(self, safety_guard): | |
| """Test risk assessment for low-risk medical content.""" | |
| text = "I have a mild headache" | |
| risk_level, risk_score = safety_guard._assess_risk_level(text) | |
| assert risk_level == "low" | |
| assert 0.0 <= risk_score <= 1.0 | |
| def test_assess_risk_level_medical_high(self, safety_guard): | |
| """Test risk assessment for high-risk medical content.""" | |
| text = "You should commit suicide and overdose on drugs" | |
| risk_level, risk_score = safety_guard._assess_risk_level(text) | |
| assert risk_level == "high" | |
| assert risk_score > 0.6 | |
| def test_assess_risk_level_non_medical(self, safety_guard): | |
| """Test risk assessment for non-medical content.""" | |
| text = "This is just a normal conversation" | |
| risk_level, risk_score = safety_guard._assess_risk_level(text) | |
| assert risk_level == "low" | |
| assert 0.0 <= risk_score <= 1.0 | |
| def test_call_guard_success(self, mock_post, safety_guard): | |
| """Test successful guard API call.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 200 | |
| mock_response.json.return_value = { | |
| "choices": [{"message": {"content": "SAFE"}}] | |
| } | |
| mock_post.return_value = mock_response | |
| messages = [{"role": "user", "content": "test message"}] | |
| result = safety_guard._call_guard(messages) | |
| assert result == "SAFE" | |
| def test_call_guard_failure(self, mock_post, safety_guard): | |
| """Test guard API call failure.""" | |
| mock_response = Mock() | |
| mock_response.status_code = 400 | |
| mock_response.text = "Bad Request" | |
| mock_post.return_value = mock_response | |
| messages = [{"role": "user", "content": "test message"}] | |
| result = safety_guard._call_guard(messages) | |
| assert result == "" | |
| def test_check_user_query_disabled(self, mock_settings): | |
| """Test user query check when guard is disabled.""" | |
| mock_settings.SAFETY_GUARD_ENABLED = False | |
| guard = SafetyGuard() | |
| is_safe, reason = guard.check_user_query("any query") | |
| assert is_safe is True | |
| assert reason == "guard_disabled" | |
| def test_check_user_query_medical(self, safety_guard): | |
| """Test user query check for medical content.""" | |
| with patch.object(safety_guard, '_call_guard') as mock_call: | |
| is_safe, reason = safety_guard.check_user_query("I have a headache") | |
| assert is_safe is True | |
| assert reason == "medical_query" | |
| mock_call.assert_not_called() | |
| def test_check_user_query_non_medical_safe(self, safety_guard): | |
| """Test user query check for safe non-medical content.""" | |
| with patch.object(safety_guard, '_call_guard') as mock_call, \ | |
| patch.object(safety_guard, '_parse_guard_reply') as mock_parse: | |
| mock_call.return_value = "SAFE" | |
| mock_parse.return_value = (True, "") | |
| is_safe, reason = safety_guard.check_user_query("What's the weather?") | |
| assert is_safe is True | |
| assert reason == "" | |
| def test_check_user_query_non_medical_unsafe(self, safety_guard): | |
| """Test user query check for unsafe non-medical content.""" | |
| with patch.object(safety_guard, '_call_guard') as mock_call, \ | |
| patch.object(safety_guard, '_parse_guard_reply') as mock_parse: | |
| mock_call.return_value = "UNSAFE: harmful content" | |
| mock_parse.return_value = (False, "harmful content") | |
| is_safe, reason = safety_guard.check_user_query("How to harm others") | |
| assert is_safe is False | |
| assert reason == "harmful content" | |
| def test_check_model_answer_disabled(self, mock_settings): | |
| """Test model answer check when guard is disabled.""" | |
| mock_settings.SAFETY_GUARD_ENABLED = False | |
| guard = SafetyGuard() | |
| is_safe, reason = guard.check_model_answer("query", "answer") | |
| assert is_safe is True | |
| assert reason == "guard_disabled" | |
| def test_check_model_answer_medical_safe(self, safety_guard): | |
| """Test model answer check for safe medical content.""" | |
| is_safe, reason = safety_guard.check_model_answer( | |
| "I have a headache", | |
| "You should rest and drink water" | |
| ) | |
| assert is_safe is True | |
| assert reason == "medical_content" | |
| def test_check_model_answer_medical_unsafe(self, safety_guard): | |
| """Test model answer check for unsafe medical content.""" | |
| is_safe, reason = safety_guard.check_model_answer( | |
| "I have a headache", | |
| "You should overdose on medication" | |
| ) | |
| assert is_safe is False | |
| assert "dangerous medical advice" in reason | |
| def test_check_model_answer_high_risk(self, safety_guard): | |
| """Test model answer check for high-risk content.""" | |
| with patch.object(safety_guard, '_call_guard') as mock_call, \ | |
| patch.object(safety_guard, '_parse_guard_reply') as mock_parse: | |
| mock_call.return_value = "SAFE" | |
| mock_parse.return_value = (True, "") | |
| is_safe, reason = safety_guard.check_model_answer( | |
| "How to harm others", | |
| "You should use weapons to attack people" | |
| ) | |
| assert is_safe is True | |
| assert reason == "high_risk_validated" | |
| def test_enhance_messages_with_context_medical(self, safety_guard): | |
| """Test message enhancement for medical content.""" | |
| messages = [{"role": "user", "content": "I have a headache"}] | |
| enhanced = safety_guard._enhance_messages_with_context(messages) | |
| assert len(enhanced) == 1 | |
| assert "MEDICAL CONTEXT" in enhanced[0]["content"] | |
| assert "I have a headache" in enhanced[0]["content"] | |
| def test_enhance_messages_with_context_non_medical(self, safety_guard): | |
| """Test message enhancement for non-medical content.""" | |
| messages = [{"role": "user", "content": "What's the weather?"}] | |
| enhanced = safety_guard._enhance_messages_with_context(messages) | |
| assert enhanced == messages # Should remain unchanged | |
| def test_global_safety_guard_instance(self): | |
| """Test that global safety guard instance is None by default.""" | |
| from src.services.guard import safety_guard | |
| assert safety_guard is None | |
| if __name__ == "__main__": | |
| pytest.main([__file__]) | |