navkast
commited on
Tidy up VSP implementation so far (#9)
Browse files* feat: Add integration test for LLM caching using OpenRouter
* fix: Fix linter issues in test_integration_openrouter.py
* feat: Add unit tests for LLMCache
* style: Fix linter issues
* feat: Add unit test for CachedLLMService
* style: Run linter
* fix: Ensure LLMCache table is created
* fix: Update the service with the context manager in CachedLLMService
* refactor: replace in-memory SQLite databases with temporary files
* style: Fix import order in test files
* fix: wrap `async with` block in `try-finally` to clean up temporary database file
* Commit some tests
src/vsp/llm/cached_llm_service.py
CHANGED
|
@@ -7,8 +7,8 @@ logger = logger_factory.get_logger(__name__)
|
|
| 7 |
|
| 8 |
class CachedLLMService(LLMService):
|
| 9 |
def __init__(self, llm_service: LLMService, cache: LLMCache | None = None):
|
| 10 |
-
self.
|
| 11 |
-
self.
|
| 12 |
|
| 13 |
async def invoke(
|
| 14 |
self,
|
|
@@ -19,12 +19,12 @@ class CachedLLMService(LLMService):
|
|
| 19 |
temperature: float = 0.0,
|
| 20 |
) -> str | None:
|
| 21 |
cache_key = f"{user_prompt}_{system_prompt}_{partial_assistant_prompt}_{max_tokens}_{temperature}"
|
| 22 |
-
cached_response = self.
|
| 23 |
if cached_response is not None:
|
| 24 |
logger.debug("LLM cache hit")
|
| 25 |
return cached_response
|
| 26 |
|
| 27 |
-
response = await self.
|
| 28 |
user_prompt=user_prompt,
|
| 29 |
system_prompt=system_prompt,
|
| 30 |
partial_assistant_prompt=partial_assistant_prompt,
|
|
@@ -33,6 +33,6 @@ class CachedLLMService(LLMService):
|
|
| 33 |
)
|
| 34 |
|
| 35 |
if response is not None:
|
| 36 |
-
self.
|
| 37 |
|
| 38 |
return response
|
|
|
|
| 7 |
|
| 8 |
class CachedLLMService(LLMService):
|
| 9 |
def __init__(self, llm_service: LLMService, cache: LLMCache | None = None):
|
| 10 |
+
self._llm_service = llm_service
|
| 11 |
+
self._cache = cache or LLMCache()
|
| 12 |
|
| 13 |
async def invoke(
|
| 14 |
self,
|
|
|
|
| 19 |
temperature: float = 0.0,
|
| 20 |
) -> str | None:
|
| 21 |
cache_key = f"{user_prompt}_{system_prompt}_{partial_assistant_prompt}_{max_tokens}_{temperature}"
|
| 22 |
+
cached_response = self._cache.get(cache_key, {})
|
| 23 |
if cached_response is not None:
|
| 24 |
logger.debug("LLM cache hit")
|
| 25 |
return cached_response
|
| 26 |
|
| 27 |
+
response = await self._llm_service.invoke(
|
| 28 |
user_prompt=user_prompt,
|
| 29 |
system_prompt=system_prompt,
|
| 30 |
partial_assistant_prompt=partial_assistant_prompt,
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
if response is not None:
|
| 36 |
+
self._cache.set(cache_key, response, {})
|
| 37 |
|
| 38 |
return response
|
src/vsp/llm/llm_cache.py
CHANGED
|
@@ -12,7 +12,7 @@ class LLMCache:
|
|
| 12 |
self._init_db()
|
| 13 |
|
| 14 |
def _init_db(self) -> None:
|
| 15 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 16 |
conn.execute(
|
| 17 |
"""
|
| 18 |
CREATE TABLE IF NOT EXISTS llm_cache (
|
|
@@ -21,7 +21,7 @@ class LLMCache:
|
|
| 21 |
response TEXT,
|
| 22 |
metadata TEXT
|
| 23 |
)
|
| 24 |
-
|
| 25 |
)
|
| 26 |
|
| 27 |
def _hash_prompt(self, prompt: str, metadata: dict[str, Any]) -> str:
|
|
@@ -30,7 +30,7 @@ class LLMCache:
|
|
| 30 |
|
| 31 |
def get(self, prompt: str, metadata: dict[str, Any]) -> str | None:
|
| 32 |
prompt_hash = self._hash_prompt(prompt, metadata)
|
| 33 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 34 |
cursor = conn.cursor()
|
| 35 |
cursor.execute("SELECT response FROM llm_cache WHERE prompt_hash = ?", (prompt_hash,))
|
| 36 |
result = cursor.fetchone()
|
|
@@ -38,12 +38,12 @@ class LLMCache:
|
|
| 38 |
|
| 39 |
def set(self, prompt: str, response: str, metadata: dict[str, Any]) -> None:
|
| 40 |
prompt_hash = self._hash_prompt(prompt, metadata)
|
| 41 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 42 |
conn.execute(
|
| 43 |
"INSERT OR REPLACE INTO llm_cache (prompt_hash, prompt, response, metadata) VALUES (?, ?, ?, ?)",
|
| 44 |
(prompt_hash, prompt, response, json.dumps(metadata)),
|
| 45 |
)
|
| 46 |
|
| 47 |
def clear(self) -> None:
|
| 48 |
-
with sqlite3.connect(self.db_path) as conn:
|
| 49 |
conn.execute("DELETE FROM llm_cache")
|
|
|
|
| 12 |
self._init_db()
|
| 13 |
|
| 14 |
def _init_db(self) -> None:
|
| 15 |
+
with sqlite3.connect(self.db_path, autocommit=True) as conn:
|
| 16 |
conn.execute(
|
| 17 |
"""
|
| 18 |
CREATE TABLE IF NOT EXISTS llm_cache (
|
|
|
|
| 21 |
response TEXT,
|
| 22 |
metadata TEXT
|
| 23 |
)
|
| 24 |
+
"""
|
| 25 |
)
|
| 26 |
|
| 27 |
def _hash_prompt(self, prompt: str, metadata: dict[str, Any]) -> str:
|
|
|
|
| 30 |
|
| 31 |
def get(self, prompt: str, metadata: dict[str, Any]) -> str | None:
|
| 32 |
prompt_hash = self._hash_prompt(prompt, metadata)
|
| 33 |
+
with sqlite3.connect(self.db_path, autocommit=True) as conn:
|
| 34 |
cursor = conn.cursor()
|
| 35 |
cursor.execute("SELECT response FROM llm_cache WHERE prompt_hash = ?", (prompt_hash,))
|
| 36 |
result = cursor.fetchone()
|
|
|
|
| 38 |
|
| 39 |
def set(self, prompt: str, response: str, metadata: dict[str, Any]) -> None:
|
| 40 |
prompt_hash = self._hash_prompt(prompt, metadata)
|
| 41 |
+
with sqlite3.connect(self.db_path, autocommit=True) as conn:
|
| 42 |
conn.execute(
|
| 43 |
"INSERT OR REPLACE INTO llm_cache (prompt_hash, prompt, response, metadata) VALUES (?, ?, ?, ?)",
|
| 44 |
(prompt_hash, prompt, response, json.dumps(metadata)),
|
| 45 |
)
|
| 46 |
|
| 47 |
def clear(self) -> None:
|
| 48 |
+
with sqlite3.connect(self.db_path, autocommit=True) as conn:
|
| 49 |
conn.execute("DELETE FROM llm_cache")
|
tests/vsp/llm/openrouter/test_integration_openrouter.py
CHANGED
|
@@ -1,7 +1,10 @@
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from vsp.llm.openrouter.openrouter import AsyncOpenRouterService
|
| 6 |
from vsp.shared import logger_factory
|
| 7 |
|
|
@@ -44,5 +47,62 @@ async def test_openrouter_integration():
|
|
| 44 |
raise
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
if __name__ == "__main__":
|
| 48 |
asyncio.run(test_openrouter_integration())
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
from unittest.mock import patch
|
| 3 |
|
| 4 |
import pytest
|
| 5 |
|
| 6 |
+
from vsp.llm.cached_llm_service import CachedLLMService
|
| 7 |
+
from vsp.llm.llm_cache import LLMCache
|
| 8 |
from vsp.llm.openrouter.openrouter import AsyncOpenRouterService
|
| 9 |
from vsp.shared import logger_factory
|
| 10 |
|
|
|
|
| 47 |
raise
|
| 48 |
|
| 49 |
|
| 50 |
+
@pytest.mark.asyncio
|
| 51 |
+
async def test_cached_openrouter_integration():
|
| 52 |
+
"""
|
| 53 |
+
Integration test for CachedLLMService with AsyncOpenRouterService.
|
| 54 |
+
|
| 55 |
+
This test verifies that:
|
| 56 |
+
1. The first call goes to OpenRouter
|
| 57 |
+
2. The second call with the same prompt returns the cached response
|
| 58 |
+
3. A new prompt triggers another call to OpenRouter
|
| 59 |
+
"""
|
| 60 |
+
model = "nousresearch/hermes-3-llama-3.1-405b:free"
|
| 61 |
+
openrouter_service = AsyncOpenRouterService(model)
|
| 62 |
+
import os
|
| 63 |
+
import tempfile
|
| 64 |
+
|
| 65 |
+
temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
| 66 |
+
temp_db_file.close()
|
| 67 |
+
cache = LLMCache(temp_db_file.name)
|
| 68 |
+
cached_service = CachedLLMService(openrouter_service, cache)
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
async with openrouter_service() as service:
|
| 72 |
+
cached_service._llm_service = service # Update the service with the context manager
|
| 73 |
+
|
| 74 |
+
# Mock the invoke method to track calls
|
| 75 |
+
with patch.object(service, "invoke", wraps=service.invoke) as mock_invoke:
|
| 76 |
+
# First call
|
| 77 |
+
response1 = await cached_service.invoke(
|
| 78 |
+
user_prompt="What is the capital of France?", max_tokens=100, temperature=0.7
|
| 79 |
+
)
|
| 80 |
+
assert mock_invoke.call_count == 1
|
| 81 |
+
assert response1 is not None
|
| 82 |
+
assert "Paris" in response1
|
| 83 |
+
|
| 84 |
+
# Second call with the same prompt
|
| 85 |
+
response2 = await cached_service.invoke(
|
| 86 |
+
user_prompt="What is the capital of France?", max_tokens=100, temperature=0.7
|
| 87 |
+
)
|
| 88 |
+
assert mock_invoke.call_count == 1 # Should not have increased
|
| 89 |
+
assert response2 == response1
|
| 90 |
+
|
| 91 |
+
# Third call with a different prompt
|
| 92 |
+
response3 = await cached_service.invoke(
|
| 93 |
+
user_prompt="What is the capital of Spain?", max_tokens=100, temperature=0.7
|
| 94 |
+
)
|
| 95 |
+
assert mock_invoke.call_count == 2
|
| 96 |
+
assert response3 is not None
|
| 97 |
+
assert response3 != response1
|
| 98 |
+
assert "Madrid" in response3
|
| 99 |
+
|
| 100 |
+
logger.info("Cached OpenRouter integration test passed successfully")
|
| 101 |
+
finally:
|
| 102 |
+
# Clean up the temporary database file
|
| 103 |
+
os.unlink(temp_db_file.name)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
if __name__ == "__main__":
|
| 107 |
asyncio.run(test_openrouter_integration())
|
| 108 |
+
asyncio.run(test_cached_openrouter_integration())
|
tests/vsp/llm/test_llm_cache.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
from unittest.mock import AsyncMock
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from vsp.llm.cached_llm_service import CachedLLMService
|
| 8 |
+
from vsp.llm.llm_cache import LLMCache
|
| 9 |
+
from vsp.llm.llm_service import LLMService
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture
|
| 13 |
+
def llm_cache():
|
| 14 |
+
temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
| 15 |
+
temp_db_file.close()
|
| 16 |
+
cache = LLMCache(temp_db_file.name)
|
| 17 |
+
yield cache
|
| 18 |
+
# Clean up the temporary database file after the test
|
| 19 |
+
os.unlink(temp_db_file.name)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_llm_cache_set_and_get(llm_cache):
|
| 23 |
+
prompt = "What is the capital of France?"
|
| 24 |
+
response = "The capital of France is Paris."
|
| 25 |
+
metadata = {"model": "test_model", "temperature": 0.7}
|
| 26 |
+
|
| 27 |
+
# Test setting a value in the cache
|
| 28 |
+
llm_cache.set(prompt, response, metadata)
|
| 29 |
+
|
| 30 |
+
# Test getting the value from the cache
|
| 31 |
+
cached_response = llm_cache.get(prompt, metadata)
|
| 32 |
+
assert cached_response == response
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_llm_cache_get_nonexistent(llm_cache):
|
| 36 |
+
prompt = "What is the capital of Germany?"
|
| 37 |
+
metadata = {"model": "test_model", "temperature": 0.7}
|
| 38 |
+
|
| 39 |
+
# Test getting a non-existent value from the cache
|
| 40 |
+
cached_response = llm_cache.get(prompt, metadata)
|
| 41 |
+
assert cached_response is None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_llm_cache_clear(llm_cache):
|
| 45 |
+
prompt1 = "What is the capital of France?"
|
| 46 |
+
response1 = "The capital of France is Paris."
|
| 47 |
+
prompt2 = "What is the capital of Italy?"
|
| 48 |
+
response2 = "The capital of Italy is Rome."
|
| 49 |
+
metadata = {"model": "test_model", "temperature": 0.7}
|
| 50 |
+
|
| 51 |
+
# Set multiple values in the cache
|
| 52 |
+
llm_cache.set(prompt1, response1, metadata)
|
| 53 |
+
llm_cache.set(prompt2, response2, metadata)
|
| 54 |
+
|
| 55 |
+
# Clear the cache
|
| 56 |
+
llm_cache.clear()
|
| 57 |
+
|
| 58 |
+
# Verify that the cache is empty
|
| 59 |
+
assert llm_cache.get(prompt1, metadata) is None
|
| 60 |
+
assert llm_cache.get(prompt2, metadata) is None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_llm_cache_different_metadata(llm_cache):
|
| 64 |
+
prompt = "What is the capital of France?"
|
| 65 |
+
response1 = "The capital of France is Paris."
|
| 66 |
+
response2 = "La capitale de la France est Paris."
|
| 67 |
+
metadata1 = {"model": "test_model_en", "temperature": 0.7}
|
| 68 |
+
metadata2 = {"model": "test_model_fr", "temperature": 0.7}
|
| 69 |
+
|
| 70 |
+
# Set values with different metadata
|
| 71 |
+
llm_cache.set(prompt, response1, metadata1)
|
| 72 |
+
llm_cache.set(prompt, response2, metadata2)
|
| 73 |
+
|
| 74 |
+
# Verify that different metadata produces different cache results
|
| 75 |
+
assert llm_cache.get(prompt, metadata1) == response1
|
| 76 |
+
assert llm_cache.get(prompt, metadata2) == response2
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@pytest.mark.asyncio
|
| 80 |
+
async def test_cached_llm_service():
|
| 81 |
+
# Create a mock LLMService
|
| 82 |
+
mock_llm_service = AsyncMock(spec=LLMService)
|
| 83 |
+
mock_llm_service.invoke.side_effect = ["First response", "Second response", "Third response"]
|
| 84 |
+
|
| 85 |
+
# Create a CachedLLMService with the mock service
|
| 86 |
+
temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
| 87 |
+
temp_db_file.close()
|
| 88 |
+
cache = LLMCache(temp_db_file.name)
|
| 89 |
+
cached_service = CachedLLMService(mock_llm_service, cache)
|
| 90 |
+
|
| 91 |
+
# Test first call (should use the mock service)
|
| 92 |
+
response1 = await cached_service.invoke(user_prompt="Test prompt 1", max_tokens=10, temperature=0.5)
|
| 93 |
+
assert response1 == "First response"
|
| 94 |
+
mock_llm_service.invoke.assert_called_once()
|
| 95 |
+
|
| 96 |
+
# Test second call with the same parameters (should use cache)
|
| 97 |
+
response2 = await cached_service.invoke(user_prompt="Test prompt 1", max_tokens=10, temperature=0.5)
|
| 98 |
+
assert response2 == "First response"
|
| 99 |
+
assert mock_llm_service.invoke.call_count == 1 # Should not have increased
|
| 100 |
+
|
| 101 |
+
# Test third call with different parameters (should use the mock service again)
|
| 102 |
+
response3 = await cached_service.invoke(user_prompt="Test prompt 2", max_tokens=20, temperature=0.7)
|
| 103 |
+
assert response3 == "Second response"
|
| 104 |
+
assert mock_llm_service.invoke.call_count == 2
|
| 105 |
+
|
| 106 |
+
# Test fourth call with system and partial assistant prompts
|
| 107 |
+
response4 = await cached_service.invoke(
|
| 108 |
+
user_prompt="Test prompt 3",
|
| 109 |
+
system_prompt="System prompt",
|
| 110 |
+
partial_assistant_prompt="Partial assistant prompt",
|
| 111 |
+
max_tokens=30,
|
| 112 |
+
temperature=0.8,
|
| 113 |
+
)
|
| 114 |
+
assert response4 == "Third response"
|
| 115 |
+
assert mock_llm_service.invoke.call_count == 3
|
| 116 |
+
|
| 117 |
+
# Test fifth call with the same parameters as the fourth (should use cache)
|
| 118 |
+
response5 = await cached_service.invoke(
|
| 119 |
+
user_prompt="Test prompt 3",
|
| 120 |
+
system_prompt="System prompt",
|
| 121 |
+
partial_assistant_prompt="Partial assistant prompt",
|
| 122 |
+
max_tokens=30,
|
| 123 |
+
temperature=0.8,
|
| 124 |
+
)
|
| 125 |
+
assert response5 == "Third response"
|
| 126 |
+
assert mock_llm_service.invoke.call_count == 3 # Should not have increased
|
| 127 |
+
|
| 128 |
+
# Clean up the temporary database file
|
| 129 |
+
os.unlink(temp_db_file.name)
|