File size: 11,593 Bytes
fb2f0a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64daa59
fb2f0a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import openai
import weave
import base64
import json
import tempfile
import os
from pathlib import Path
from PIL import Image
from typing import Dict, Any, Optional
from weave_prompt import ImageEvaluator
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Weave autopatches OpenAI to log LLM calls to W&B
weave.init(os.getenv("WEAVE_PROJECT", "meta-llama"))

class LlamaEvaluator(ImageEvaluator):
    """Llama-based image evaluator using W&B Inference."""
    
    def __init__(self, api_key: Optional[str] = None):
        """
        Initialize the Llama evaluator with OpenAI client.
        
        Args:
            api_key: Optional API key. If not provided, will look for OPENAI_API_KEY
                    or WANDB_API_KEY environment variables.
        """
        # Get API key from parameter, environment variables, or raise error
        if api_key is None:
            api_key = os.getenv('WANDB_API_KEY')
            if api_key is None:
                raise ValueError(
                    "API key not provided. Please either:\n"
                    "1. Pass api_key parameter to LlamaEvaluator()\n"
                    "2. Set OPENAI_API_KEY environment variable\n"
                    "3. Set WANDB_API_KEY environment variable\n"
                    "Get your API key from https://wandb.ai/authorize"
                )
        
        self.client = openai.OpenAI(
            # The custom base URL points to W&B Inference
            base_url='https://api.inference.wandb.ai/v1',
            
            # Get your API key from https://wandb.ai/authorize
            api_key=api_key,
        )
        self.model = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
    
    def _encode_image(self, image: Image.Image) -> str:
        """Encode PIL Image to base64 string."""
        try:
            # Save image to temporary file and encode
            with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
                image.save(tmp_file.name, format='JPEG')
                with open(tmp_file.name, "rb") as image_file:
                    encoded = base64.b64encode(image_file.read()).decode('utf-8')
                # Clean up temp file
                Path(tmp_file.name).unlink()
                return encoded
        except Exception as e:
            print(f"Error encoding image: {e}")
            return None
    
    def _call_vision_model(self, prompt: str, images: list) -> str:
        """Call the vision model with prompt and images."""
        try:
            # Prepare content with text and images
            content = [{"type": "text", "text": prompt}]
            
            for i, img in enumerate(images):
                base64_image = self._encode_image(img)
                if base64_image:
                    if i > 0:  # Add label for multiple images
                        content.append({
                            "type": "text", 
                            "text": f"Image {i+1}:"
                        })
                    content.append({
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image}"
                        }
                    })
            
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {
                        "role": "system",
                        "content": "You are an expert image analyst. Provide detailed, accurate analysis."
                    },
                    {
                        "role": "user",
                        "content": content
                    }
                ],
                max_tokens=1000
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f"Error calling vision model: {e}")
            return None
    
    def generate_initial_prompt(self, generated_img: Image.Image) -> str:
        """Generate an initial prompt by describing the generated_img image."""
        prompt = """
        Analyze this image and generate a detailed text prompt that could be used to recreate it. 
        Focus on:
        - Main subjects and objects
        - Visual style and artistic technique
        - Colors, lighting, and mood
        - Composition and layout
        - Important details and textures
        
        Provide a concise but comprehensive prompt suitable for image generation.
        """
        
        description = self._call_vision_model(prompt, [generated_img])
        
        if description:
            return description.strip()
        else:
            # Fallback prompt
            return "A beautiful image with vibrant colors and detailed composition"
    @weave.op()
    def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
        """Analyze differences between generated and target images."""
        analysis_prompt = """
        Compare these two images and analyze their differences. The first image is generated, the second is the target.
        
        Please provide a detailed analysis in JSON format with the following structure:
        {
            "missing_elements": ["list of elements present in target but missing in generated"],
            "style_differences": ["list of style differences between the images"],
            "color_differences": ["differences in color, lighting, or tone"],
            "composition_differences": ["differences in layout, positioning, or framing"],
            "quality_differences": ["differences in detail, sharpness, or overall quality"],
            "similarity_score": "percentage of how similar the images are (0-100)",
            "overall_assessment": "brief summary of the main differences"
        }
        
        Focus on identifying what elements, styles, or qualities are present in the target image but missing or different in the generated image.
        """
        
        response_text = self._call_vision_model(analysis_prompt, [generated_img, target_img])
        
        if not response_text:
            return {
                "missing_elements": ["texture", "details"],
                "style_differences": ["color intensity", "composition"],
                "error": "Failed to analyze images"
            }
        
        try:
            # Extract JSON from the response if it's wrapped in markdown
            if "```json" in response_text:
                json_start = response_text.find("```json") + 7
                json_end = response_text.find("```", json_start)
                json_text = response_text[json_start:json_end].strip()
            elif "{" in response_text and "}" in response_text:
                # Find the JSON object in the response
                json_start = response_text.find("{")
                json_end = response_text.rfind("}") + 1
                json_text = response_text[json_start:json_end]
            else:
                json_text = response_text
            
            analysis_result = json.loads(json_text)
            
            # Ensure required keys exist with fallback values
            if "missing_elements" not in analysis_result:
                analysis_result["missing_elements"] = ["texture", "details"]
            if "style_differences" not in analysis_result:
                analysis_result["style_differences"] = ["color intensity", "composition"]
                
            return analysis_result
            
        except json.JSONDecodeError:
            # If JSON parsing fails, return a structured response with fallback values
            return {
                "missing_elements": ["texture", "details"],
                "style_differences": ["color intensity", "composition"],
                "raw_analysis": response_text,
                "note": "JSON parsing failed, using fallback analysis"
            }
    @weave.op()
    def describe_image(self, image: Image.Image, custom_prompt: str = None) -> str:
        """Generate a detailed description of an image."""
        if not custom_prompt:
            custom_prompt = "Please describe this image in detail, including objects, people, colors, setting, and any notable features."
        
        description = self._call_vision_model(custom_prompt, [image])
        return description if description else "Failed to generate description"


