from typing import Any, Union from pydantic import BaseModel from vsp.llm.prompt import Prompt class PromptChain: """ A class to manage and execute a chain of Prompt objects. This class allows for the creation of a sequence of prompts where the output of each prompt is used as input for the subsequent prompt in the chain. Attributes: prompts (list[Prompt]): An ordered list of Prompt objects to be executed in sequence. """ def __init__(self, prompts: list[Prompt]): """ Initialize a PromptChain object. Args: prompts (list[Prompt]): An ordered list of Prompt objects to be chained together. """ self._prompts = prompts @property def prompts(self) -> list[Prompt]: """ All prompts in the chain. Potentially useful for changing the prompts dynamically. Returns: list[Prompt]: A list of all Prompt objects in the chain. """ return self._prompts async def evaluate(self) -> Union[Any, dict[str, Any], BaseModel]: """ Evaluate the entire chain of prompts. This method executes each prompt in the chain sequentially, passing the output of each prompt as input to the next. The final result of the last prompt in the chain is returned. Returns: Union[Any, dict[str, Any], BaseModel]: The output of the last prompt in the chain. The type depends on the output_formatter of the last prompt. Raises: Any exception that might be raised by the individual Prompt.evaluate() calls. Note: - If a prompt's output is a dictionary, it will be used to update the input for the next prompt. - If a prompt's output is a Pydantic model, its dictionary representation will be used to update the input. - For any other type of output, it will be stored under the key 'previous_output' in the input dictionary. """ current_input: dict[str, Any] = {} for prompt in self._prompts: # Update the prompts with the current input prompt.upsert_inputs(current_input) # Evaluate the prompt result = await prompt.evaluate() # Prepare the input for the next prompt if isinstance(result, dict): current_input.update(result) elif isinstance(result, BaseModel): current_input.update(result.model_dump()) else: current_input["previous_output"] = result return result @classmethod def create(cls, *args: dict[str, Any] | Prompt) -> "PromptChain": """ Create a PromptChain from a series of prompt configurations or Prompt objects. This class method provides a convenient way to create a PromptChain by specifying either configurations for each Prompt or direct Prompt objects. Args: *args: Variable number of arguments, each being either: - A dictionary containing the configuration for a single Prompt in the chain. - A Prompt object. Returns: PromptChain: A new PromptChain instance with the specified prompts. Raises: ValueError: If an argument is neither a dict nor a Prompt object. Example:: chain = PromptChain.create( { "llm_service": llm_service1, "system_prompt": PromptText("You are a helpful assistant."), "user_prompt": PromptText("Summarize this: {text}", inputs={"text": "A long text."}), "output_formatter": lambda x: {"summary": x} }, # You can also use Prompt() objects directly Prompt(llm_service2, user_prompt=PromptText("Translate to French: {summary}")), { "llm_service": llm_service3, "user_prompt": PromptText("Make it formal: {previous_output}") } ) """ prompts: list[Prompt] = [] for arg in args: if isinstance(arg, dict): prompts.append(Prompt(**arg)) elif isinstance(arg, Prompt): prompts.append(arg) else: raise ValueError(f"Invalid argument type: {type(arg)}. Expected dict or Prompt.") return cls(prompts)