File size: 2,381 Bytes
3b993c4 24d33b9 3b993c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import pytest
from vsp.llm.openai.openai import AsyncOpenAIService
from vsp.llm.openai.openai_model import OpenAIModel
from vsp.llm.prompt import Prompt
from vsp.llm.prompt_chain import PromptChain
from vsp.llm.prompt_text import PromptText
from vsp.shared import logger_factory
# Note: This test requires an actual OpenAI API key and will make real API calls.
logger = logger_factory.get_logger(__name__)
@pytest.mark.asyncio
async def test_integration_prompt_chain_with_openai():
openai_service = AsyncOpenAIService(OpenAIModel.GPT_4_MINI)
async with openai_service() as service:
chain = PromptChain.create(
Prompt(
llm_service=service,
user_prompt=PromptText("Generate a short greeting for {name}.", inputs={"name": "Alice"}),
output_formatter=lambda x: {"greeting": x.strip()},
),
{
"llm_service": service,
"user_prompt": PromptText('Translate the following to French: "{greeting}"'),
},
)
result = await chain.evaluate()
logger.info("Prompt chain evaluated during integration test.", result=result)
assert isinstance(result, str)
assert len(result) > 0
@pytest.mark.asyncio
async def test_integration_prompt_chain_with_system_prompt():
openai_service = AsyncOpenAIService(OpenAIModel.GPT_4_MINI)
async with openai_service() as service:
chain = PromptChain.create(
{
"llm_service": service,
"system_prompt": PromptText("You are a helpful assistant that generates greetings."),
"user_prompt": PromptText("Generate a short greeting for {name}.", inputs={"name": "Alice"}),
"output_formatter": lambda x: {"greeting": x.strip()},
},
{
"llm_service": service,
"system_prompt": PromptText(
"You are a French translator. You do not write anything except the translated sentence."
),
"user_prompt": PromptText("Translate the following greeting to French: {greeting}"),
},
)
result = await chain.evaluate()
logger.info("Prompt chain evaluated during integration test.", result=result)
assert isinstance(result, str)
assert len(result) > 0
|