# Utility functions for backward compatibility
def encode_image_from_path(image_path: str) -> str:
    """Encode image from file path to base64 string."""
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    except FileNotFoundError:
        print(f"Error: Image file not found at {image_path}")
        return None
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None

def describe_image_from_path(image_path: str, custom_prompt: str = None) -> str:
    """Generate description for an image from file path."""
    if not Path(image_path).exists():
        print(f"Error: Image file does not exist at {image_path}")
        return None
    
    # Load image and use evaluator
    image = Image.open(image_path)
    evaluator = LlamaEvaluator()
    return evaluator.describe_image(image, custom_prompt)

def analyze_differences_from_paths(generated_img_path: str, target_img_path: str) -> Dict[str, Any]:
    """Analyze differences between two images from file paths."""
    try:
        generated_img = Image.open(generated_img_path)
        target_img = Image.open(target_img_path)
        
        evaluator = LlamaEvaluator()
        return evaluator.analyze_differences(generated_img, target_img)
    except Exception as e:
        return {
            "missing_elements": ["texture", "details"],
            "style_differences": ["color intensity", "composition"],
            "error": str(e)
        }


# Example usage
if __name__ == "__main__":
    # Example 1: Using the class directly
    evaluator = LlamaEvaluator()
    
    # Load images
    try:
        image_path = "/Users/chuchwu/Downloads/happy-190806.jpg"
        target_image = Image.open(image_path)
        
        # Generate initial prompt
        print("Generating initial prompt...")
        initial_prompt = evaluator.generate_initial_prompt(target_image)
        print(f"Initial Prompt: {initial_prompt}")
        print("\n" + "="*50 + "\n")
        
        # Describe the image
        print("Describing image...")
        description = evaluator.describe_image(target_image)
        print(f"Description: {description}")
        print("\n" + "="*50 + "\n")
        
        # Example 2: Analyze differences (using same image for demo)
        print("Analyzing differences...")
        differences = evaluator.analyze_differences(target_image, target_image)
        print("Difference Analysis:")
        print(f"Missing Elements: {differences.get('missing_elements', [])}")
        print(f"Style Differences: {differences.get('style_differences', [])}")
        
        if 'similarity_score' in differences:
            print(f"Similarity Score: {differences['similarity_score']}%")
        
        if 'overall_assessment' in differences:
            print(f"Overall Assessment: {differences['overall_assessment']}")
            
    except FileNotFoundError:
        print("Image file not found. Please update the image_path variable.")
    except Exception as e:
        print(f"Error: {e}")