| from vsp.llm.llm_cache import LLMCache | |
| from vsp.llm.llm_service import LLMService | |
| from vsp.shared import logger_factory | |
| logger = logger_factory.get_logger(__name__) | |
| class CachedLLMService(LLMService): | |
| def __init__(self, llm_service: LLMService, cache: LLMCache | None = None): | |
| self._llm_service = llm_service | |
| self._cache = cache or LLMCache() | |
| async def invoke( | |
| self, | |
| user_prompt: str | None = None, | |
| system_prompt: str | None = None, | |
| partial_assistant_prompt: str | None = None, | |
| max_tokens: int = 1000, | |
| temperature: float = 0.0, | |
| ) -> str | None: | |
| cache_key = f"{user_prompt}_{system_prompt}_{partial_assistant_prompt}_{max_tokens}_{temperature}" | |
| cached_response = self._cache.get(cache_key, {}) | |
| if cached_response is not None: | |
| logger.debug("LLM cache hit") | |
| return cached_response | |
| response = await self._llm_service.invoke( | |
| user_prompt=user_prompt, | |
| system_prompt=system_prompt, | |
| partial_assistant_prompt=partial_assistant_prompt, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| if response is not None: | |
| self._cache.set(cache_key, response, {}) | |
| return response | |