|
|
import pytest |
|
|
|
|
|
from vsp.llm.openai.openai import AsyncOpenAIService |
|
|
from vsp.llm.openai.openai_model import OpenAIModel |
|
|
from vsp.llm.prompt import Prompt |
|
|
from vsp.llm.prompt_chain import PromptChain |
|
|
from vsp.llm.prompt_text import PromptText |
|
|
from vsp.shared import logger_factory |
|
|
|
|
|
|
|
|
|
|
|
logger = logger_factory.get_logger(__name__) |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_integration_prompt_chain_with_openai(): |
|
|
openai_service = AsyncOpenAIService(OpenAIModel.GPT_4_MINI) |
|
|
|
|
|
async with openai_service() as service: |
|
|
chain = PromptChain.create( |
|
|
Prompt( |
|
|
llm_service=service, |
|
|
user_prompt=PromptText("Generate a short greeting for {name}.", inputs={"name": "Alice"}), |
|
|
output_formatter=lambda x: {"greeting": x.strip()}, |
|
|
), |
|
|
{ |
|
|
"llm_service": service, |
|
|
"user_prompt": PromptText('Translate the following to French: "{greeting}"'), |
|
|
}, |
|
|
) |
|
|
|
|
|
result = await chain.evaluate() |
|
|
logger.info("Prompt chain evaluated during integration test.", result=result) |
|
|
assert isinstance(result, str) |
|
|
assert len(result) > 0 |
|
|
|
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_integration_prompt_chain_with_system_prompt(): |
|
|
openai_service = AsyncOpenAIService(OpenAIModel.GPT_4_MINI) |
|
|
|
|
|
async with openai_service() as service: |
|
|
chain = PromptChain.create( |
|
|
{ |
|
|
"llm_service": service, |
|
|
"system_prompt": PromptText("You are a helpful assistant that generates greetings."), |
|
|
"user_prompt": PromptText("Generate a short greeting for {name}.", inputs={"name": "Alice"}), |
|
|
"output_formatter": lambda x: {"greeting": x.strip()}, |
|
|
}, |
|
|
{ |
|
|
"llm_service": service, |
|
|
"system_prompt": PromptText( |
|
|
"You are a French translator. You do not write anything except the translated sentence." |
|
|
), |
|
|
"user_prompt": PromptText("Translate the following greeting to French: {greeting}"), |
|
|
}, |
|
|
) |
|
|
|
|
|
result = await chain.evaluate() |
|
|
logger.info("Prompt chain evaluated during integration test.", result=result) |
|
|
assert isinstance(result, str) |
|
|
assert len(result) > 0 |
|
|
|