File size: 2,381 Bytes
3b993c4
 
24d33b9
 
 
 
 
 
3b993c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pytest

from vsp.llm.openai.openai import AsyncOpenAIService
from vsp.llm.openai.openai_model import OpenAIModel
from vsp.llm.prompt import Prompt
from vsp.llm.prompt_chain import PromptChain
from vsp.llm.prompt_text import PromptText
from vsp.shared import logger_factory

# Note: This test requires an actual OpenAI API key and will make real API calls.

logger = logger_factory.get_logger(__name__)


@pytest.mark.asyncio
async def test_integration_prompt_chain_with_openai():
    openai_service = AsyncOpenAIService(OpenAIModel.GPT_4_MINI)

    async with openai_service() as service:
        chain = PromptChain.create(
            Prompt(
                llm_service=service,
                user_prompt=PromptText("Generate a short greeting for {name}.", inputs={"name": "Alice"}),
                output_formatter=lambda x: {"greeting": x.strip()},
            ),
            {
                "llm_service": service,
                "user_prompt": PromptText('Translate the following to French: "{greeting}"'),
            },
        )

        result = await chain.evaluate()
        logger.info("Prompt chain evaluated during integration test.", result=result)
        assert isinstance(result, str)
        assert len(result) > 0


@pytest.mark.asyncio
async def test_integration_prompt_chain_with_system_prompt():
    openai_service = AsyncOpenAIService(OpenAIModel.GPT_4_MINI)

    async with openai_service() as service:
        chain = PromptChain.create(
            {
                "llm_service": service,
                "system_prompt": PromptText("You are a helpful assistant that generates greetings."),
                "user_prompt": PromptText("Generate a short greeting for {name}.", inputs={"name": "Alice"}),
                "output_formatter": lambda x: {"greeting": x.strip()},
            },
            {
                "llm_service": service,
                "system_prompt": PromptText(
                    "You are a French translator. You do not write anything except the translated sentence."
                ),
                "user_prompt": PromptText("Translate the following greeting to French: {greeting}"),
            },
        )

        result = await chain.evaluate()
        logger.info("Prompt chain evaluated during integration test.", result=result)
        assert isinstance(result, str)
        assert len(result) > 0