File size: 4,767 Bytes
3b993c4
 
 
 
24d33b9
 
3b993c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Any, Callable

from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential

from vsp.llm.llm_service import LLMService
from vsp.llm.prompt_text import PromptText


class PromptEvaluationError(Exception):
    """
    Custom exception for errors during prompt evaluation.

    This exception is raised when an error occurs during the evaluation of a prompt.
    """

    pass


class Prompt:
    """
    A class to manage different types of prompts for language models.

    This class handles system prompts, user prompts, partial assistant prompts,
    and output formatting for language model interactions.
    """

    def __init__(
        self,
        llm_service: LLMService,
        system_prompt: PromptText | None = None,
        user_prompt: PromptText | None = None,
        partial_assistant_prompt: PromptText | None = None,
        max_tokens: int = 1000,
        temperature: float = 0.0,
        output_formatter: Callable[[str], Any] = lambda x: x,
    ):
        """
        Initialize a Prompt object.

        Args:
            llm_service (LLMService): The language model service to use for evaluation.
            system_prompt (PromptText | None): The system prompt.
            user_prompt (PromptText | None): The user prompt.
            partial_assistant_prompt (PromptText | None): The partial assistant prompt.
            max_tokens (int): The maximum number of tokens to generate.
            temperature (float): The sampling temperature.
            output_formatter (Callable[[str], Any]): A function to format the output.

        Raises:
            ValueError: If both system_prompt and user_prompt are None.
        """
        self._system_prompt = system_prompt
        self._user_prompt = user_prompt
        self._partial_assistant_prompt = partial_assistant_prompt
        self._max_tokens = max_tokens
        self._temperature = temperature
        self._output_formatter: Callable[[str], Any] = output_formatter
        self._llm_service = llm_service

        if not self._system_prompt and not self._user_prompt:
            raise ValueError("At least one of system_prompt or user_prompt must be provided")

    @property
    def llm_service(self) -> LLMService:
        """The language model service used for evaluation."""
        return self._llm_service

    @llm_service.setter
    def llm_service(self, value: LLMService) -> None:
        """
        Set the language model service for the prompt to something other than the constructor value.

        Args:
            llm_service (LLMService): The language model service to use for evaluation.
        """
        self._llm_service = value

    @retry(
        wait=wait_random_exponential(min=1, max=60),
        stop=stop_after_attempt(3),
        retry=retry_if_exception_type(PromptEvaluationError),
        reraise=True,
    )
    async def evaluate(self) -> Any:
        """
        Evaluate the prompt using the LLM service.

        Returns:
            Any: The formatted output of the LLM service, or None if no result.

        Raises:
            PromptEvaluationError: If an error occurs during prompt evaluation.
        """
        result = await self._llm_service.invoke(
            user_prompt=self._user_prompt.get_prompt() if self._user_prompt else None,
            system_prompt=self._system_prompt.get_prompt() if self._system_prompt else None,
            partial_assistant_prompt=(
                self._partial_assistant_prompt.get_prompt() if self._partial_assistant_prompt else None
            ),
            max_tokens=self._max_tokens,
            temperature=self._temperature,
        )
        if result is None:
            raise PromptEvaluationError("No result from LLM service")
        try:
            return self._output_formatter(result)
        except Exception as e:
            # This might happen because of just randomness in the LLM output
            # Let's throw a special exception that triggers a retry of the prompt
            raise PromptEvaluationError("Error formatting output") from e

    def upsert_inputs(self, new_inputs: dict[str, Any]) -> None:
        """
        Update the prompts with new inputs.

        This method updates all non-None prompts (system, user, and partial assistant)
        with the provided new inputs.

        Args:
            new_inputs (dict[str, Any]): A dictionary of new input values to update the prompts.
        """
        if self._system_prompt:
            self._system_prompt.upsert_inputs(new_inputs)
        if self._user_prompt:
            self._user_prompt.upsert_inputs(new_inputs)
        if self._partial_assistant_prompt:
            self._partial_assistant_prompt.upsert_inputs(new_inputs)