File size: 4,479 Bytes
3b993c4 24d33b9 3b993c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from typing import Any, Union
from pydantic import BaseModel
from vsp.llm.prompt import Prompt
class PromptChain:
"""
A class to manage and execute a chain of Prompt objects.
This class allows for the creation of a sequence of prompts where the output
of each prompt is used as input for the subsequent prompt in the chain.
Attributes:
prompts (list[Prompt]): An ordered list of Prompt objects to be executed in sequence.
"""
def __init__(self, prompts: list[Prompt]):
"""
Initialize a PromptChain object.
Args:
prompts (list[Prompt]): An ordered list of Prompt objects to be chained together.
"""
self._prompts = prompts
@property
def prompts(self) -> list[Prompt]:
"""
All prompts in the chain. Potentially useful for changing the prompts dynamically.
Returns:
list[Prompt]: A list of all Prompt objects in the chain.
"""
return self._prompts
async def evaluate(self) -> Union[Any, dict[str, Any], BaseModel]:
"""
Evaluate the entire chain of prompts.
This method executes each prompt in the chain sequentially, passing the output
of each prompt as input to the next. The final result of the last prompt in the
chain is returned.
Returns:
Union[Any, dict[str, Any], BaseModel]: The output of the last prompt in the chain.
The type depends on the output_formatter of the last prompt.
Raises:
Any exception that might be raised by the individual Prompt.evaluate() calls.
Note:
- If a prompt's output is a dictionary, it will be used to update the input for the next prompt.
- If a prompt's output is a Pydantic model, its dictionary representation will be used to update the input.
- For any other type of output, it will be stored under the key 'previous_output' in the input dictionary.
"""
current_input: dict[str, Any] = {}
for prompt in self._prompts:
# Update the prompts with the current input
prompt.upsert_inputs(current_input)
# Evaluate the prompt
result = await prompt.evaluate()
# Prepare the input for the next prompt
if isinstance(result, dict):
current_input.update(result)
elif isinstance(result, BaseModel):
current_input.update(result.model_dump())
else:
current_input["previous_output"] = result
return result
@classmethod
def create(cls, *args: dict[str, Any] | Prompt) -> "PromptChain":
"""
Create a PromptChain from a series of prompt configurations or Prompt objects.
This class method provides a convenient way to create a PromptChain by
specifying either configurations for each Prompt or direct Prompt objects.
Args:
*args: Variable number of arguments, each being either:
- A dictionary containing the configuration for a single Prompt in the chain.
- A Prompt object.
Returns:
PromptChain: A new PromptChain instance with the specified prompts.
Raises:
ValueError: If an argument is neither a dict nor a Prompt object.
Example::
chain = PromptChain.create(
{
"llm_service": llm_service1,
"system_prompt": PromptText("You are a helpful assistant."),
"user_prompt": PromptText("Summarize this: {text}", inputs={"text": "A long text."}),
"output_formatter": lambda x: {"summary": x}
},
# You can also use Prompt() objects directly
Prompt(llm_service2, user_prompt=PromptText("Translate to French: {summary}")),
{
"llm_service": llm_service3,
"user_prompt": PromptText("Make it formal: {previous_output}")
}
)
"""
prompts: list[Prompt] = []
for arg in args:
if isinstance(arg, dict):
prompts.append(Prompt(**arg))
elif isinstance(arg, Prompt):
prompts.append(arg)
else:
raise ValueError(f"Invalid argument type: {type(arg)}. Expected dict or Prompt.")
return cls(prompts)
|