vsp-demo / src /vsp /llm /prompt.py
navkast
Update location of the VSP module (#1)
c1f8477 unverified
from typing import Any, Callable
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from vsp.llm.llm_service import LLMService
from vsp.llm.prompt_text import PromptText
class PromptEvaluationError(Exception):
"""
Custom exception for errors during prompt evaluation.
This exception is raised when an error occurs during the evaluation of a prompt.
"""
pass
class Prompt:
"""
A class to manage different types of prompts for language models.
This class handles system prompts, user prompts, partial assistant prompts,
and output formatting for language model interactions.
"""
def __init__(
self,
llm_service: LLMService,
system_prompt: PromptText | None = None,
user_prompt: PromptText | None = None,
partial_assistant_prompt: PromptText | None = None,
max_tokens: int = 1000,
temperature: float = 0.0,
output_formatter: Callable[[str], Any] = lambda x: x,
):
"""
Initialize a Prompt object.
Args:
llm_service (LLMService): The language model service to use for evaluation.
system_prompt (PromptText | None): The system prompt.
user_prompt (PromptText | None): The user prompt.
partial_assistant_prompt (PromptText | None): The partial assistant prompt.
max_tokens (int): The maximum number of tokens to generate.
temperature (float): The sampling temperature.
output_formatter (Callable[[str], Any]): A function to format the output.
Raises:
ValueError: If both system_prompt and user_prompt are None.
"""
self._system_prompt = system_prompt
self._user_prompt = user_prompt
self._partial_assistant_prompt = partial_assistant_prompt
self._max_tokens = max_tokens
self._temperature = temperature
self._output_formatter: Callable[[str], Any] = output_formatter
self._llm_service = llm_service
if not self._system_prompt and not self._user_prompt:
raise ValueError("At least one of system_prompt or user_prompt must be provided")
@property
def llm_service(self) -> LLMService:
"""The language model service used for evaluation."""
return self._llm_service
@llm_service.setter
def llm_service(self, value: LLMService) -> None:
"""
Set the language model service for the prompt to something other than the constructor value.
Args:
llm_service (LLMService): The language model service to use for evaluation.
"""
self._llm_service = value
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(3),
retry=retry_if_exception_type(PromptEvaluationError),
reraise=True,
)
async def evaluate(self) -> Any:
"""
Evaluate the prompt using the LLM service.
Returns:
Any: The formatted output of the LLM service, or None if no result.
Raises:
PromptEvaluationError: If an error occurs during prompt evaluation.
"""
result = await self._llm_service.invoke(
user_prompt=self._user_prompt.get_prompt() if self._user_prompt else None,
system_prompt=self._system_prompt.get_prompt() if self._system_prompt else None,
partial_assistant_prompt=(
self._partial_assistant_prompt.get_prompt() if self._partial_assistant_prompt else None
),
max_tokens=self._max_tokens,
temperature=self._temperature,
)
if result is None:
raise PromptEvaluationError("No result from LLM service")
try:
return self._output_formatter(result)
except Exception as e:
# This might happen because of just randomness in the LLM output
# Let's throw a special exception that triggers a retry of the prompt
raise PromptEvaluationError("Error formatting output") from e
def upsert_inputs(self, new_inputs: dict[str, Any]) -> None:
"""
Update the prompts with new inputs.
This method updates all non-None prompts (system, user, and partial assistant)
with the provided new inputs.
Args:
new_inputs (dict[str, Any]): A dictionary of new input values to update the prompts.
"""
if self._system_prompt:
self._system_prompt.upsert_inputs(new_inputs)
if self._user_prompt:
self._user_prompt.upsert_inputs(new_inputs)
if self._partial_assistant_prompt:
self._partial_assistant_prompt.upsert_inputs(new_inputs)