File size: 2,289 Bytes
fb2f0a7
 
 
 
 
 
 
 
 
 
 
 
64daa59
fb2f0a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import Any, Dict
import openai
import weave
import os
from dotenv import load_dotenv

from weave_prompt import PromptRefiner

# Load environment variables from .env file
load_dotenv()

# Weave autopatches OpenAI to log LLM calls to W&B
weave.init(os.getenv("WEAVE_PROJECT", "meta-llama"))


class LlamaPromptRefiner(PromptRefiner):
    @weave.op()
    def refine_prompt(self, current_prompt: str, analysis: Dict[str, Any], similarity_score):
        client = openai.OpenAI(
            # The custom base URL points to W&B Inference
            base_url='https://api.inference.wandb.ai/v1',

            # Get your API key from https://wandb.ai/authorize
            # Consider setting it in the environment as OPENAI_API_KEY instead for safety
            api_key=os.getenv("WANDB_API_KEY"),
        )

        response = client.chat.completions.create(
            model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
            messages=[
                {
                    "role": "system",
                    "content": (
                        "You are an expert at prompt engineering for text-to-image models. "
                        "Given a current prompt and an analysis of the differences between a generated image and a target image, "
                        "your job is to suggest a new prompt that will make the generated image more similar to the target. "
                        "Limit the new prompt to 100 words at most. "
                        "The user message will contain two sections: one for the current prompt and one for the analysis, each delimited by 'START OF CURRENT PROMPT'/'END OF CURRENT PROMPT' and 'START OF ANALYSIS'/'END OF ANALYSIS'. "
                        "Only return the improved prompt."
                    )
                },
                {
                    "role": "user",
                    "content": (
                        f"<START OF CURRENT PROMPT>\n{current_prompt}\n<END OF CURRENT PROMPT>\n"
                        f"<START OF ANALYSIS>\n{str(analysis)}\n<END OF ANALYSIS>\n"
                        "Suggest a new, improved prompt. Only return the prompt. Do not exceed 100 words."
                    )
                }
            ],
        )
        return response.choices[0].message.content