LiamKhoaLe's picture
Refactor tests organisation
d753c16
# tests/test_guard.py
import pytest
import os
from unittest.mock import Mock, patch, MagicMock
from src.services.guard import SafetyGuard
class TestSafetyGuard:
"""Test suite for SafetyGuard functionality."""
@pytest.fixture
def mock_settings(self):
"""Mock settings for testing."""
with patch('src.services.guard.settings') as mock_settings:
mock_settings.SAFETY_GUARD_TIMEOUT = 30
mock_settings.SAFETY_GUARD_ENABLED = True
mock_settings.SAFETY_GUARD_FAIL_OPEN = True
yield mock_settings
@pytest.fixture
def mock_rotator(self):
"""Mock NVIDIA rotator for testing."""
mock_rotator = Mock()
mock_rotator.get_key.return_value = "test_api_key"
mock_rotator.rotate.return_value = "test_api_key_2"
return mock_rotator
@pytest.fixture
def safety_guard(self, mock_settings, mock_rotator):
"""Create SafetyGuard instance for testing."""
return SafetyGuard(mock_rotator)
def test_init_with_valid_config(self, mock_settings, mock_rotator):
"""Test SafetyGuard initialization with valid configuration."""
guard = SafetyGuard(mock_rotator)
assert guard.nvidia_rotator == mock_rotator
assert guard.timeout_s == 30
assert guard.enabled is True
assert guard.fail_open is True
def test_init_without_api_key(self, mock_settings):
"""Test SafetyGuard initialization without API key."""
mock_rotator = Mock()
mock_rotator.get_key.return_value = None
with pytest.raises(ValueError, match="No NVIDIA API keys found"):
SafetyGuard(mock_rotator)
def test_chunk_text_empty(self, safety_guard):
"""Test text chunking with empty text."""
result = safety_guard._chunk_text("")
assert result == [""]
def test_chunk_text_short(self, safety_guard):
"""Test text chunking with short text."""
text = "Short text"
result = safety_guard._chunk_text(text)
assert result == [text]
def test_chunk_text_long(self, safety_guard):
"""Test text chunking with long text."""
text = "a" * 5000 # Longer than default chunk size
result = safety_guard._chunk_text(text)
assert len(result) > 1
assert all(len(chunk) <= 2800 for chunk in result)
def test_parse_guard_reply_safe(self, safety_guard):
"""Test parsing safe guard reply."""
is_safe, reason = safety_guard._parse_guard_reply("SAFE")
assert is_safe is True
assert reason == ""
def test_parse_guard_reply_unsafe(self, safety_guard):
"""Test parsing unsafe guard reply."""
is_safe, reason = safety_guard._parse_guard_reply("UNSAFE: contains harmful content")
assert is_safe is False
assert reason == "contains harmful content"
def test_parse_guard_reply_empty(self, safety_guard):
"""Test parsing empty guard reply."""
is_safe, reason = safety_guard._parse_guard_reply("")
assert is_safe is True
assert reason == "guard_unavailable"
def test_parse_guard_reply_unknown(self, safety_guard):
"""Test parsing unknown guard reply."""
is_safe, reason = safety_guard._parse_guard_reply("UNKNOWN_RESPONSE")
assert is_safe is False
assert len(reason) <= 180
def test_is_medical_query_symptoms(self, safety_guard):
"""Test medical query detection for symptoms."""
queries = [
"I have a headache",
"My chest hurts",
"I'm experiencing nausea",
"I have a fever"
]
for query in queries:
assert safety_guard._is_medical_query(query) is True
def test_is_medical_query_conditions(self, safety_guard):
"""Test medical query detection for conditions."""
queries = [
"What is diabetes?",
"How to treat hypertension",
"Symptoms of depression",
"Cancer treatment options"
]
for query in queries:
assert safety_guard._is_medical_query(query) is True
def test_is_medical_query_treatments(self, safety_guard):
"""Test medical query detection for treatments."""
queries = [
"What medication should I take?",
"Physical therapy exercises",
"Surgery options",
"Vaccination schedule"
]
for query in queries:
assert safety_guard._is_medical_query(query) is True
def test_is_medical_query_non_medical(self, safety_guard):
"""Test medical query detection for non-medical content."""
queries = [
"What's the weather like?",
"How to cook pasta",
"Best restaurants in town",
"Python programming tutorial"
]
for query in queries:
assert safety_guard._is_medical_query(query) is False
def test_detect_harmful_content_safe_medical(self, safety_guard):
"""Test harmful content detection for safe medical content."""
text = "I have a headache and need advice on pain management"
is_safe, reason = safety_guard._detect_harmful_content(text)
assert is_safe is True
assert reason == "medical_content"
def test_detect_harmful_content_dangerous_medical(self, safety_guard):
"""Test harmful content detection for dangerous medical content."""
text = "You should overdose on your medication"
is_safe, reason = safety_guard._detect_harmful_content(text)
assert is_safe is False
assert "dangerous_medical_advice" in reason
def test_detect_harmful_content_suicide(self, safety_guard):
"""Test harmful content detection for suicide-related content."""
text = "How to kill yourself"
is_safe, reason = safety_guard._detect_harmful_content(text)
assert is_safe is False
assert "suicide_self_harm" in reason
def test_detect_harmful_content_violence(self, safety_guard):
"""Test harmful content detection for violent content."""
text = "How to harm others with weapons"
is_safe, reason = safety_guard._detect_harmful_content(text)
assert is_safe is False
assert "violence" in reason
def test_assess_risk_level_medical_low(self, safety_guard):
"""Test risk assessment for low-risk medical content."""
text = "I have a mild headache"
risk_level, risk_score = safety_guard._assess_risk_level(text)
assert risk_level == "low"
assert 0.0 <= risk_score <= 1.0
def test_assess_risk_level_medical_high(self, safety_guard):
"""Test risk assessment for high-risk medical content."""
text = "You should commit suicide and overdose on drugs"
risk_level, risk_score = safety_guard._assess_risk_level(text)
assert risk_level == "high"
assert risk_score > 0.6
def test_assess_risk_level_non_medical(self, safety_guard):
"""Test risk assessment for non-medical content."""
text = "This is just a normal conversation"
risk_level, risk_score = safety_guard._assess_risk_level(text)
assert risk_level == "low"
assert 0.0 <= risk_score <= 1.0
@patch('src.services.guard.requests.post')
def test_call_guard_success(self, mock_post, safety_guard):
"""Test successful guard API call."""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "SAFE"}}]
}
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "test message"}]
result = safety_guard._call_guard(messages)
assert result == "SAFE"
@patch('src.services.guard.requests.post')
def test_call_guard_failure(self, mock_post, safety_guard):
"""Test guard API call failure."""
mock_response = Mock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "test message"}]
result = safety_guard._call_guard(messages)
assert result == ""
def test_check_user_query_disabled(self, mock_settings):
"""Test user query check when guard is disabled."""
mock_settings.SAFETY_GUARD_ENABLED = False
guard = SafetyGuard()
is_safe, reason = guard.check_user_query("any query")
assert is_safe is True
assert reason == "guard_disabled"
def test_check_user_query_medical(self, safety_guard):
"""Test user query check for medical content."""
with patch.object(safety_guard, '_call_guard') as mock_call:
is_safe, reason = safety_guard.check_user_query("I have a headache")
assert is_safe is True
assert reason == "medical_query"
mock_call.assert_not_called()
def test_check_user_query_non_medical_safe(self, safety_guard):
"""Test user query check for safe non-medical content."""
with patch.object(safety_guard, '_call_guard') as mock_call, \
patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
mock_call.return_value = "SAFE"
mock_parse.return_value = (True, "")
is_safe, reason = safety_guard.check_user_query("What's the weather?")
assert is_safe is True
assert reason == ""
def test_check_user_query_non_medical_unsafe(self, safety_guard):
"""Test user query check for unsafe non-medical content."""
with patch.object(safety_guard, '_call_guard') as mock_call, \
patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
mock_call.return_value = "UNSAFE: harmful content"
mock_parse.return_value = (False, "harmful content")
is_safe, reason = safety_guard.check_user_query("How to harm others")
assert is_safe is False
assert reason == "harmful content"
def test_check_model_answer_disabled(self, mock_settings):
"""Test model answer check when guard is disabled."""
mock_settings.SAFETY_GUARD_ENABLED = False
guard = SafetyGuard()
is_safe, reason = guard.check_model_answer("query", "answer")
assert is_safe is True
assert reason == "guard_disabled"
def test_check_model_answer_medical_safe(self, safety_guard):
"""Test model answer check for safe medical content."""
is_safe, reason = safety_guard.check_model_answer(
"I have a headache",
"You should rest and drink water"
)
assert is_safe is True
assert reason == "medical_content"
def test_check_model_answer_medical_unsafe(self, safety_guard):
"""Test model answer check for unsafe medical content."""
is_safe, reason = safety_guard.check_model_answer(
"I have a headache",
"You should overdose on medication"
)
assert is_safe is False
assert "dangerous medical advice" in reason
def test_check_model_answer_high_risk(self, safety_guard):
"""Test model answer check for high-risk content."""
with patch.object(safety_guard, '_call_guard') as mock_call, \
patch.object(safety_guard, '_parse_guard_reply') as mock_parse:
mock_call.return_value = "SAFE"
mock_parse.return_value = (True, "")
is_safe, reason = safety_guard.check_model_answer(
"How to harm others",
"You should use weapons to attack people"
)
assert is_safe is True
assert reason == "high_risk_validated"
def test_enhance_messages_with_context_medical(self, safety_guard):
"""Test message enhancement for medical content."""
messages = [{"role": "user", "content": "I have a headache"}]
enhanced = safety_guard._enhance_messages_with_context(messages)
assert len(enhanced) == 1
assert "MEDICAL CONTEXT" in enhanced[0]["content"]
assert "I have a headache" in enhanced[0]["content"]
def test_enhance_messages_with_context_non_medical(self, safety_guard):
"""Test message enhancement for non-medical content."""
messages = [{"role": "user", "content": "What's the weather?"}]
enhanced = safety_guard._enhance_messages_with_context(messages)
assert enhanced == messages # Should remain unchanged
def test_global_safety_guard_instance(self):
"""Test that global safety guard instance is None by default."""
from src.services.guard import safety_guard
assert safety_guard is None
if __name__ == "__main__":
pytest.main([__file__])