WeavePrompt / image_evaluators.py
kevin1kevin1k's picture
Upload folder using huggingface_hub
64daa59 verified
import openai
import weave
import base64
import json
import tempfile
import os
from pathlib import Path
from PIL import Image
from typing import Dict, Any, Optional
from weave_prompt import ImageEvaluator
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Weave autopatches OpenAI to log LLM calls to W&B
weave.init(os.getenv("WEAVE_PROJECT", "meta-llama"))
class LlamaEvaluator(ImageEvaluator):
"""Llama-based image evaluator using W&B Inference."""
def __init__(self, api_key: Optional[str] = None):
"""
Initialize the Llama evaluator with OpenAI client.
Args:
api_key: Optional API key. If not provided, will look for OPENAI_API_KEY
or WANDB_API_KEY environment variables.
"""
# Get API key from parameter, environment variables, or raise error
if api_key is None:
api_key = os.getenv('WANDB_API_KEY')
if api_key is None:
raise ValueError(
"API key not provided. Please either:\n"
"1. Pass api_key parameter to LlamaEvaluator()\n"
"2. Set OPENAI_API_KEY environment variable\n"
"3. Set WANDB_API_KEY environment variable\n"
"Get your API key from https://wandb.ai/authorize"
)
self.client = openai.OpenAI(
# The custom base URL points to W&B Inference
base_url='https://api.inference.wandb.ai/v1',
# Get your API key from https://wandb.ai/authorize
api_key=api_key,
)
self.model = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
def _encode_image(self, image: Image.Image) -> str:
"""Encode PIL Image to base64 string."""
try:
# Save image to temporary file and encode
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
image.save(tmp_file.name, format='JPEG')
with open(tmp_file.name, "rb") as image_file:
encoded = base64.b64encode(image_file.read()).decode('utf-8')
# Clean up temp file
Path(tmp_file.name).unlink()
return encoded
except Exception as e:
print(f"Error encoding image: {e}")
return None
def _call_vision_model(self, prompt: str, images: list) -> str:
"""Call the vision model with prompt and images."""
try:
# Prepare content with text and images
content = [{"type": "text", "text": prompt}]
for i, img in enumerate(images):
base64_image = self._encode_image(img)
if base64_image:
if i > 0: # Add label for multiple images
content.append({
"type": "text",
"text": f"Image {i+1}:"
})
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": "You are an expert image analyst. Provide detailed, accurate analysis."
},
{
"role": "user",
"content": content
}
],
max_tokens=1000
)
return response.choices[0].message.content
except Exception as e:
print(f"Error calling vision model: {e}")
return None
def generate_initial_prompt(self, generated_img: Image.Image) -> str:
"""Generate an initial prompt by describing the generated_img image."""
prompt = """
Analyze this image and generate a detailed text prompt that could be used to recreate it.
Focus on:
- Main subjects and objects
- Visual style and artistic technique
- Colors, lighting, and mood
- Composition and layout
- Important details and textures
Provide a concise but comprehensive prompt suitable for image generation.
"""
description = self._call_vision_model(prompt, [generated_img])
if description:
return description.strip()
else:
# Fallback prompt
return "A beautiful image with vibrant colors and detailed composition"
@weave.op()
def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
"""Analyze differences between generated and target images."""
analysis_prompt = """
Compare these two images and analyze their differences. The first image is generated, the second is the target.
Please provide a detailed analysis in JSON format with the following structure:
{
"missing_elements": ["list of elements present in target but missing in generated"],
"style_differences": ["list of style differences between the images"],
"color_differences": ["differences in color, lighting, or tone"],
"composition_differences": ["differences in layout, positioning, or framing"],
"quality_differences": ["differences in detail, sharpness, or overall quality"],
"similarity_score": "percentage of how similar the images are (0-100)",
"overall_assessment": "brief summary of the main differences"
}
Focus on identifying what elements, styles, or qualities are present in the target image but missing or different in the generated image.
"""
response_text = self._call_vision_model(analysis_prompt, [generated_img, target_img])
if not response_text:
return {
"missing_elements": ["texture", "details"],
"style_differences": ["color intensity", "composition"],
"error": "Failed to analyze images"
}
try:
# Extract JSON from the response if it's wrapped in markdown
if "```json" in response_text:
json_start = response_text.find("```json") + 7
json_end = response_text.find("```", json_start)
json_text = response_text[json_start:json_end].strip()
elif "{" in response_text and "}" in response_text:
# Find the JSON object in the response
json_start = response_text.find("{")
json_end = response_text.rfind("}") + 1
json_text = response_text[json_start:json_end]
else:
json_text = response_text
analysis_result = json.loads(json_text)
# Ensure required keys exist with fallback values
if "missing_elements" not in analysis_result:
analysis_result["missing_elements"] = ["texture", "details"]
if "style_differences" not in analysis_result:
analysis_result["style_differences"] = ["color intensity", "composition"]
return analysis_result
except json.JSONDecodeError:
# If JSON parsing fails, return a structured response with fallback values
return {
"missing_elements": ["texture", "details"],
"style_differences": ["color intensity", "composition"],
"raw_analysis": response_text,
"note": "JSON parsing failed, using fallback analysis"
}
@weave.op()
def describe_image(self, image: Image.Image, custom_prompt: str = None) -> str:
"""Generate a detailed description of an image."""
if not custom_prompt:
custom_prompt = "Please describe this image in detail, including objects, people, colors, setting, and any notable features."
description = self._call_vision_model(custom_prompt, [image])
return description if description else "Failed to generate description"
# Utility functions for backward compatibility
def encode_image_from_path(image_path: str) -> str:
"""Encode image from file path to base64 string."""
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
except FileNotFoundError:
print(f"Error: Image file not found at {image_path}")
return None
except Exception as e:
print(f"Error encoding image: {e}")
return None
def describe_image_from_path(image_path: str, custom_prompt: str = None) -> str:
"""Generate description for an image from file path."""
if not Path(image_path).exists():
print(f"Error: Image file does not exist at {image_path}")
return None
# Load image and use evaluator
image = Image.open(image_path)
evaluator = LlamaEvaluator()
return evaluator.describe_image(image, custom_prompt)
def analyze_differences_from_paths(generated_img_path: str, target_img_path: str) -> Dict[str, Any]:
"""Analyze differences between two images from file paths."""
try:
generated_img = Image.open(generated_img_path)
target_img = Image.open(target_img_path)
evaluator = LlamaEvaluator()
return evaluator.analyze_differences(generated_img, target_img)
except Exception as e:
return {
"missing_elements": ["texture", "details"],
"style_differences": ["color intensity", "composition"],
"error": str(e)
}
# Example usage
if __name__ == "__main__":
# Example 1: Using the class directly
evaluator = LlamaEvaluator()
# Load images
try:
image_path = "/Users/chuchwu/Downloads/happy-190806.jpg"
target_image = Image.open(image_path)
# Generate initial prompt
print("Generating initial prompt...")
initial_prompt = evaluator.generate_initial_prompt(target_image)
print(f"Initial Prompt: {initial_prompt}")
print("\n" + "="*50 + "\n")
# Describe the image
print("Describing image...")
description = evaluator.describe_image(target_image)
print(f"Description: {description}")
print("\n" + "="*50 + "\n")
# Example 2: Analyze differences (using same image for demo)
print("Analyzing differences...")
differences = evaluator.analyze_differences(target_image, target_image)
print("Difference Analysis:")
print(f"Missing Elements: {differences.get('missing_elements', [])}")
print(f"Style Differences: {differences.get('style_differences', [])}")
if 'similarity_score' in differences:
print(f"Similarity Score: {differences['similarity_score']}%")
if 'overall_assessment' in differences:
print(f"Overall Assessment: {differences['overall_assessment']}")
except FileNotFoundError:
print("Image file not found. Please update the image_path variable.")
except Exception as e:
print(f"Error: {e}")