kevin1kevin1k commited on
Commit
fb2f0a7
·
verified ·
1 Parent(s): 4584c11

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import streamlit as st
2
  from PIL import Image
3
  from dotenv import load_dotenv
4
- from image_to_text import LlamaEvaluator
5
- from prompt_refiner import LlamaPromptRefiner
6
  from weave_prompt import PromptOptimizer
7
- from lpips_evaluator import LPIPSImageSimilarityMetric
8
- from fal_image_generator import FalImageGenerator
9
 
10
  # Load environment variables from .env file
11
  load_dotenv()
 
1
  import streamlit as st
2
  from PIL import Image
3
  from dotenv import load_dotenv
4
+ from image_evaluators import LlamaEvaluator
5
+ from prompt_refiners import LlamaPromptRefiner
6
  from weave_prompt import PromptOptimizer
7
+ from similarity_metrics import LPIPSImageSimilarityMetric
8
+ from image_generators import FalImageGenerator
9
 
10
  # Load environment variables from .env file
11
  load_dotenv()
image_evaluators.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import weave
3
+ import base64
4
+ import json
5
+ import tempfile
6
+ import os
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ from typing import Dict, Any, Optional
10
+ from weave_prompt import ImageEvaluator
11
+ from dotenv import load_dotenv
12
+
13
+ # Load environment variables from .env file
14
+ load_dotenv()
15
+
16
+ # Weave autopatches OpenAI to log LLM calls to W&B
17
+ weave.init("meta-llama")
18
+
19
+ class LlamaEvaluator(ImageEvaluator):
20
+ """Llama-based image evaluator using W&B Inference."""
21
+
22
+ def __init__(self, api_key: Optional[str] = None):
23
+ """
24
+ Initialize the Llama evaluator with OpenAI client.
25
+
26
+ Args:
27
+ api_key: Optional API key. If not provided, will look for OPENAI_API_KEY
28
+ or WANDB_API_KEY environment variables.
29
+ """
30
+ # Get API key from parameter, environment variables, or raise error
31
+ if api_key is None:
32
+ api_key = os.getenv('WANDB_API_KEY')
33
+ if api_key is None:
34
+ raise ValueError(
35
+ "API key not provided. Please either:\n"
36
+ "1. Pass api_key parameter to LlamaEvaluator()\n"
37
+ "2. Set OPENAI_API_KEY environment variable\n"
38
+ "3. Set WANDB_API_KEY environment variable\n"
39
+ "Get your API key from https://wandb.ai/authorize"
40
+ )
41
+
42
+ self.client = openai.OpenAI(
43
+ # The custom base URL points to W&B Inference
44
+ base_url='https://api.inference.wandb.ai/v1',
45
+
46
+ # Get your API key from https://wandb.ai/authorize
47
+ api_key=api_key,
48
+ )
49
+ self.model = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
50
+
51
+ def _encode_image(self, image: Image.Image) -> str:
52
+ """Encode PIL Image to base64 string."""
53
+ try:
54
+ # Save image to temporary file and encode
55
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
56
+ image.save(tmp_file.name, format='JPEG')
57
+ with open(tmp_file.name, "rb") as image_file:
58
+ encoded = base64.b64encode(image_file.read()).decode('utf-8')
59
+ # Clean up temp file
60
+ Path(tmp_file.name).unlink()
61
+ return encoded
62
+ except Exception as e:
63
+ print(f"Error encoding image: {e}")
64
+ return None
65
+
66
+ def _call_vision_model(self, prompt: str, images: list) -> str:
67
+ """Call the vision model with prompt and images."""
68
+ try:
69
+ # Prepare content with text and images
70
+ content = [{"type": "text", "text": prompt}]
71
+
72
+ for i, img in enumerate(images):
73
+ base64_image = self._encode_image(img)
74
+ if base64_image:
75
+ if i > 0: # Add label for multiple images
76
+ content.append({
77
+ "type": "text",
78
+ "text": f"Image {i+1}:"
79
+ })
80
+ content.append({
81
+ "type": "image_url",
82
+ "image_url": {
83
+ "url": f"data:image/jpeg;base64,{base64_image}"
84
+ }
85
+ })
86
+
87
+ response = self.client.chat.completions.create(
88
+ model=self.model,
89
+ messages=[
90
+ {
91
+ "role": "system",
92
+ "content": "You are an expert image analyst. Provide detailed, accurate analysis."
93
+ },
94
+ {
95
+ "role": "user",
96
+ "content": content
97
+ }
98
+ ],
99
+ max_tokens=1000
100
+ )
101
+ return response.choices[0].message.content
102
+ except Exception as e:
103
+ print(f"Error calling vision model: {e}")
104
+ return None
105
+
106
+ def generate_initial_prompt(self, generated_img: Image.Image) -> str:
107
+ """Generate an initial prompt by describing the generated_img image."""
108
+ prompt = """
109
+ Analyze this image and generate a detailed text prompt that could be used to recreate it.
110
+ Focus on:
111
+ - Main subjects and objects
112
+ - Visual style and artistic technique
113
+ - Colors, lighting, and mood
114
+ - Composition and layout
115
+ - Important details and textures
116
+
117
+ Provide a concise but comprehensive prompt suitable for image generation.
118
+ """
119
+
120
+ description = self._call_vision_model(prompt, [generated_img])
121
+
122
+ if description:
123
+ return description.strip()
124
+ else:
125
+ # Fallback prompt
126
+ return "A beautiful image with vibrant colors and detailed composition"
127
+ @weave.op()
128
+ def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
129
+ """Analyze differences between generated and target images."""
130
+ analysis_prompt = """
131
+ Compare these two images and analyze their differences. The first image is generated, the second is the target.
132
+
133
+ Please provide a detailed analysis in JSON format with the following structure:
134
+ {
135
+ "missing_elements": ["list of elements present in target but missing in generated"],
136
+ "style_differences": ["list of style differences between the images"],
137
+ "color_differences": ["differences in color, lighting, or tone"],
138
+ "composition_differences": ["differences in layout, positioning, or framing"],
139
+ "quality_differences": ["differences in detail, sharpness, or overall quality"],
140
+ "similarity_score": "percentage of how similar the images are (0-100)",
141
+ "overall_assessment": "brief summary of the main differences"
142
+ }
143
+
144
+ Focus on identifying what elements, styles, or qualities are present in the target image but missing or different in the generated image.
145
+ """
146
+
147
+ response_text = self._call_vision_model(analysis_prompt, [generated_img, target_img])
148
+
149
+ if not response_text:
150
+ return {
151
+ "missing_elements": ["texture", "details"],
152
+ "style_differences": ["color intensity", "composition"],
153
+ "error": "Failed to analyze images"
154
+ }
155
+
156
+ try:
157
+ # Extract JSON from the response if it's wrapped in markdown
158
+ if "```json" in response_text:
159
+ json_start = response_text.find("```json") + 7
160
+ json_end = response_text.find("```", json_start)
161
+ json_text = response_text[json_start:json_end].strip()
162
+ elif "{" in response_text and "}" in response_text:
163
+ # Find the JSON object in the response
164
+ json_start = response_text.find("{")
165
+ json_end = response_text.rfind("}") + 1
166
+ json_text = response_text[json_start:json_end]
167
+ else:
168
+ json_text = response_text
169
+
170
+ analysis_result = json.loads(json_text)
171
+
172
+ # Ensure required keys exist with fallback values
173
+ if "missing_elements" not in analysis_result:
174
+ analysis_result["missing_elements"] = ["texture", "details"]
175
+ if "style_differences" not in analysis_result:
176
+ analysis_result["style_differences"] = ["color intensity", "composition"]
177
+
178
+ return analysis_result
179
+
180
+ except json.JSONDecodeError:
181
+ # If JSON parsing fails, return a structured response with fallback values
182
+ return {
183
+ "missing_elements": ["texture", "details"],
184
+ "style_differences": ["color intensity", "composition"],
185
+ "raw_analysis": response_text,
186
+ "note": "JSON parsing failed, using fallback analysis"
187
+ }
188
+ @weave.op()
189
+ def describe_image(self, image: Image.Image, custom_prompt: str = None) -> str:
190
+ """Generate a detailed description of an image."""
191
+ if not custom_prompt:
192
+ custom_prompt = "Please describe this image in detail, including objects, people, colors, setting, and any notable features."
193
+
194
+ description = self._call_vision_model(custom_prompt, [image])
195
+ return description if description else "Failed to generate description"
196
+
197
+
198
+ # Utility functions for backward compatibility
199
+ def encode_image_from_path(image_path: str) -> str:
200
+ """Encode image from file path to base64 string."""
201
+ try:
202
+ with open(image_path, "rb") as image_file:
203
+ return base64.b64encode(image_file.read()).decode('utf-8')
204
+ except FileNotFoundError:
205
+ print(f"Error: Image file not found at {image_path}")
206
+ return None
207
+ except Exception as e:
208
+ print(f"Error encoding image: {e}")
209
+ return None
210
+
211
+ def describe_image_from_path(image_path: str, custom_prompt: str = None) -> str:
212
+ """Generate description for an image from file path."""
213
+ if not Path(image_path).exists():
214
+ print(f"Error: Image file does not exist at {image_path}")
215
+ return None
216
+
217
+ # Load image and use evaluator
218
+ image = Image.open(image_path)
219
+ evaluator = LlamaEvaluator()
220
+ return evaluator.describe_image(image, custom_prompt)
221
+
222
+ def analyze_differences_from_paths(generated_img_path: str, target_img_path: str) -> Dict[str, Any]:
223
+ """Analyze differences between two images from file paths."""
224
+ try:
225
+ generated_img = Image.open(generated_img_path)
226
+ target_img = Image.open(target_img_path)
227
+
228
+ evaluator = LlamaEvaluator()
229
+ return evaluator.analyze_differences(generated_img, target_img)
230
+ except Exception as e:
231
+ return {
232
+ "missing_elements": ["texture", "details"],
233
+ "style_differences": ["color intensity", "composition"],
234
+ "error": str(e)
235
+ }
236
+
237
+
238
+ # Example usage
239
+ if __name__ == "__main__":
240
+ # Example 1: Using the class directly
241
+ evaluator = LlamaEvaluator()
242
+
243
+ # Load images
244
+ try:
245
+ image_path = "/Users/chuchwu/Downloads/happy-190806.jpg"
246
+ target_image = Image.open(image_path)
247
+
248
+ # Generate initial prompt
249
+ print("Generating initial prompt...")
250
+ initial_prompt = evaluator.generate_initial_prompt(target_image)
251
+ print(f"Initial Prompt: {initial_prompt}")
252
+ print("\n" + "="*50 + "\n")
253
+
254
+ # Describe the image
255
+ print("Describing image...")
256
+ description = evaluator.describe_image(target_image)
257
+ print(f"Description: {description}")
258
+ print("\n" + "="*50 + "\n")
259
+
260
+ # Example 2: Analyze differences (using same image for demo)
261
+ print("Analyzing differences...")
262
+ differences = evaluator.analyze_differences(target_image, target_image)
263
+ print("Difference Analysis:")
264
+ print(f"Missing Elements: {differences.get('missing_elements', [])}")
265
+ print(f"Style Differences: {differences.get('style_differences', [])}")
266
+
267
+ if 'similarity_score' in differences:
268
+ print(f"Similarity Score: {differences['similarity_score']}%")
269
+
270
+ if 'overall_assessment' in differences:
271
+ print(f"Overall Assessment: {differences['overall_assessment']}")
272
+
273
+ except FileNotFoundError:
274
+ print("Image file not found. Please update the image_path variable.")
275
+ except Exception as e:
276
+ print(f"Error: {e}")
image_generators.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fal_client
2
+ from PIL import Image
3
+ from typing import Dict, Any
4
+ import requests
5
+ from io import BytesIO
6
+
7
+ from weave_prompt import ImageGenerator
8
+
9
+ from dotenv import load_dotenv
10
+ load_dotenv()
11
+
12
+ class FalImageGenerator(ImageGenerator):
13
+ """Handles image generation using fal_client."""
14
+
15
+ def __init__(self, model_name: str = "fal-ai/flux-pro"):
16
+ self.model_name = model_name
17
+
18
+ def _on_queue_update(self, update):
19
+ """Handle queue updates during image generation."""
20
+ if isinstance(update, fal_client.InProgress):
21
+ for log in update.logs:
22
+ print(log["message"])
23
+
24
+ def generate(self, prompt: str, **kwargs) -> Image.Image:
25
+ """Generate an image from a text prompt using fal_client."""
26
+ result = fal_client.subscribe(
27
+ self.model_name,
28
+ arguments={
29
+ "prompt": prompt,
30
+ **kwargs
31
+ },
32
+ with_logs=True,
33
+ on_queue_update=self._on_queue_update,
34
+ )
35
+ print(result)
36
+
37
+ return self._extract_image_from_result(result)
38
+
39
+ def _extract_image_from_result(self, result: Dict[str, Any]) -> Image.Image:
40
+ """Extract and download image from fal_client result."""
41
+ if result and 'images' in result and len(result['images']) > 0:
42
+ image_url = result['images'][0]['url']
43
+ response = requests.get(image_url)
44
+ response.raise_for_status() # Raise an exception for bad status codes
45
+ image = Image.open(BytesIO(response.content))
46
+ return image
47
+ else:
48
+ raise ValueError("No image found in the result")
prompt_refiners.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ import openai
3
+ import weave
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ from weave_prompt import PromptRefiner
8
+
9
+ # Load environment variables from .env file
10
+ load_dotenv()
11
+
12
+ # Weave autopatches OpenAI to log LLM calls to W&B
13
+ weave.init(project_name="meta-llama")
14
+
15
+
16
+ class LlamaPromptRefiner(PromptRefiner):
17
+ @weave.op()
18
+ def refine_prompt(self, current_prompt: str, analysis: Dict[str, Any], similarity_score):
19
+ client = openai.OpenAI(
20
+ # The custom base URL points to W&B Inference
21
+ base_url='https://api.inference.wandb.ai/v1',
22
+
23
+ # Get your API key from https://wandb.ai/authorize
24
+ # Consider setting it in the environment as OPENAI_API_KEY instead for safety
25
+ api_key=os.getenv("WANDB_API_KEY"),
26
+ )
27
+
28
+ response = client.chat.completions.create(
29
+ model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
30
+ messages=[
31
+ {
32
+ "role": "system",
33
+ "content": (
34
+ "You are an expert at prompt engineering for text-to-image models. "
35
+ "Given a current prompt and an analysis of the differences between a generated image and a target image, "
36
+ "your job is to suggest a new prompt that will make the generated image more similar to the target. "
37
+ "Limit the new prompt to 100 words at most. "
38
+ "The user message will contain two sections: one for the current prompt and one for the analysis, each delimited by 'START OF CURRENT PROMPT'/'END OF CURRENT PROMPT' and 'START OF ANALYSIS'/'END OF ANALYSIS'. "
39
+ "Only return the improved prompt."
40
+ )
41
+ },
42
+ {
43
+ "role": "user",
44
+ "content": (
45
+ f"<START OF CURRENT PROMPT>\n{current_prompt}\n<END OF CURRENT PROMPT>\n"
46
+ f"<START OF ANALYSIS>\n{str(analysis)}\n<END OF ANALYSIS>\n"
47
+ "Suggest a new, improved prompt. Only return the prompt. Do not exceed 100 words."
48
+ )
49
+ }
50
+ ],
51
+ )
52
+ return response.choices[0].message.content
similarity_metrics.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from weave_prompt import ImageSimilarityMetric
3
+ from PIL import Image
4
+ import lpips
5
+ import torch
6
+ import numpy as np
7
+
8
+ class LPIPSImageSimilarityMetric(ImageSimilarityMetric):
9
+ """Image similarity metric using LPIPS perceptual similarity."""
10
+ def __init__(self, net: str = 'alex', device: str = 'cpu'):
11
+ self.lpips_model = lpips.LPIPS(net=net).to(device)
12
+ self.device = device
13
+
14
+ def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
15
+ def img_to_tensor(img):
16
+ img = img.convert('RGB') # Ensure image has 3 channels for handling PNG
17
+ arr = np.array(img.resize((256, 256))).astype(np.float32) / 255.0
18
+ arr = arr.transpose(2, 0, 1) # HWC to CHW
19
+ tensor = torch.tensor(arr).unsqueeze(0)
20
+ return tensor * 2 - 1 # LPIPS expects [-1, 1]
21
+ gen_tensor = img_to_tensor(generated_img).to(self.device)
22
+ tgt_tensor = img_to_tensor(target_img).to(self.device)
23
+ distance = self.lpips_model(gen_tensor, tgt_tensor).item()
24
+ similarity = max(0.0, 1.0 - distance)
25
+ return similarity
weave_prompt.py CHANGED
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
3
  from typing import Any, Dict, List, Optional, Union
