|
|
import os |
|
|
import tempfile |
|
|
from unittest.mock import AsyncMock |
|
|
|
|
|
import pytest |
|
|
|
|
|
from vsp.llm.cached_llm_service import CachedLLMService |
|
|
from vsp.llm.llm_cache import LLMCache |
|
|
from vsp.llm.llm_service import LLMService |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def llm_cache(): |
|
|
temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") |
|
|
temp_db_file.close() |
|
|
cache = LLMCache(temp_db_file.name) |
|
|
yield cache |
|
|
|
|
|
os.unlink(temp_db_file.name) |
|
|
|
|
|
|
|
|
def test_llm_cache_set_and_get(llm_cache): |
|
|
prompt = "What is the capital of France?" |
|
|
response = "The capital of France is Paris." |
|
|
metadata = {"model": "test_model", "temperature": 0.7} |
|
|
|
|
|
|
|
|
llm_cache.set(prompt, response, metadata) |
|
|
|
|
|
|
|
|
cached_response = llm_cache.get(prompt, metadata) |
|
|
assert cached_response == response |
|
|
|
|
|
|
|
|
def test_llm_cache_get_nonexistent(llm_cache): |
|
|
prompt = "What is the capital of Germany?" |
|
|
metadata = {"model": "test_model", "temperature": 0.7} |
|
|
|
|
|
|
|
|
cached_response = llm_cache.get(prompt, metadata) |
|
|
assert cached_response is None |
|
|
|
|
|
|
|
|
def test_llm_cache_clear(llm_cache): |
|
|
prompt1 = "What is the capital of France?" |
|
|
response1 = "The capital of France is Paris." |
|
|
prompt2 = "What is the capital of Italy?" |
|
|
response2 = "The capital of Italy is Rome." |
|
|
metadata = {"model": "test_model", "temperature": 0.7} |
|
|
|
|
|
|
|
|
llm_cache.set(prompt1, response1, metadata) |
|
|
llm_cache.set(prompt2, response2, metadata) |
|
|
|
|
|
|
|
|
llm_cache.clear() |
|
|
|
|
|
|
|
|
assert llm_cache.get(prompt1, metadata) is None |
|
|
assert llm_cache.get(prompt2, metadata) is None |
|
|
|
|
|
|
|
|
def test_llm_cache_different_metadata(llm_cache): |
|
|
prompt = "What is the capital of France?" |
|
|
response1 = "The capital of France is Paris." |
|
|
response2 = "La capitale de la France est Paris." |
|
|
metadata1 = {"model": "test_model_en", "temperature": 0.7} |
|
|
metadata2 = {"model": "test_model_fr", "temperature": 0.7} |
|
|
|
|
|
|
|
|
llm_cache.set(prompt, response1, metadata1) |
|
|
llm_cache.set(prompt, response2, metadata2) |
|
|
|
|
|
|
|
|
assert llm_cache.get(prompt, metadata1) == response1 |
|
|
assert llm_cache.get(prompt, metadata2) == response2 |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_cached_llm_service(): |
|
|
|
|
|
mock_llm_service = AsyncMock(spec=LLMService) |
|
|
mock_llm_service.invoke.side_effect = ["First response", "Second response", "Third response"] |
|
|
|
|
|
|
|
|
temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") |
|
|
temp_db_file.close() |
|
|
cache = LLMCache(temp_db_file.name) |
|
|
cached_service = CachedLLMService(mock_llm_service, cache) |
|
|
|
|
|
|
|
|
response1 = await cached_service.invoke(user_prompt="Test prompt 1", max_tokens=10, temperature=0.5) |
|
|
assert response1 == "First response" |
|
|
mock_llm_service.invoke.assert_called_once() |
|
|
|
|
|
|
|
|
response2 = await cached_service.invoke(user_prompt="Test prompt 1", max_tokens=10, temperature=0.5) |
|
|
assert response2 == "First response" |
|
|
assert mock_llm_service.invoke.call_count == 1 |
|
|
|
|
|
|
|
|
response3 = await cached_service.invoke(user_prompt="Test prompt 2", max_tokens=20, temperature=0.7) |
|
|
assert response3 == "Second response" |
|
|
assert mock_llm_service.invoke.call_count == 2 |
|
|
|
|
|
|
|
|
response4 = await cached_service.invoke( |
|
|
user_prompt="Test prompt 3", |
|
|
system_prompt="System prompt", |
|
|
partial_assistant_prompt="Partial assistant prompt", |
|
|
max_tokens=30, |
|
|
temperature=0.8, |
|
|
) |
|
|
assert response4 == "Third response" |
|
|
assert mock_llm_service.invoke.call_count == 3 |
|
|
|
|
|
|
|
|
response5 = await cached_service.invoke( |
|
|
user_prompt="Test prompt 3", |
|
|
system_prompt="System prompt", |
|
|
partial_assistant_prompt="Partial assistant prompt", |
|
|
max_tokens=30, |
|
|
temperature=0.8, |
|
|
) |
|
|
assert response5 == "Third response" |
|
|
assert mock_llm_service.invoke.call_count == 3 |
|
|
|
|
|
|
|
|
os.unlink(temp_db_file.name) |
|
|
|