navkast commited on
Commit
324115d
·
unverified ·
1 Parent(s): 49b13c6

Tidy up VSP implementation so far (#9)

Browse files

* feat: Add integration test for LLM caching using OpenRouter

* fix: Fix linter issues in test_integration_openrouter.py

* feat: Add unit tests for LLMCache

* style: Fix linter issues

* feat: Add unit test for CachedLLMService

* style: Run linter

* fix: Ensure LLMCache table is created

* fix: Update the service with the context manager in CachedLLMService

* refactor: replace in-memory SQLite databases with temporary files

* style: Fix import order in test files

* fix: wrap `async with` block in `try-finally` to clean up temporary database file

* Commit some tests

src/vsp/llm/cached_llm_service.py CHANGED
@@ -7,8 +7,8 @@ logger = logger_factory.get_logger(__name__)
7
 
8
  class CachedLLMService(LLMService):
9
  def __init__(self, llm_service: LLMService, cache: LLMCache | None = None):
10
- self.llm_service = llm_service
11
- self.cache = cache or LLMCache()
12
 
13
  async def invoke(
14
  self,
@@ -19,12 +19,12 @@ class CachedLLMService(LLMService):
19
  temperature: float = 0.0,
20
  ) -> str | None:
21
  cache_key = f"{user_prompt}_{system_prompt}_{partial_assistant_prompt}_{max_tokens}_{temperature}"
22
- cached_response = self.cache.get(cache_key, {})
23
  if cached_response is not None:
24
  logger.debug("LLM cache hit")
25
  return cached_response
26
 
27
- response = await self.llm_service.invoke(
28
  user_prompt=user_prompt,
29
  system_prompt=system_prompt,
30
  partial_assistant_prompt=partial_assistant_prompt,
@@ -33,6 +33,6 @@ class CachedLLMService(LLMService):
33
  )
34
 
35
  if response is not None:
36
- self.cache.set(cache_key, response, {})
37
 
38
  return response
 
7
 
8
  class CachedLLMService(LLMService):
9
  def __init__(self, llm_service: LLMService, cache: LLMCache | None = None):
10
+ self._llm_service = llm_service
11
+ self._cache = cache or LLMCache()
12
 
13
  async def invoke(
14
  self,
 
19
  temperature: float = 0.0,
20
  ) -> str | None:
21
  cache_key = f"{user_prompt}_{system_prompt}_{partial_assistant_prompt}_{max_tokens}_{temperature}"
22
+ cached_response = self._cache.get(cache_key, {})
23
  if cached_response is not None:
24
  logger.debug("LLM cache hit")
25
  return cached_response
26
 
27
+ response = await self._llm_service.invoke(
28
  user_prompt=user_prompt,
29
  system_prompt=system_prompt,
30
  partial_assistant_prompt=partial_assistant_prompt,
 
33
  )
34
 
35
  if response is not None:
36
+ self._cache.set(cache_key, response, {})
37
 
38
  return response
src/vsp/llm/llm_cache.py CHANGED
@@ -12,7 +12,7 @@ class LLMCache:
12
  self._init_db()
13
 
14
  def _init_db(self) -> None:
15
- with sqlite3.connect(self.db_path) as conn:
16
  conn.execute(
17
  """
18
  CREATE TABLE IF NOT EXISTS llm_cache (
@@ -21,7 +21,7 @@ class LLMCache:
21
  response TEXT,
22
  metadata TEXT
23
  )
24
- """
25
  )
26
 
27
  def _hash_prompt(self, prompt: str, metadata: dict[str, Any]) -> str:
@@ -30,7 +30,7 @@ class LLMCache:
30
 
31
  def get(self, prompt: str, metadata: dict[str, Any]) -> str | None:
32
  prompt_hash = self._hash_prompt(prompt, metadata)
33
- with sqlite3.connect(self.db_path) as conn:
34
  cursor = conn.cursor()
35
  cursor.execute("SELECT response FROM llm_cache WHERE prompt_hash = ?", (prompt_hash,))
36
  result = cursor.fetchone()
@@ -38,12 +38,12 @@ class LLMCache:
38
 
39
  def set(self, prompt: str, response: str, metadata: dict[str, Any]) -> None:
40
  prompt_hash = self._hash_prompt(prompt, metadata)
41
- with sqlite3.connect(self.db_path) as conn:
42
  conn.execute(
43
  "INSERT OR REPLACE INTO llm_cache (prompt_hash, prompt, response, metadata) VALUES (?, ?, ?, ?)",
44
  (prompt_hash, prompt, response, json.dumps(metadata)),
45
  )
46
 
47
  def clear(self) -> None:
48
- with sqlite3.connect(self.db_path) as conn:
49
  conn.execute("DELETE FROM llm_cache")
 
12
  self._init_db()
13
 
14
  def _init_db(self) -> None:
15
+ with sqlite3.connect(self.db_path, autocommit=True) as conn:
16
  conn.execute(
17
  """
18
  CREATE TABLE IF NOT EXISTS llm_cache (
 
21
  response TEXT,
22
  metadata TEXT
23
  )
24
+ """
25
  )
26
 
27
  def _hash_prompt(self, prompt: str, metadata: dict[str, Any]) -> str:
 
30
 
31
  def get(self, prompt: str, metadata: dict[str, Any]) -> str | None:
32
  prompt_hash = self._hash_prompt(prompt, metadata)
33
+ with sqlite3.connect(self.db_path, autocommit=True) as conn:
34
  cursor = conn.cursor()
35
  cursor.execute("SELECT response FROM llm_cache WHERE prompt_hash = ?", (prompt_hash,))
36
  result = cursor.fetchone()
 
38
 
39
  def set(self, prompt: str, response: str, metadata: dict[str, Any]) -> None:
40
  prompt_hash = self._hash_prompt(prompt, metadata)
41
+ with sqlite3.connect(self.db_path, autocommit=True) as conn:
42
  conn.execute(
43
  "INSERT OR REPLACE INTO llm_cache (prompt_hash, prompt, response, metadata) VALUES (?, ?, ?, ?)",
44
  (prompt_hash, prompt, response, json.dumps(metadata)),
45
  )
46
 
47
  def clear(self) -> None:
48
+ with sqlite3.connect(self.db_path, autocommit=True) as conn:
49
  conn.execute("DELETE FROM llm_cache")
tests/vsp/llm/openrouter/test_integration_openrouter.py CHANGED
@@ -1,7 +1,10 @@
1
  import asyncio
 
2
 
3
  import pytest
4
 
 
 
5
  from vsp.llm.openrouter.openrouter import AsyncOpenRouterService
6
  from vsp.shared import logger_factory
7
 
@@ -44,5 +47,62 @@ async def test_openrouter_integration():
44
  raise
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  if __name__ == "__main__":
48
  asyncio.run(test_openrouter_integration())
 
 
1
  import asyncio
2
+ from unittest.mock import patch
3
 
4
  import pytest
5
 
6
+ from vsp.llm.cached_llm_service import CachedLLMService
7
+ from vsp.llm.llm_cache import LLMCache
8
  from vsp.llm.openrouter.openrouter import AsyncOpenRouterService
9
  from vsp.shared import logger_factory
10
 
 
47
  raise
48
 
49
 
50
+ @pytest.mark.asyncio
51
+ async def test_cached_openrouter_integration():
52
+ """
53
+ Integration test for CachedLLMService with AsyncOpenRouterService.
54
+
55
+ This test verifies that:
56
+ 1. The first call goes to OpenRouter
57
+ 2. The second call with the same prompt returns the cached response
58
+ 3. A new prompt triggers another call to OpenRouter
59
+ """
60
+ model = "nousresearch/hermes-3-llama-3.1-405b:free"
61
+ openrouter_service = AsyncOpenRouterService(model)
62
+ import os
63
+ import tempfile
64
+
65
+ temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
66
+ temp_db_file.close()
67
+ cache = LLMCache(temp_db_file.name)
68
+ cached_service = CachedLLMService(openrouter_service, cache)
69
+
70
+ try:
71
+ async with openrouter_service() as service:
72
+ cached_service._llm_service = service # Update the service with the context manager
73
+
74
+ # Mock the invoke method to track calls
75
+ with patch.object(service, "invoke", wraps=service.invoke) as mock_invoke:
76
+ # First call
77
+ response1 = await cached_service.invoke(
78
+ user_prompt="What is the capital of France?", max_tokens=100, temperature=0.7
79
+ )
80
+ assert mock_invoke.call_count == 1
81
+ assert response1 is not None
82
+ assert "Paris" in response1
83
+
84
+ # Second call with the same prompt
85
+ response2 = await cached_service.invoke(
86
+ user_prompt="What is the capital of France?", max_tokens=100, temperature=0.7
87
+ )
88
+ assert mock_invoke.call_count == 1 # Should not have increased
89
+ assert response2 == response1
90
+
91
+ # Third call with a different prompt
92
+ response3 = await cached_service.invoke(
93
+ user_prompt="What is the capital of Spain?", max_tokens=100, temperature=0.7
94
+ )
95
+ assert mock_invoke.call_count == 2
96
+ assert response3 is not None
97
+ assert response3 != response1
98
+ assert "Madrid" in response3
99
+
100
+ logger.info("Cached OpenRouter integration test passed successfully")
101
+ finally:
102
+ # Clean up the temporary database file
103
+ os.unlink(temp_db_file.name)
104
+
105
+
106
  if __name__ == "__main__":
107
  asyncio.run(test_openrouter_integration())
108
+ asyncio.run(test_cached_openrouter_integration())
tests/vsp/llm/test_llm_cache.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from unittest.mock import AsyncMock
4
+
5
+ import pytest
6
+
7
+ from vsp.llm.cached_llm_service import CachedLLMService
8
+ from vsp.llm.llm_cache import LLMCache
9
+ from vsp.llm.llm_service import LLMService
10
+
11
+
12
+ @pytest.fixture
13
+ def llm_cache():
14
+ temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
15
+ temp_db_file.close()
16
+ cache = LLMCache(temp_db_file.name)
17
+ yield cache
18
+ # Clean up the temporary database file after the test
19
+ os.unlink(temp_db_file.name)
20
+
21
+
22
+ def test_llm_cache_set_and_get(llm_cache):
23
+ prompt = "What is the capital of France?"
24
+ response = "The capital of France is Paris."
25
+ metadata = {"model": "test_model", "temperature": 0.7}
26
+
27
+ # Test setting a value in the cache
28
+ llm_cache.set(prompt, response, metadata)
29
+
30
+ # Test getting the value from the cache
31
+ cached_response = llm_cache.get(prompt, metadata)
32
+ assert cached_response == response
33
+
34
+
35
+ def test_llm_cache_get_nonexistent(llm_cache):
36
+ prompt = "What is the capital of Germany?"
37
+ metadata = {"model": "test_model", "temperature": 0.7}
38
+
39
+ # Test getting a non-existent value from the cache
40
+ cached_response = llm_cache.get(prompt, metadata)
41
+ assert cached_response is None
42
+
43
+
44
+ def test_llm_cache_clear(llm_cache):
45
+ prompt1 = "What is the capital of France?"
46
+ response1 = "The capital of France is Paris."
47
+ prompt2 = "What is the capital of Italy?"
48
+ response2 = "The capital of Italy is Rome."
49
+ metadata = {"model": "test_model", "temperature": 0.7}
50
+
51
+ # Set multiple values in the cache
52
+ llm_cache.set(prompt1, response1, metadata)
53
+ llm_cache.set(prompt2, response2, metadata)
54
+
55
+ # Clear the cache
56
+ llm_cache.clear()
57
+
58
+ # Verify that the cache is empty
59
+ assert llm_cache.get(prompt1, metadata) is None
60
+ assert llm_cache.get(prompt2, metadata) is None
61
+
62
+
63
+ def test_llm_cache_different_metadata(llm_cache):
64
+ prompt = "What is the capital of France?"
65
+ response1 = "The capital of France is Paris."
66
+ response2 = "La capitale de la France est Paris."
67
+ metadata1 = {"model": "test_model_en", "temperature": 0.7}
68
+ metadata2 = {"model": "test_model_fr", "temperature": 0.7}
69
+
70
+ # Set values with different metadata
71
+ llm_cache.set(prompt, response1, metadata1)
72
+ llm_cache.set(prompt, response2, metadata2)
73
+
74
+ # Verify that different metadata produces different cache results
75
+ assert llm_cache.get(prompt, metadata1) == response1
76
+ assert llm_cache.get(prompt, metadata2) == response2
77
+
78
+
79
+ @pytest.mark.asyncio
80
+ async def test_cached_llm_service():
81
+ # Create a mock LLMService
82
+ mock_llm_service = AsyncMock(spec=LLMService)
83
+ mock_llm_service.invoke.side_effect = ["First response", "Second response", "Third response"]
84
+
85
+ # Create a CachedLLMService with the mock service
86
+ temp_db_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
87
+ temp_db_file.close()
88
+ cache = LLMCache(temp_db_file.name)
89
+ cached_service = CachedLLMService(mock_llm_service, cache)
90
+
91
+ # Test first call (should use the mock service)
92
+ response1 = await cached_service.invoke(user_prompt="Test prompt 1", max_tokens=10, temperature=0.5)
93
+ assert response1 == "First response"
94
+ mock_llm_service.invoke.assert_called_once()
95
+
96
+ # Test second call with the same parameters (should use cache)
97
+ response2 = await cached_service.invoke(user_prompt="Test prompt 1", max_tokens=10, temperature=0.5)
98
+ assert response2 == "First response"
99
+ assert mock_llm_service.invoke.call_count == 1 # Should not have increased
100
+
101
+ # Test third call with different parameters (should use the mock service again)
102
+ response3 = await cached_service.invoke(user_prompt="Test prompt 2", max_tokens=20, temperature=0.7)
103
+ assert response3 == "Second response"
104
+ assert mock_llm_service.invoke.call_count == 2
105
+
106
+ # Test fourth call with system and partial assistant prompts
107
+ response4 = await cached_service.invoke(
108
+ user_prompt="Test prompt 3",
109
+ system_prompt="System prompt",
110
+ partial_assistant_prompt="Partial assistant prompt",
111
+ max_tokens=30,
112
+ temperature=0.8,
113
+ )
114
+ assert response4 == "Third response"
115
+ assert mock_llm_service.invoke.call_count == 3
116
+
117
+ # Test fifth call with the same parameters as the fourth (should use cache)
118
+ response5 = await cached_service.invoke(
119
+ user_prompt="Test prompt 3",
120
+ system_prompt="System prompt",
121
+ partial_assistant_prompt="Partial assistant prompt",
122
+ max_tokens=30,
123
+ temperature=0.8,
124
+ )
125
+ assert response5 == "Third response"
126
+ assert mock_llm_service.invoke.call_count == 3 # Should not have increased
127
+
128
+ # Clean up the temporary database file
129
+ os.unlink(temp_db_file.name)