4
  import PIL.Image as Image
5
 
6
- class TextToImageModel(ABC):
7
  """Abstract base class for text-to-image models."""
8
 
9
  @abstractmethod
@@ -85,7 +85,7 @@ class PromptOptimizer:
85
  """Main class that orchestrates the prompt optimization process."""
86
 
87
  def __init__(self,
88
- model: TextToImageModel,
89
  evaluator: ImageEvaluator,
90
  refiner: PromptRefiner,
91
  similarity_metric: ImageSimilarityMetric,
@@ -94,7 +94,7 @@ class PromptOptimizer:
94
  """Initialize the optimizer.
95
 
96
  Args:
97
- model: Text-to-image model to use
98
  evaluator: Image evaluator for generating initial prompt and analysis
99
  refiner: Prompt refinement strategy
100
  similarity_metric: Image similarity metric
@@ -102,7 +102,7 @@ class PromptOptimizer:
102
  similarity_threshold: Target similarity threshold for early stopping
103
  """
104
  # Configuration
105
- self.model = model
106
  self.evaluator = evaluator
107
  self.refiner = refiner
108
  self.similarity_metric = similarity_metric
@@ -141,9 +141,9 @@ class PromptOptimizer:
141
  if self.target_img is None or self.current_prompt is None:
142
  raise RuntimeError("Must call initialize() before step()")
143
  if self.iteration >= self.max_iterations:
144
- return True, self.current_prompt, self.model.generate(self.current_prompt)
145
  # Generate image with current prompt
146
- generated_img = self.model.generate(self.current_prompt)
147
  # Evaluate similarity
148
  similarity = self.similarity_metric.compute(generated_img, self.target_img)
149
  # Analyze differences
 
3
  from typing import Any, Dict, List, Optional, Union
4
  import PIL.Image as Image
5
 
6
+ class ImageGenerator(ABC):
7
  """Abstract base class for text-to-image models."""
8
 
9
  @abstractmethod
 
85
  """Main class that orchestrates the prompt optimization process."""
86
 
87
  def __init__(self,
88
+ image_generator: ImageGenerator,
89
  evaluator: ImageEvaluator,
90
  refiner: PromptRefiner,
91
  similarity_metric: ImageSimilarityMetric,
 
94
  """Initialize the optimizer.
95
 
96
  Args:
97
+ image_generator: Text-to-image generator to use
98
  evaluator: Image evaluator for generating initial prompt and analysis
99
  refiner: Prompt refinement strategy
100
  similarity_metric: Image similarity metric
 
102
  similarity_threshold: Target similarity threshold for early stopping
103
  """
104
  # Configuration
105
+ self.image_generator = image_generator
106
  self.evaluator = evaluator
107
  self.refiner = refiner
108
  self.similarity_metric = similarity_metric
 
141
  if self.target_img is None or self.current_prompt is None:
142
  raise RuntimeError("Must call initialize() before step()")
143
  if self.iteration >= self.max_iterations:
144
+ return True, self.current_prompt, self.image_generator.generate(self.current_prompt)
145
  # Generate image with current prompt
146
+ generated_img = self.image_generator.generate(self.current_prompt)
147
  # Evaluate similarity
148
  similarity = self.similarity_metric.compute(generated_img, self.target_img)
149
  # Analyze differences