kevin1kevin1k commited on
Commit
c6eb9ce
·
verified ·
1 Parent(s): 0514939

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .github
3
+ __pycache__
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+ .Python
8
+ env
9
+ .env
10
+ .venv/
11
+ venv/
12
+ pip-log.txt
13
+ pip-delete-this-directory.txt
14
+ .tox
15
+ .coverage
16
+ .coverage.*
17
+ .cache
18
+ nosetests.xml
19
+ coverage.xml
20
+ *.cover
21
+ *.log
22
+ .gitignore
23
+ .vscode
24
+ .idea
25
+ *.swp
26
+ *.swo
27
+ *~
28
+ .DS_Store
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ .python-version
13
+
14
+ # Environment variables and secrets
15
+ .env
16
+ *.env
17
+
18
+ # API keys and sensitive data
19
+ config.json
20
+ secrets.json
21
+
22
+ # History files
23
+ .history/
Dockerfile CHANGED
@@ -1,20 +1,53 @@
1
- FROM python:3.13.5-slim
 
2
 
 
3
  WORKDIR /app
4
 
 
5
  RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
  git \
 
 
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- RUN pip3 install -r requirements.txt
 
15
 
16
- EXPOSE 8501
 
17
 
18
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
 
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
1
+ # Use Python 3.11 slim image
2
+ FROM python:3.11-slim
3
 
4
+ # Set working directory
5
  WORKDIR /app
6
 
7
+ # Install system dependencies
8
  RUN apt-get update && apt-get install -y \
 
 
9
  git \
