vsp-demo / tests /vsp /llm /test_llm_cache.py
navkast
Tidy up VSP implementation so far (#9)
324115d unverified
import os
import tempfile
from unittest.mock import AsyncMock
import pytest
from vsp.llm.cached_llm_service import CachedLLMService
from vsp.llm.llm_cache import LLMCache
from vsp.llm.llm_service import LLMService
@pytest.fixture
def llm_cache():
temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
temp_db_file.close()
cache = LLMCache(temp_db_file.name)
yield cache
# Clean up the temporary database file after the test
os.unlink(temp_db_file.name)
def test_llm_cache_set_and_get(llm_cache):
prompt = "What is the capital of France?"
response = "The capital of France is Paris."
metadata = {"model": "test_model", "temperature": 0.7}
# Test setting a value in the cache
llm_cache.set(prompt, response, metadata)
# Test getting the value from the cache
cached_response = llm_cache.get(prompt, metadata)
assert cached_response == response
def test_llm_cache_get_nonexistent(llm_cache):
prompt = "What is the capital of Germany?"
metadata = {"model": "test_model", "temperature": 0.7}
# Test getting a non-existent value from the cache
cached_response = llm_cache.get(prompt, metadata)
assert cached_response is None
def test_llm_cache_clear(llm_cache):
prompt1 = "What is the capital of France?"
response1 = "The capital of France is Paris."
prompt2 = "What is the capital of Italy?"
response2 = "The capital of Italy is Rome."
metadata = {"model": "test_model", "temperature": 0.7}
# Set multiple values in the cache
llm_cache.set(prompt1, response1, metadata)
llm_cache.set(prompt2, response2, metadata)
# Clear the cache
llm_cache.clear()
# Verify that the cache is empty
assert llm_cache.get(prompt1, metadata) is None
assert llm_cache.get(prompt2, metadata) is None
def test_llm_cache_different_metadata(llm_cache):
prompt = "What is the capital of France?"
response1 = "The capital of France is Paris."
response2 = "La capitale de la France est Paris."
metadata1 = {"model": "test_model_en", "temperature": 0.7}
metadata2 = {"model": "test_model_fr", "temperature": 0.7}
# Set values with different metadata
llm_cache.set(prompt, response1, metadata1)
llm_cache.set(prompt, response2, metadata2)
# Verify that different metadata produces different cache results
assert llm_cache.get(prompt, metadata1) == response1
assert llm_cache.get(prompt, metadata2) == response2
@pytest.mark.asyncio
async def test_cached_llm_service():
# Create a mock LLMService
mock_llm_service = AsyncMock(spec=LLMService)
mock_llm_service.invoke.side_effect = ["First response", "Second response", "Third response"]
# Create a CachedLLMService with the mock service
temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
temp_db_file.close()
cache = LLMCache(temp_db_file.name)
cached_service = CachedLLMService(mock_llm_service, cache)
# Test first call (should use the mock service)
response1 = await cached_service.invoke(user_prompt="Test prompt 1", max_tokens=10, temperature=0.5)
assert response1 == "First response"
mock_llm_service.invoke.assert_called_once()
# Test second call with the same parameters (should use cache)
response2 = await cached_service.invoke(user_prompt="Test prompt 1", max_tokens=10, temperature=0.5)
assert response2 == "First response"
assert mock_llm_service.invoke.call_count == 1 # Should not have increased
# Test third call with different parameters (should use the mock service again)
response3 = await cached_service.invoke(user_prompt="Test prompt 2", max_tokens=20, temperature=0.7)
assert response3 == "Second response"
assert mock_llm_service.invoke.call_count == 2
# Test fourth call with system and partial assistant prompts
response4 = await cached_service.invoke(
user_prompt="Test prompt 3",
system_prompt="System prompt",
partial_assistant_prompt="Partial assistant prompt",
max_tokens=30,
temperature=0.8,
)
assert response4 == "Third response"
assert mock_llm_service.invoke.call_count == 3
# Test fifth call with the same parameters as the fourth (should use cache)
response5 = await cached_service.invoke(
user_prompt="Test prompt 3",
system_prompt="System prompt",
partial_assistant_prompt="Partial assistant prompt",
max_tokens=30,
temperature=0.8,
)
assert response5 == "Third response"
assert mock_llm_service.invoke.call_count == 3 # Should not have increased
# Clean up the temporary database file
os.unlink(temp_db_file.name)