Spaces:
Runtime error
Runtime error
| import openai | |
| import weave | |
| import base64 | |
| import json | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| from PIL import Image | |
| from typing import Dict, Any, Optional | |
| from weave_prompt import ImageEvaluator | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Weave autopatches OpenAI to log LLM calls to W&B | |
| weave.init(os.getenv("WEAVE_PROJECT", "meta-llama")) | |
| class LlamaEvaluator(ImageEvaluator): | |
| """Llama-based image evaluator using W&B Inference.""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| """ | |
| Initialize the Llama evaluator with OpenAI client. | |
| Args: | |
| api_key: Optional API key. If not provided, will look for OPENAI_API_KEY | |
| or WANDB_API_KEY environment variables. | |
| """ | |
| # Get API key from parameter, environment variables, or raise error | |
| if api_key is None: | |
| api_key = os.getenv('WANDB_API_KEY') | |
| if api_key is None: | |
| raise ValueError( | |
| "API key not provided. Please either:\n" | |
| "1. Pass api_key parameter to LlamaEvaluator()\n" | |
| "2. Set OPENAI_API_KEY environment variable\n" | |
| "3. Set WANDB_API_KEY environment variable\n" | |
| "Get your API key from https://wandb.ai/authorize" | |
| ) | |
| self.client = openai.OpenAI( | |
| # The custom base URL points to W&B Inference | |
| base_url='https://api.inference.wandb.ai/v1', | |
| # Get your API key from https://wandb.ai/authorize | |
| api_key=api_key, | |
| ) | |
| self.model = "meta-llama/Llama-4-Scout-17B-16E-Instruct" | |
| def _encode_image(self, image: Image.Image) -> str: | |
| """Encode PIL Image to base64 string.""" | |
| try: | |
| # Save image to temporary file and encode | |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: | |
| image.save(tmp_file.name, format='JPEG') | |
| with open(tmp_file.name, "rb") as image_file: | |
| encoded = base64.b64encode(image_file.read()).decode('utf-8') | |
| # Clean up temp file | |
| Path(tmp_file.name).unlink() | |
| return encoded | |
| except Exception as e: | |
| print(f"Error encoding image: {e}") | |
| return None | |
| def _call_vision_model(self, prompt: str, images: list) -> str: | |
| """Call the vision model with prompt and images.""" | |
| try: | |
| # Prepare content with text and images | |
| content = [{"type": "text", "text": prompt}] | |
| for i, img in enumerate(images): | |
| base64_image = self._encode_image(img) | |
| if base64_image: | |
| if i > 0: # Add label for multiple images | |
| content.append({ | |
| "type": "text", | |
| "text": f"Image {i+1}:" | |
| }) | |
| content.append({ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}" | |
| } | |
| }) | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are an expert image analyst. Provide detailed, accurate analysis." | |
| }, | |
| { | |
| "role": "user", | |
| "content": content | |
| } | |
| ], | |
| max_tokens=1000 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| print(f"Error calling vision model: {e}") | |
| return None | |
| def generate_initial_prompt(self, generated_img: Image.Image) -> str: | |
| """Generate an initial prompt by describing the generated_img image.""" | |
| prompt = """ | |
| Analyze this image and generate a detailed text prompt that could be used to recreate it. | |
| Focus on: | |
| - Main subjects and objects | |
| - Visual style and artistic technique | |
| - Colors, lighting, and mood | |
| - Composition and layout | |
| - Important details and textures | |
| Provide a concise but comprehensive prompt suitable for image generation. | |
| """ | |
| description = self._call_vision_model(prompt, [generated_img]) | |
| if description: | |
| return description.strip() | |
| else: | |
| # Fallback prompt | |
| return "A beautiful image with vibrant colors and detailed composition" | |
| def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]: | |
| """Analyze differences between generated and target images.""" | |
| analysis_prompt = """ | |
| Compare these two images and analyze their differences. The first image is generated, the second is the target. | |
| Please provide a detailed analysis in JSON format with the following structure: | |
| { | |
| "missing_elements": ["list of elements present in target but missing in generated"], | |
| "style_differences": ["list of style differences between the images"], | |
| "color_differences": ["differences in color, lighting, or tone"], | |
| "composition_differences": ["differences in layout, positioning, or framing"], | |
| "quality_differences": ["differences in detail, sharpness, or overall quality"], | |
| "similarity_score": "percentage of how similar the images are (0-100)", | |
| "overall_assessment": "brief summary of the main differences" | |
| } | |
| Focus on identifying what elements, styles, or qualities are present in the target image but missing or different in the generated image. | |
| """ | |
| response_text = self._call_vision_model(analysis_prompt, [generated_img, target_img]) | |
| if not response_text: | |
| return { | |
| "missing_elements": ["texture", "details"], | |
| "style_differences": ["color intensity", "composition"], | |
| "error": "Failed to analyze images" | |
| } | |
| try: | |
| # Extract JSON from the response if it's wrapped in markdown | |
| if "```json" in response_text: | |
| json_start = response_text.find("```json") + 7 | |
| json_end = response_text.find("```", json_start) | |
| json_text = response_text[json_start:json_end].strip() | |
| elif "{" in response_text and "}" in response_text: | |
| # Find the JSON object in the response | |
| json_start = response_text.find("{") | |
| json_end = response_text.rfind("}") + 1 | |
| json_text = response_text[json_start:json_end] | |
| else: | |
| json_text = response_text | |
| analysis_result = json.loads(json_text) | |
| # Ensure required keys exist with fallback values | |
| if "missing_elements" not in analysis_result: | |
| analysis_result["missing_elements"] = ["texture", "details"] | |
| if "style_differences" not in analysis_result: | |
| analysis_result["style_differences"] = ["color intensity", "composition"] | |
| return analysis_result | |
| except json.JSONDecodeError: | |
| # If JSON parsing fails, return a structured response with fallback values | |
| return { | |
| "missing_elements": ["texture", "details"], | |
| "style_differences": ["color intensity", "composition"], | |
| "raw_analysis": response_text, | |
| "note": "JSON parsing failed, using fallback analysis" | |
| } | |
| def describe_image(self, image: Image.Image, custom_prompt: str = None) -> str: | |
| """Generate a detailed description of an image.""" | |
| if not custom_prompt: | |
| custom_prompt = "Please describe this image in detail, including objects, people, colors, setting, and any notable features." | |
| description = self._call_vision_model(custom_prompt, [image]) | |
| return description if description else "Failed to generate description" | |
| # Utility functions for backward compatibility | |
| def encode_image_from_path(image_path: str) -> str: | |
| """Encode image from file path to base64 string.""" | |
| try: | |
| with open(image_path, "rb") as image_file: | |
| return base64.b64encode(image_file.read()).decode('utf-8') | |
| except FileNotFoundError: | |
| print(f"Error: Image file not found at {image_path}") | |
| return None | |
| except Exception as e: | |
| print(f"Error encoding image: {e}") | |
| return None | |
| def describe_image_from_path(image_path: str, custom_prompt: str = None) -> str: | |
| """Generate description for an image from file path.""" | |
| if not Path(image_path).exists(): | |
| print(f"Error: Image file does not exist at {image_path}") | |
| return None | |
| # Load image and use evaluator | |
| image = Image.open(image_path) | |
| evaluator = LlamaEvaluator() | |
| return evaluator.describe_image(image, custom_prompt) | |
| def analyze_differences_from_paths(generated_img_path: str, target_img_path: str) -> Dict[str, Any]: | |
| """Analyze differences between two images from file paths.""" | |
| try: | |
| generated_img = Image.open(generated_img_path) | |
| target_img = Image.open(target_img_path) | |
| evaluator = LlamaEvaluator() | |
| return evaluator.analyze_differences(generated_img, target_img) | |
| except Exception as e: | |
| return { | |
| "missing_elements": ["texture", "details"], | |
| "style_differences": ["color intensity", "composition"], | |
| "error": str(e) | |
| } | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Example 1: Using the class directly | |
| evaluator = LlamaEvaluator() | |
| # Load images | |
| try: | |
| image_path = "/Users/chuchwu/Downloads/happy-190806.jpg" | |
| target_image = Image.open(image_path) | |
| # Generate initial prompt | |
| print("Generating initial prompt...") | |
| initial_prompt = evaluator.generate_initial_prompt(target_image) | |
| print(f"Initial Prompt: {initial_prompt}") | |
| print("\n" + "="*50 + "\n") | |
| # Describe the image | |
| print("Describing image...") | |
| description = evaluator.describe_image(target_image) | |
| print(f"Description: {description}") | |
| print("\n" + "="*50 + "\n") | |
| # Example 2: Analyze differences (using same image for demo) | |
| print("Analyzing differences...") | |
| differences = evaluator.analyze_differences(target_image, target_image) | |
| print("Difference Analysis:") | |
| print(f"Missing Elements: {differences.get('missing_elements', [])}") | |
| print(f"Style Differences: {differences.get('style_differences', [])}") | |
| if 'similarity_score' in differences: | |
| print(f"Similarity Score: {differences['similarity_score']}%") | |
| if 'overall_assessment' in differences: | |
| print(f"Overall Assessment: {differences['overall_assessment']}") | |
| except FileNotFoundError: | |
| print("Image file not found. Please update the image_path variable.") | |
| except Exception as e: | |
| print(f"Error: {e}") |