10
+ curl \
11
+ build-essential \
12
  && rm -rf /var/lib/apt/lists/*
13
 
14
+ # Install UV package manager
15
+ RUN pip install uv
16
+
17
+ # Copy UV configuration files first for better caching
18
+ COPY pyproject.toml uv.lock ./
19
+
20
+ # Install Python dependencies using UV
21
+ RUN uv venv /opt/venv && \
22
+ . /opt/venv/bin/activate && \
23
+ uv sync --frozen
24
+
25
+ # Set the virtual environment as the default Python
26
+ ENV PATH="/opt/venv/bin:$PATH"
27
+
28
+ # Copy application code
29
+ COPY . .
30
+
31
+ # Create a non-root user
32
+ RUN useradd -m -u 1000 user
33
+ USER user
34
+
35
+ # Set environment variables
36
+ ENV HOME=/home/user \
37
+ PATH="/opt/venv/bin:/home/user/.local/bin:$PATH" \
38
+ PYTHONPATH=/app
39
+
40
+ # Change to user's home directory
41
+ WORKDIR $HOME/app
42
 
43
+ # Copy app to user directory
44
+ COPY --chown=user . $HOME/app
45
 
46
+ # Expose the port Streamlit runs on
47
+ EXPOSE 7860
48
 
49
+ # Health check
50
+ HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
51
 
52
+ # Run the Streamlit application
53
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.headless=true", "--server.fileWatcherType=none", "--browser.gatherUsageStats=false"]
README.md CHANGED
@@ -1,19 +1,33 @@
1
- ---
2
- title: WeavePrompt
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Streamlit template space
12
- ---
13
 
14
- # Welcome to Streamlit!
 
15
 
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
 
 
 
 
17
 
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WeavePrompt
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ Iterative prompt refinement for text-to-image models.
4
+ Given a target image, WeavePrompt automatically generates and refines text prompts to make a model's output resemble the target image, using vision-language models and perceptual metrics.
5
 
6
+ ## Features
7
+ - Upload a target image
8
+ - Step-by-step prompt optimization
9
+ - View prompt and generated image at each iteration
10
+ - Full optimization history
11
 
12
+ ## Installation
13
+
14
+ 1. Clone the repository:
15
+ ```bash
16
+ git clone <repo-url>
17
+ cd WeavePrompt
18
+ ```
19
+ 2. Install dependencies:
20
+ ```bash
21
+ uv venv
22
+ uv sync
23
+ source .venv/bin/activate
24
+ ```
25
+
26
+ ## Usage
27
+
28
+ Run the demo app:
29
+ ```bash
30
+ streamlit run app.py
31
+ ```
32
+
33
+ Follow the instructions in the browser to upload an image and step through the optimization process.
README_HF.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: WeavePrompt
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ app_port: 7860
10
+ ---
11
+
12
+ # WeavePrompt
13
+
14
+ An intelligent prompt optimization system that iteratively refines text-to-image generation prompts to better match target images.
15
+
16
+ ## Features
17
+
18
+ - 🎯 **Target-driven optimization**: Upload a target image and get optimized prompts
19
+ - 🔄 **Iterative refinement**: Automatically improves prompts through multiple iterations
20
+ - 📊 **Similarity tracking**: Monitor progress with visual similarity metrics
21
+ - 🎨 **High-quality generation**: Uses advanced text-to-image models
22
+
23
+ ## How it works
24
+
25
+ 1. Upload your target image
26
+ 2. Provide an initial prompt (or let the system generate one)
27
+ 3. Watch as the system iteratively refines the prompt
28
+ 4. Get optimized prompts that better match your target image
29
+
30
+ ## Usage
31
+
32
+ Simply run the Streamlit app and follow the interactive interface to optimize your prompts!
33
+
34
+ ## Configuration
35
+
36
+ Set your API keys as environment variables:
37
+ - `FAL_KEY`: Your FAL AI API key for image generation
38
+
39
+ ---
40
+
41
+ Built with ❤️ using Streamlit and advanced AI models.
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import time
4
+ from image_to_text import LlamaEvaluator
5
+ from prompt_refiner import LlamaPromptRefiner
6
+ from weave_prompt import PromptOptimizer
7
+ from mock_components import MockTextToImageModel, MockImageEvaluator, MockPromptRefiner
8
+ from lpips_evaluator import LPIPSImageSimilarityMetric
9
+ from fal_image_generator import FalImageGenerator
10
+ import io
11
+
12
+ st.set_page_config(
13
+ page_title="WeavePrompt Demo",
14
+ page_icon="🎨",
15
+ layout="wide"
16
+ )
17
+
18
+ def main():
19
+ st.title("🎨 WeavePrompt: Iterative Prompt Optimization")
20
+ st.markdown("""
21
+ Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it.
22
+ This demo uses mock components for illustration.
23
+ """)
24
+
25
+ # Initialize session state
26
+ if 'optimizer' not in st.session_state:
27
+ st.session_state.optimizer = PromptOptimizer(
28
+ model=FalImageGenerator(),
29
+ evaluator=LlamaEvaluator(),
30
+ refiner=LlamaPromptRefiner(),
31
+ similarity_metric=LPIPSImageSimilarityMetric(),
32
+ max_iterations=10,
33
+ similarity_threshold=0.95
34
+ )
35
+
36
+ if 'optimization_started' not in st.session_state:
37
+ st.session_state.optimization_started = False
38
+
39
+ if 'current_results' not in st.session_state:
40
+ st.session_state.current_results = None
41
+
42
+ # File uploader
43
+ uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
44
+
45
+ if uploaded_file is not None:
46
+ # Display target image
47
+ target_image = Image.open(uploaded_file)
48
+
49
+ col1, col2 = st.columns(2)
50
+ with col1:
51
+ st.subheader("Target Image")
52
+ st.image(target_image, width='stretch')
53
+
54
+ # Start button
55
+ if not st.session_state.optimization_started:
56
+ if st.button("Start Optimization"):
57
+ st.session_state.optimization_started = True
58
+ # Initialize optimization
59
+ is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
60
+ st.session_state.current_results = (is_completed, prompt, generated_image)
61
+
62
+ # Display optimization progress
63
+ if st.session_state.optimization_started:
64
+ with col2:
65
+ st.subheader("Generated Image")
66
+ is_completed, prompt, generated_image = st.session_state.current_results
67
+ st.image(generated_image, width='stretch')
68
+
69
+ # Display prompt and controls
70
+ st.text_area("Current Prompt", prompt, height=100)
71
+
72
+ # Progress metrics
73
+ col1, col2, col3 = st.columns(3)
74
+ with col1:
75
+ st.metric("Iteration", len(st.session_state.optimizer.history))
76
+ with col2:
77
+ if len(st.session_state.optimizer.history) > 0:
78
+ similarity = st.session_state.optimizer.history[-1]['similarity']
79
+ st.metric("Similarity", f"{similarity:.2%}")
80
+ with col3:
81
+ st.metric("Status", "Completed" if is_completed else "In Progress")
82
+
83
+ # Next step button
84
+ if not is_completed:
85
+ if st.button("Next Step"):
86
+ is_completed, prompt, generated_image = st.session_state.optimizer.step()
87
+ st.session_state.current_results = (is_completed, prompt, generated_image)
88
+ st.rerun()
89
+ else:
90
+ st.success("Optimization completed! Click 'Reset' to try another image.")
91
+
92
+ # Reset button
93
+ if st.button("Reset"):
94
+ st.session_state.optimization_started = False
95
+ st.session_state.current_results = None
96
+ st.rerun()
97
+
98
+ # Display history
99
+ if len(st.session_state.optimizer.history) > 0:
100
+ st.subheader("Optimization History")
101
+ for idx, hist_entry in enumerate(st.session_state.optimizer.history):
102
+ st.markdown(f"### Step {idx + 1}")
103
+ col1, col2 = st.columns([2, 3])
104
+ with col1:
105
+ st.image(hist_entry['image'], width='stretch')
106
+ with col2:
107
+ st.text(f"Similarity: {hist_entry['similarity']:.2%}")
108
+ st.text("Prompt:")
109
+ st.text(hist_entry['prompt'])
110
+ st.text("\nAnalysis:")
111
+ for key, value in hist_entry['analysis'].items():
112
+ st.text(f"{key}: {value}")
113
+
114
+ if __name__ == "__main__":
115
+ main()
fal_image_generator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fal_client
2
+ from PIL import Image
3
+ from typing import Dict, Any
4
+ import requests
5
+ from io import BytesIO
6
+
7
+ from weave_prompt import TextToImageModel
8
+ import load_keys
9
+
10
+
11
+ class FalImageGenerator(TextToImageModel):
12
+ """Handles image generation using fal_client."""
13
+
14
+ def __init__(self, model_name: str = "fal-ai/flux-pro"):
15
+ self.model_name = model_name
16
+
17
+ def _on_queue_update(self, update):
18
+ """Handle queue updates during image generation."""
19
+ if isinstance(update, fal_client.InProgress):
20
+ for log in update.logs:
21
+ print(log["message"])
22
+
23
+ def generate(self, prompt: str, **kwargs) -> Image.Image:
24
+ """Generate an image from a text prompt using fal_client."""
25
+ result = fal_client.subscribe(
26
+ self.model_name,
27
+ arguments={
28
+ "prompt": prompt,
29
+ **kwargs
30
+ },
31
+ with_logs=True,
32
+ on_queue_update=self._on_queue_update,
33
+ )
34
+ print(result)
35
+
36
+ return self._extract_image_from_result(result)
37
+
38
+ def _extract_image_from_result(self, result: Dict[str, Any]) -> Image.Image:
39
+ """Extract and download image from fal_client result."""
40
+ if result and 'images' in result and len(result['images']) > 0:
41
+ image_url = result['images'][0]['url']
42
+ response = requests.get(image_url)
43
+ response.raise_for_status() # Raise an exception for bad status codes
44
+ image = Image.open(BytesIO(response.content))
45
+ return image
46
+ else:
47
+ raise ValueError("No image found in the result")
image_to_text.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import weave
3
+ import base64
4
+ import json
5
+ import tempfile
6
+ import os
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ from typing import Dict, Any, Optional
10
+ from weave_prompt import ImageEvaluator
11
+ import load_keys
12
+
13
+ # Weave autopatches OpenAI to log LLM calls to W&B
14
+ weave.init("meta-llama")
15
+
16
+ class LlamaEvaluator(ImageEvaluator):
17
+ """Llama-based image evaluator using W&B Inference."""
18
+
19
+ def __init__(self, api_key: Optional[str] = None):
20
+ """
21
+ Initialize the Llama evaluator with OpenAI client.
22
+
23
+ Args:
24
+ api_key: Optional API key. If not provided, will look for OPENAI_API_KEY
25
+ or WANDB_API_KEY environment variables.
26
+ """
27
+ # Get API key from parameter, environment variables, or raise error
28
+ if api_key is None:
29
+ api_key = os.getenv('WANDB_API_KEY')
30
+ if api_key is None:
31
+ raise ValueError(
32
+ "API key not provided. Please either:\n"
33
+ "1. Pass api_key parameter to LlamaEvaluator()\n"
34
+ "2. Set OPENAI_API_KEY environment variable\n"
35
+ "3. Set WANDB_API_KEY environment variable\n"
36
+ "Get your API key from https://wandb.ai/authorize"
37
+ )
38
+
39
+ self.client = openai.OpenAI(
40
+ # The custom base URL points to W&B Inference
41
+ base_url='https://api.inference.wandb.ai/v1',
42
+
43
+ # Get your API key from https://wandb.ai/authorize
44
+ api_key=api_key,
45
+ )
46
+ self.model = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
47
+
48
+ def _encode_image(self, image: Image.Image) -> str:
49
+ """Encode PIL Image to base64 string."""
50
+ try:
51
+ # Save image to temporary file and encode
52
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
53
+ image.save(tmp_file.name, format='JPEG')
54
+ with open(tmp_file.name, "rb") as image_file:
55
+ encoded = base64.b64encode(image_file.read()).decode('utf-8')
56
+ # Clean up temp file
57
+ Path(tmp_file.name).unlink()
58
+ return encoded
59
+ except Exception as e:
60
+ print(f"Error encoding image: {e}")
61
+ return None
62
+
63
+ def _call_vision_model(self, prompt: str, images: list) -> str:
64
+ """Call the vision model with prompt and images."""
65
+ try:
66
+ # Prepare content with text and images
67
+ content = [{"type": "text", "text": prompt}]
68
+
69
+ for i, img in enumerate(images):
70
+ base64_image = self._encode_image(img)
71
+ if base64_image:
72
+ if i > 0: # Add label for multiple images
73
+ content.append({
74
+ "type": "text",
75
+ "text": f"Image {i+1}:"
76
+ })
77
+ content.append({
78
+ "type": "image_url",
79
+ "image_url": {
80
+ "url": f"data:image/jpeg;base64,{base64_image}"
81
+ }
82
+ })
83
+
84
+ response = self.client.chat.completions.create(
85
+ model=self.model,
86
+ messages=[
87
+ {
88
+ "role": "system",
89
+ "content": "You are an expert image analyst. Provide detailed, accurate analysis."
90
+ },
91
+ {
92
+ "role": "user",
93
+ "content": content
94
+ }
95
+ ],
96
+ max_tokens=1000
97
+ )
98
+ return response.choices[0].message.content
99
+ except Exception as e:
100
+ print(f"Error calling vision model: {e}")
101
+ return None
102
+
103
+ def generate_initial_prompt(self, generated_img: Image.Image) -> str:
104
+ """Generate an initial prompt by describing the generated_img image."""
105
+ prompt = """
106
+ Analyze this image and generate a detailed text prompt that could be used to recreate it.
107
+ Focus on:
108
+ - Main subjects and objects
109
+ - Visual style and artistic technique
110
+ - Colors, lighting, and mood
111
+ - Composition and layout
112
+ - Important details and textures
113
+
114
+ Provide a concise but comprehensive prompt suitable for image generation.
115
+ """
116
+
117
+ description = self._call_vision_model(prompt, [generated_img])
118
+
119
+ if description:
120
+ return description.strip()
121
+ else:
122
+ # Fallback prompt
123
+ return "A beautiful image with vibrant colors and detailed composition"
124
+
125
+ def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
126
+ """Analyze differences between generated and target images."""
127
+ analysis_prompt = """
128
+ Compare these two images and analyze their differences. The first image is generated, the second is the target.
129
+
130
+ Please provide a detailed analysis in JSON format with the following structure:
131
+ {
132
+ "missing_elements": ["list of elements present in target but missing in generated"],
133
+ "style_differences": ["list of style differences between the images"],
134
+ "color_differences": ["differences in color, lighting, or tone"],
135
+ "composition_differences": ["differences in layout, positioning, or framing"],
136
+ "quality_differences": ["differences in detail, sharpness, or overall quality"],
137
+ "similarity_score": "percentage of how similar the images are (0-100)",
138
+ "overall_assessment": "brief summary of the main differences"
139
+ }
140
+
141
+ Focus on identifying what elements, styles, or qualities are present in the target image but missing or different in the generated image.
142
+ """
143
+
144
+ response_text = self._call_vision_model(analysis_prompt, [generated_img, target_img])
145
+
146
+ if not response_text:
147
+ return {
148
+ "missing_elements": ["texture", "details"],
149
+ "style_differences": ["color intensity", "composition"],
150
+ "error": "Failed to analyze images"
151
+ }
152
+
153
+ try:
154
+ # Extract JSON from the response if it's wrapped in markdown
155
+ if "```json" in response_text:
156
+ json_start = response_text.find("```json") + 7
157
+ json_end = response_text.find("```", json_start)
158
+ json_text = response_text[json_start:json_end].strip()
159
+ elif "{" in response_text and "}" in response_text:
160
+ # Find the JSON object in the response
161
+ json_start = response_text.find("{")
162
+ json_end = response_text.rfind("}") + 1
163
+ json_text = response_text[json_start:json_end]
164
+ else:
165
+ json_text = response_text
166
+
167
+ analysis_result = json.loads(json_text)
168
+
169
+ # Ensure required keys exist with fallback values
170
+ if "missing_elements" not in analysis_result:
171
+ analysis_result["missing_elements"] = ["texture", "details"]
172
+ if "style_differences" not in analysis_result:
173
+ analysis_result["style_differences"] = ["color intensity", "composition"]
174
+
175
+ return analysis_result
176
+
177
+ except json.JSONDecodeError:
178
+ # If JSON parsing fails, return a structured response with fallback values
179
+ return {
180
+ "missing_elements": ["texture", "details"],
181
+ "style_differences": ["color intensity", "composition"],
182
+ "raw_analysis": response_text,
183
+ "note": "JSON parsing failed, using fallback analysis"
184
+ }
185
+
186
+ def describe_image(self, image: Image.Image, custom_prompt: str = None) -> str:
187
+ """Generate a detailed description of an image."""
188
+ if not custom_prompt:
189
+ custom_prompt = "Please describe this image in detail, including objects, people, colors, setting, and any notable features."
190
+
191
+ description = self._call_vision_model(custom_prompt, [image])
192
+ return description if description else "Failed to generate description"
193
+
194
+
195
+ # Utility functions for backward compatibility
196
+ def encode_image_from_path(image_path: str) -> str:
197
+ """Encode image from file path to base64 string."""
198
+ try:
199
+ with open(image_path, "rb") as image_file:
200
+ return base64.b64encode(image_file.read()).decode('utf-8')
201
+ except FileNotFoundError:
202
+ print(f"Error: Image file not found at {image_path}")
203
+ return None
204
+ except Exception as e:
205
+ print(f"Error encoding image: {e}")
206
+ return None
207
+
208
+ def describe_image_from_path(image_path: str, custom_prompt: str = None) -> str:
209
+ """Generate description for an image from file path."""
210
+ if not Path(image_path).exists():
211
+ print(f"Error: Image file does not exist at {image_path}")
212
+ return None
213
+
214
+ # Load image and use evaluator
215
+ image = Image.open(image_path)
216
+ evaluator = LlamaEvaluator()
217
+ return evaluator.describe_image(image, custom_prompt)
218
+
219
+ def analyze_differences_from_paths(generated_img_path: str, target_img_path: str) -> Dict[str, Any]:
220
+ """Analyze differences between two images from file paths."""
221
+ try:
222
+ generated_img = Image.open(generated_img_path)
223
+ target_img = Image.open(target_img_path)
224
+
225
+ evaluator = LlamaEvaluator()
226
+ return evaluator.analyze_differences(generated_img, target_img)
227
+ except Exception as e:
228
+ return {
229
+ "missing_elements": ["texture", "details"],
230
+ "style_differences": ["color intensity", "composition"],
231
+ "error": str(e)
232
+ }
233
+
234
+
235
+ # Example usage
236
+ if __name__ == "__main__":
237
+ # Example 1: Using the class directly
238
+ evaluator = LlamaEvaluator()
239
+
240
+ # Load images
241
+ try:
242
+ image_path = "/Users/chuchwu/Downloads/happy-190806.jpg"
243
+ target_image = Image.open(image_path)
244
+
245
+ # Generate initial prompt
246
+ print("Generating initial prompt...")
247
+ initial_prompt = evaluator.generate_initial_prompt(target_image)
248
+ print(f"Initial Prompt: {initial_prompt}")
249
+ print("\n" + "="*50 + "\n")
250
+
251
+ # Describe the image
252
+ print("Describing image...")
253
+ description = evaluator.describe_image(target_image)
254
+ print(f"Description: {description}")
255
+ print("\n" + "="*50 + "\n")
256
+
257
+ # Example 2: Analyze differences (using same image for demo)
258
+ print("Analyzing differences...")
259
+ differences = evaluator.analyze_differences(target_image, target_image)
260
+ print("Difference Analysis:")
261
+ print(f"Missing Elements: {differences.get('missing_elements', [])}")
262
+ print(f"Style Differences: {differences.get('style_differences', [])}")
263
+
264
+ if 'similarity_score' in differences:
265
+ print(f"Similarity Score: {differences['similarity_score']}%")
266
+
267
+ if 'overall_assessment' in differences:
268
+ print(f"Overall Assessment: {differences['overall_assessment']}")
269
+
270
+ except FileNotFoundError:
271
+ print("Image file not found. Please update the image_path variable.")
272
+ except Exception as e:
273
+ print(f"Error: {e}")
lpips_evaluator.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from weave_prompt import ImageSimilarityMetric
3
+ from PIL import Image
4
+ import lpips
5
+ import torch
6
+ import numpy as np
7
+
8
+ class LPIPSImageSimilarityMetric(ImageSimilarityMetric):
9
+ """Image similarity metric using LPIPS perceptual similarity."""
10
+ def __init__(self, net: str = 'alex', device: str = 'cpu'):
11
+ self.lpips_model = lpips.LPIPS(net=net).to(device)
12
+ self.device = device
13
+
14
+ def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
15
+ def img_to_tensor(img):
16
+ img = img.convert('RGB') # Ensure image has 3 channels for handling PNG
17
+ arr = np.array(img.resize((256, 256))).astype(np.float32) / 255.0
18
+ arr = arr.transpose(2, 0, 1) # HWC to CHW
19
+ tensor = torch.tensor(arr).unsqueeze(0)
20
+ return tensor * 2 - 1 # LPIPS expects [-1, 1]
21
+ gen_tensor = img_to_tensor(generated_img).to(self.device)
22
+ tgt_tensor = img_to_tensor(target_img).to(self.device)
23
+ distance = self.lpips_model(gen_tensor, tgt_tensor).item()
24
+ similarity = max(0.0, 1.0 - distance)
25
+ return similarity
mock_components.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fal_client
2
+ from weave_prompt import TextToImageModel, ImageEvaluator, PromptRefiner
3
+ from PIL import Image
4
+ import numpy as np
5
+ from typing import Dict, Any
6
+ import os
7
+ from fal_image_generator import FalImageGenerator
8
+
9
+ class MockTextToImageModel(TextToImageModel):
10
+ """Mock text-to-image model for demonstration."""
11
+
12
+ def __init__(self):
13
+ self.image_generator = FalImageGenerator()
14
+
15
+ def generate(self, prompt: str, **kwargs) -> Image.Image:
16
+ """Generate an image using the fal image generator."""
17
+ return self.image_generator.generate_image(prompt, **kwargs)
18
+
19
+ class MockImageEvaluator(ImageEvaluator):
20
+ """Mock image evaluator for demonstration."""
21
+
22
+ def generate_initial_prompt(self, target_img: Image.Image) -> str:
23
+ """Generate a mock initial prompt."""
24
+ return "A beautiful image with vibrant colors"
25
+
26
+ def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
27
+ """Mock difference analysis."""
28
+ return {
29
+ "missing_elements": ["texture", "details"],
30
+ "style_differences": ["color intensity", "composition"]
31
+ }
32
+
33
+ class MockSimilarityMetric:
34
+ """Mock similarity metric that gradually increases."""
35
+
36
+ def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
37
+ """Mock similarity computation that gradually increases."""
38
+ # Randomly increase similarity over time
39
+ return np.random.uniform(0.5, 0.95)
40
+
41
+ class MockPromptRefiner(PromptRefiner):
42
+ """Mock prompt refiner for demonstration."""
43
+
44
+ def refine_prompt(self, current_prompt: str, analysis: Dict[str, Any], similarity_score: float) -> str:
45
+ """Mock prompt refinement by adding random modifiers."""
46
+ modifiers = [
47
+ "with more detail",
48
+ "in vibrant colors",
49
+ "with better composition",
50
+ "high quality",
51
+ "masterfully crafted"
52
+ ]
53
+ return f"{current_prompt}, {np.random.choice(modifiers)}"
prompt_refiner.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ import openai
3
+ import weave
4
+ import os
5
+
6
+ from weave_prompt import PromptRefiner
7
+ import load_keys
8
+
9
+ # Weave autopatches OpenAI to log LLM calls to W&B
10
+ weave.init(project_name="meta-llama")
11
+
12
+
13
+ class LlamaPromptRefiner(PromptRefiner):
14
+ @weave.op()
15
+ def refine_prompt(self, current_prompt: str, analysis: Dict[str, Any], similarity_score):
16
+ client = openai.OpenAI(
17
+ # The custom base URL points to W&B Inference
18
+ base_url='https://api.inference.wandb.ai/v1',
19
+
20
+ # Get your API key from https://wandb.ai/authorize
21
+ # Consider setting it in the environment as OPENAI_API_KEY instead for safety
22
+ api_key=os.getenv("WANDB_API_KEY"),
23
+ )
24
+
25
+ response = client.chat.completions.create(
26
+ model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
27
+ messages=[
28
+ {
29
+ "role": "system",
30
+ "content": (
31
+ "You are an expert at prompt engineering for text-to-image models. "
32
+ "Given a current prompt and an analysis of the differences between a generated image and a target image, "
33
+ "your job is to suggest a new prompt that will make the generated image more similar to the target. "
34
+ "Limit the new prompt to 100 words at most. "
35
+ "The user message will contain two sections: one for the current prompt and one for the analysis, each delimited by 'START OF CURRENT PROMPT'/'END OF CURRENT PROMPT' and 'START OF ANALYSIS'/'END OF ANALYSIS'. "
36
+ "Only return the improved prompt."
37
+ )
38
+ },
39
+ {
40
+ "role": "user",
41
+ "content": (
42
+ f"<START OF CURRENT PROMPT>\n{current_prompt}\n<END OF CURRENT PROMPT>\n"
43
+ f"<START OF ANALYSIS>\n{str(analysis)}\n<END OF ANALYSIS>\n"
44
+ "Suggest a new, improved prompt. Only return the prompt. Do not exceed 100 words."
45
+ )
46
+ }
47
+ ],
48
+ )
49
+ return response.choices[0].message.content
pyproject.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "weaveprompt"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "lpips>=0.1.4",
9
+ "numpy>=2.3.3",
10
+ "openai>=2.3.0",
11
+ "pillow>=11.3.0",
12
+ "streamlit>=1.50.0",
13
+ "wandb>=0.22.2",
14
+ "weave>=0.52.9",
15
+ "fal-client",
16
+ "python-dotenv>=1.1.1",
17
+ ]
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
1
+ lpips>=0.1.4
2
+ numpy>=2.3.3
3
+ openai>=2.3.0
4
+ pillow>=11.3.0
5
+ streamlit>=1.50.0
6
+ wandb>=0.22.2
7
+ weave>=0.52.9
8
+ fal-client
9
+ requests
spaces_config.yml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sdk: docker
2
+ app_port: 7860
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
weave_prompt.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, List, Optional, Union
4
+ import PIL.Image as Image
5
+
6
+ class TextToImageModel(ABC):
7
+ """Abstract base class for text-to-image models."""
8
+
9
+ @abstractmethod
10
+ def generate(self, prompt: str, **kwargs) -> Image.Image:
11
+ """Generate an image from a text prompt.
12
+
13
+ Args:
14
+ prompt: The text prompt to generate from
15
+ **kwargs: Additional model-specific parameters
16
+
17
+ Returns:
18
+ A PIL Image object
19
+ """
20
+ pass
21
+
22
+ class ImageSimilarityMetric(ABC):
23
+ """Abstract base class for image similarity metrics."""
24
+ @abstractmethod
25
+ def compute(self, generated_img: Image.Image, target_img: Image.Image) -> float:
26
+ """Compute similarity score between generated and target images.
27
+ Args:
28
+ generated_img: The generated image to evaluate
29
+ target_img: The target image to compare against
30
+ Returns:
31
+ Similarity score (higher means more similar)
32
+ """
33
+ pass
34
+
35
+ class ImageEvaluator(ABC):
36
+ """Abstract base class for evaluating image similarity."""
37
+
38
+ @abstractmethod
39
+ def generate_initial_prompt(self, target_img: Image.Image) -> str:
40
+ """Generate initial prompt from target image using VLM.
41
+
42
+ Args:
43
+ target_img: The target image to analyze
44
+
45
+ Returns:
46
+ Initial prompt describing the target image
47
+ """
48
+ pass
49
+
50
+
51
+ @abstractmethod
52
+ def analyze_differences(self, generated_img: Image.Image, target_img: Image.Image) -> Dict[str, Any]:
53
+ """Analyze differences between generated and target images using VLM.
54
+
55
+ Args:
56
+ generated_img: The generated image to analyze
57
+ target_img: The target image to compare against
58
+
59
+ Returns:
60
+ Dictionary containing analysis results (e.g. missing elements, style differences)
61
+ """
62
+ pass
63
+
64
+ class PromptRefiner(ABC):
65
+ """Abstract base class for prompt refinement strategies."""
66
+
67
+ @abstractmethod
68
+ def refine_prompt(self,
69
+ current_prompt: str,
70
+ analysis: Dict[str, Any],
71
+ similarity_score: float) -> str:
72
+ """Refine the current prompt based on image analysis.
73
+
74
+ Args:
75
+ current_prompt: The current prompt PMT_i
76
+ analysis: Analysis results from ImageEvaluator
77
+ similarity_score: Current similarity score
78
+
79
+ Returns:
80
+ Refined prompt PMT_{i+1}
81
+ """
82
+ pass
83
+
84
+ class PromptOptimizer:
85
+ """Main class that orchestrates the prompt optimization process."""
86
+
87
+ def __init__(self,
88
+ model: TextToImageModel,
89
+ evaluator: ImageEvaluator,
90
+ refiner: PromptRefiner,
91
+ similarity_metric: ImageSimilarityMetric,
92
+ max_iterations: int = 10,
93
+ similarity_threshold: float = 0.95):
94
+ """Initialize the optimizer.
95
+
96
+ Args:
97
+ model: Text-to-image model to use
98
+ evaluator: Image evaluator for generating initial prompt and analysis
99
+ refiner: Prompt refinement strategy
100
+ similarity_metric: Image similarity metric
101
+ max_iterations: Maximum number of optimization iterations
102
+ similarity_threshold: Target similarity threshold for early stopping
103
+ """
104
+ # Configuration
105
+ self.model = model
106
+ self.evaluator = evaluator
107
+ self.refiner = refiner
108
+ self.similarity_metric = similarity_metric
109
+ self.max_iterations = max_iterations
110
+ self.similarity_threshold = similarity_threshold
111
+ # Optimization state
112
+ self.target_img: Optional[Image.Image] = None
113
+ self.current_prompt: Optional[str] = None
114
+ self.iteration: int = 0
115
+ # Progress tracking
116
+ self.history: List[Dict[str, Any]] = []
117
+
118
+ def initialize(self, target_img: Image.Image) -> tuple[bool, str, Image.Image]:
119
+ """Initialize the optimization process with a target image.
120
+
121
+ Args:
122
+ target_img: Target image to optimize towards
123
+ Returns:
124
+ Tuple of (is_completed, current_prompt, current_generated_image)
125
+ """
126
+ self.target_img = target_img
127
+ self.current_prompt = self.evaluator.generate_initial_prompt(target_img)
128
+ self.iteration = 0
129
+ self.history = []
130
+ return self.step()
131
+
132
+ def step(self) -> tuple[bool, str, Image.Image]:
133
+ """Perform one optimization step.
134
+
135
+ Returns:
136
+ Tuple of (is_completed, current_prompt, current_generated_image)
137
+ is_completed: True if optimization is complete (reached threshold or max iterations)
138
+ current_prompt: The current prompt
139
+ current_generated_image: The image generated from current prompt
140
+ """
141
+ if self.target_img is None or self.current_prompt is None:
142
+ raise RuntimeError("Must call initialize() before step()")
143
+ if self.iteration >= self.max_iterations:
144
+ return True, self.current_prompt, self.model.generate(self.current_prompt)
145
+ # Generate image with current prompt
146
+ generated_img = self.model.generate(self.current_prompt)
147
+ # Evaluate similarity
148
+ similarity = self.similarity_metric.compute(generated_img, self.target_img)
149
+ # Analyze differences
150
+ analysis = self.evaluator.analyze_differences(generated_img, self.target_img)
151
+ # Track progress
152
+ self.history.append({
153
+ 'iteration': self.iteration,
154
+ 'prompt': self.current_prompt,
155
+ 'similarity': similarity,
156
+ 'analysis': analysis,
157
+ 'image': generated_img
158
+ })
159
+ # Check if we've reached target similarity
160
+ is_completed = similarity >= self.similarity_threshold
161
+ if not is_completed:
162
+ # Refine prompt
163
+ self.current_prompt = self.refiner.refine_prompt(
164
+ self.current_prompt, analysis, similarity)
165
+ self.iteration += 1
166
+ return is_completed, self.current_prompt, generated_img