kevin1kevin1k commited on
Commit
1282f37
·
verified ·
1 Parent(s): b2feaca

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +365 -78
  2. image_generators.py +77 -1
  3. multi_model_optimizer.py +179 -0
app.py CHANGED
@@ -5,7 +5,8 @@ from image_evaluators import LlamaEvaluator
5
  from prompt_refiners import LlamaPromptRefiner
6
  from weave_prompt import PromptOptimizer
7
  from similarity_metrics import LPIPSImageSimilarityMetric
8
- from image_generators import FalImageGenerator
 
9
 
10
  # Load environment variables from .env file
11
  load_dotenv()
@@ -17,19 +18,28 @@ st.set_page_config(
17
  )
18
 
19
  def main():
20
- st.title("🎨 WeavePrompt: Iterative Prompt Optimization")
21
  st.markdown("""
22
- Upload a target image and watch as WeavePrompt iteratively optimizes a text prompt to recreate it.
23
  """)
24
 
 
 
 
 
 
 
 
 
25
  # Initialize session state
26
  if 'optimizer' not in st.session_state:
27
- st.session_state.optimizer = PromptOptimizer(
28
- image_generator=FalImageGenerator(),
 
29
  evaluator=LlamaEvaluator(),
30
  refiner=LlamaPromptRefiner(),
31
  similarity_metric=LPIPSImageSimilarityMetric(),
32
- max_iterations=10,
33
  similarity_threshold=0.95
34
  )
35
 
@@ -50,6 +60,71 @@ def main():
50
  if 'auto_should_step' not in st.session_state:
51
  st.session_state.auto_should_step = False
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # File uploader
54
  uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
55
 
@@ -64,25 +139,106 @@ def main():
64
 
65
  # Start button
66
  if not st.session_state.optimization_started:
67
- if st.button("Start Optimization"):
 
68
  st.session_state.optimization_started = True
69
- # Initialize optimization
70
- is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
71
- st.session_state.current_results = (is_completed, prompt, generated_image)
72
  st.rerun()
73
  else:
74
- # Show disabled button or status when optimization has started
75
- st.button("Start Optimization", disabled=True, help="Optimization in progress")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Display optimization progress
78
  if st.session_state.optimization_started:
79
  if st.session_state.current_results is not None:
 
 
 
80
  with col2:
81
- st.subheader("Generated Image")
82
- is_completed, prompt, generated_image = st.session_state.current_results
83
  st.image(generated_image, width='stretch')
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Display prompt and controls
86
  st.text_area("Current Prompt", prompt, height=100)
87
  else:
88
  # Show loading state
@@ -93,49 +249,81 @@ def main():
93
  is_completed = False
94
  prompt = ""
95
 
96
- # Auto mode controls
97
- st.subheader("Auto Mode Controls")
98
- col_auto1, col_auto2 = st.columns(2)
99
-
100
- with col_auto1:
101
- auto_mode = st.checkbox("Auto-progress steps", value=st.session_state.auto_mode)
102
- if auto_mode != st.session_state.auto_mode:
103
- st.session_state.auto_mode = auto_mode
104
- if auto_mode:
105
- st.session_state.auto_paused = False
106
- st.session_state.auto_should_step = True # Start by stepping
107
- st.rerun()
108
-
109
- with col_auto2:
110
- if st.session_state.auto_mode:
111
- if st.session_state.auto_paused:
112
- if st.button("▶️ Resume", key="resume_btn"):
113
- st.session_state.auto_paused = False
114
- st.rerun()
115
- else:
116
- if st.button("⏸️ Pause", key="pause_btn"):
117
- st.session_state.auto_paused = True
118
- st.rerun()
119
 
120
  # Progress metrics
121
- col1, col2, col3 = st.columns(3)
122
  with col1:
123
- # Show current iteration: completed steps + 1 if still in progress
124
- current_iteration = len(st.session_state.optimizer.history) + (0 if is_completed else 1)
125
- st.metric("Iteration", current_iteration)
126
  with col2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  if len(st.session_state.optimizer.history) > 0:
128
  similarity = st.session_state.optimizer.history[-1]['similarity']
129
  st.metric("Similarity", f"{similarity:.2%}")
130
- with col3:
131
- status = "Completed" if is_completed else "Paused" if st.session_state.auto_paused else "In Progress"
132
- st.metric("Status", status)
133
 
134
- # Progress bar
135
  st.subheader("Progress")
136
- max_iterations = st.session_state.optimizer.max_iterations
137
- progress_value = min(current_iteration / max_iterations, 1.0) if max_iterations > 0 else 0.0
138
- st.progress(progress_value, text=f"Step {current_iteration} of {max_iterations}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  # Next step logic
141
  if not is_completed:
@@ -181,37 +369,136 @@ def main():
181
  st.session_state.auto_paused = False
182
  st.rerun()
183
 
184
- # Display history - simple approach
185
- if len(st.session_state.optimizer.history) > 0:
186
- st.subheader(f"Optimization History ({len(st.session_state.optimizer.history)} steps)")
 
 
 
 
187
 
188
- for idx, hist_entry in enumerate(st.session_state.optimizer.history):
189
- st.markdown(f"### Step {idx + 1}")
190
- col1, col2 = st.columns([2, 3])
191
- with col1:
192
- st.image(hist_entry['image'], width='stretch')
193
- with col2:
194
- st.text(f"Similarity: {hist_entry['similarity']:.2%}")
195
- st.text("Prompt:")
196
- st.text(hist_entry['prompt'])
197
- # Toggle analysis view per history entry
198
- expand_key = f"expand_analysis_{idx}"
199
- if 'analysis_expanded' not in st.session_state:
200
- st.session_state['analysis_expanded'] = {}
201
- if expand_key not in st.session_state['analysis_expanded']:
202
- st.session_state['analysis_expanded'][expand_key] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- if st.session_state['analysis_expanded'][expand_key]:
205
- if st.button("Hide Analysis", key=f"hide_{expand_key}"):
206
- st.session_state['analysis_expanded'][expand_key] = False
207
- st.rerun()
208
- st.text("Analysis:")
209
- for key, value in hist_entry['analysis'].items():
210
- st.text(f"{key}: {value}")
211
- else:
212
- if st.button("Expand Analysis", key=expand_key):
213
- st.session_state['analysis_expanded'][expand_key] = True
214
- st.rerun()
215
 
216
  if __name__ == "__main__":
217
  main()
 
5
  from prompt_refiners import LlamaPromptRefiner
6
  from weave_prompt import PromptOptimizer
7
  from similarity_metrics import LPIPSImageSimilarityMetric
8
+ from image_generators import FalImageGenerator, MultiModelFalImageGenerator, AVAILABLE_MODELS
9
+ from multi_model_optimizer import MultiModelPromptOptimizer
10
 
11
  # Load environment variables from .env file
12
  load_dotenv()
 
18
  )
19
 
20
  def main():
21
+ st.title("🎨 WeavePrompt: Multi-Model Prompt Optimization")
22
  st.markdown("""
23
+ Upload a target image and watch as WeavePrompt optimizes prompts across multiple AI models to find the best result.
24
  """)
25
 
26
+ # Model selection state
27
+ if 'selected_models' not in st.session_state:
28
+ st.session_state.selected_models = ["FLUX.1 [pro]"] # Default selection
29
+
30
+ # Initialize max_iterations in session state if not exists
31
+ if 'max_iterations' not in st.session_state:
32
+ st.session_state.max_iterations = 2
33
+
34
  # Initialize session state
35
  if 'optimizer' not in st.session_state:
36
+ image_generator = MultiModelFalImageGenerator(st.session_state.selected_models)
37
+ st.session_state.optimizer = MultiModelPromptOptimizer(
38
+ image_generator=image_generator,
39
  evaluator=LlamaEvaluator(),
40
  refiner=LlamaPromptRefiner(),
41
  similarity_metric=LPIPSImageSimilarityMetric(),
42
+ max_iterations=st.session_state.max_iterations,
43
  similarity_threshold=0.95
44
  )
45
 
 
60
  if 'auto_should_step' not in st.session_state:
61
  st.session_state.auto_should_step = False
62
 
63
+ # Model Selection UI
64
+ st.subheader("🤖 Model Selection")
65
+ st.markdown("Choose which AI models to optimize with:")
66
+
67
+ # Organize models by category
68
+ flux_models = [k for k in AVAILABLE_MODELS.keys() if k.startswith("FLUX")]
69
+ google_models = [k for k in AVAILABLE_MODELS.keys() if k in ["Imagen 4", "Imagen 4 Ultra", "Gemini 2.5 Flash Image"]]
70
+ other_models = [k for k in AVAILABLE_MODELS.keys() if k not in flux_models and k not in google_models]
71
+
72
+ # Track if selection changed
73
+ new_selection = []
74
+
75
+ # FLUX Models Section
76
+ st.markdown("### 🔥 FLUX Models")
77
+ cols_flux = st.columns(2)
78
+ for i, model_name in enumerate(flux_models):
79
+ col_idx = i % 2
80
+ with cols_flux[col_idx]:
81
+ is_selected = model_name in st.session_state.selected_models
82
+ if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
83
+ new_selection.append(model_name)
84
+
85
+ # Google Models Section
86
+ st.markdown("### 🔍 Google Models")
87
+ cols_google = st.columns(2)
88
+ for i, model_name in enumerate(google_models):
89
+ col_idx = i % 2
90
+ with cols_google[col_idx]:
91
+ is_selected = model_name in st.session_state.selected_models
92
+ if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
93
+ new_selection.append(model_name)
94
+
95
+ # Other Models Section
96
+ if other_models:
97
+ st.markdown("### 🎨 Other Models")
98
+ cols_other = st.columns(2)
99
+ for i, model_name in enumerate(other_models):
100
+ col_idx = i % 2
101
+ with cols_other[col_idx]:
102
+ is_selected = model_name in st.session_state.selected_models
103
+ if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
104
+ new_selection.append(model_name)
105
+
106
+ # Ensure at least one model is selected
107
+ if not new_selection:
108
+ st.error("Please select at least one model!")
109
+ new_selection = ["FLUX.1 [pro]"] # Default fallback
110
+
111
+ # Update selection if changed
112
+ if set(new_selection) != set(st.session_state.selected_models):
113
+ st.session_state.selected_models = new_selection
114
+ # Recreate optimizer with new models
115
+ image_generator = MultiModelFalImageGenerator(st.session_state.selected_models)
116
+ st.session_state.optimizer = MultiModelPromptOptimizer(
117
+ image_generator=image_generator,
118
+ evaluator=LlamaEvaluator(),
119
+ refiner=LlamaPromptRefiner(),
120
+ similarity_metric=LPIPSImageSimilarityMetric(),
121
+ max_iterations=st.session_state.max_iterations,
122
+ similarity_threshold=0.95
123
+ )
124
+ st.success(f"Updated to use {len(new_selection)} model(s): {', '.join(new_selection)}")
125
+
126
+ st.markdown("---")
127
+
128
  # File uploader
129
  uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
130
 
 
139
 
140
  # Start button
141
  if not st.session_state.optimization_started:
142
+ if st.button("🚀 Start Optimization", type="primary"):
143
+ # Set state first to ensure immediate UI update
144
  st.session_state.optimization_started = True
145
+ # Force immediate rerun to show disabled state
 
 
146
  st.rerun()
147
  else:
148
+ # Show disabled button when optimization has started
149
+ st.button(" Optimization Running...", disabled=True, help="Optimization in progress", type="secondary")
150
+ st.info("💡 Optimization is running across selected models. Use the controls below to pause/resume or reset.")
151
+
152
+ # Initialize optimization after state is set (only once)
153
+ if st.session_state.current_results is None:
154
+ try:
155
+ is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
156
+ st.session_state.current_results = (is_completed, prompt, generated_image)
157
+ st.rerun()
158
+ except Exception as e:
159
+ st.error(f"Error initializing optimization: {str(e)}")
160
+ st.session_state.optimization_started = False
161
+
162
+ # Settings (always visible)
163
+ st.subheader("⚙️ Settings")
164
+
165
+ # Check if optimization is actively running (only disable settings during auto mode)
166
+ is_actively_running = (st.session_state.optimization_started and
167
+ st.session_state.auto_mode and
168
+ not st.session_state.auto_paused)
169
+
170
+ # Single row: Number of iterations, Auto-progress, and Pause/Resume controls
171
+ col_settings1, col_settings2, col_settings3 = st.columns(3)
172
+
173
+ with col_settings1:
174
+ new_max_iterations = st.number_input(
175
+ "Number of Iterations",
176
+ min_value=1,
177
+ max_value=20,
178
+ value=st.session_state.max_iterations,
179
+ help="Maximum number of optimization iterations per model",
180
+ disabled=is_actively_running
181
+ )
182
+
183
+ # Update if changed (only when not running)
184
+ if new_max_iterations != st.session_state.max_iterations and not is_actively_running:
185
+ st.session_state.max_iterations = new_max_iterations
186
+ # Update the optimizer's max_iterations
187
+ st.session_state.optimizer.max_iterations = new_max_iterations
188
+ if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer:
189
+ st.session_state.optimizer.current_optimizer.max_iterations = new_max_iterations
190
+
191
+ with col_settings2:
192
+ auto_mode = st.checkbox(
193
+ "Auto-progress steps",
194
+ value=st.session_state.auto_mode,
195
+ disabled=is_actively_running
196
+ )
197
+ if auto_mode != st.session_state.auto_mode and not is_actively_running:
198
+ st.session_state.auto_mode = auto_mode
199
+ if auto_mode:
200
+ st.session_state.auto_paused = False
201
+ st.session_state.auto_should_step = True # Start by stepping
202
+ st.rerun()
203
+
204
+ with col_settings3:
205
+ # Pause/Resume controls (only when auto mode is enabled and optimization started)
206
+ if st.session_state.auto_mode and st.session_state.optimization_started:
207
+ if st.session_state.auto_paused:
208
+ if st.button("▶️ Resume", key="resume_btn"):
209
+ st.session_state.auto_paused = False
210
+ st.rerun()
211
+ else:
212
+ if st.button("⏸️ Pause", key="pause_btn"):
213
+ st.session_state.auto_paused = True
214
+ st.rerun()
215
+ else:
216
+ # Show placeholder or empty space when not in auto mode
217
+ st.write("")
218
 
219
  # Display optimization progress
220
  if st.session_state.optimization_started:
221
  if st.session_state.current_results is not None:
222
+ is_completed, prompt, generated_image = st.session_state.current_results
223
+
224
+ # Always show the actual current results (most up-to-date)
225
  with col2:
226
+ st.subheader("Current Generated Image")
 
227
  st.image(generated_image, width='stretch')
228
+ # Show current step info
229
+ current_model_name = st.session_state.optimizer.get_current_model_name()
230
+ current_history = st.session_state.optimizer.history if hasattr(st.session_state.optimizer, 'history') else []
231
+
232
+ if current_history:
233
+ # Show info about the current state
234
+ if is_completed:
235
+ st.caption(f"🏁 {current_model_name} - Final Result")
236
+ else:
237
+ st.caption(f"🎯 {current_model_name}")
238
+ else:
239
+ st.caption(f"🚀 {current_model_name} - Initializing...")
240
 
241
+ # Display current prompt
242
  st.text_area("Current Prompt", prompt, height=100)
243
  else:
244
  # Show loading state
 
249
  is_completed = False
250
  prompt = ""
251
 
252
+ # Multi-model progress info
253
+ progress_info = st.session_state.optimizer.get_progress_info()
254
+ current_model = progress_info['current_model_name']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  # Progress metrics
257
+ col1, col2, col3, col4 = st.columns(4)
258
  with col1:
259
+ st.metric("Current Model", current_model)
 
 
260
  with col2:
261
+ # Show current model number vs total models (1-indexed for display)
262
+ current_model_num = progress_info['current_model_index'] + 1
263
+ total_models = progress_info['total_models']
264
+
265
+ # Fix: Don't let current model number exceed total models
266
+ if current_model_num > total_models:
267
+ current_model_num = total_models
268
+
269
+ st.metric("Models Progress", f"{current_model_num}/{total_models}")
270
+ with col3:
271
+ if 'current_iteration' in progress_info:
272
+ # Show completed iterations vs max iterations
273
+ current_iteration = progress_info['current_iteration']
274
+ max_iterations = progress_info['max_iterations']
275
+ # If not completed and we haven't reached max, show next iteration number
276
+ if not is_completed and current_iteration < max_iterations:
277
+ display_iteration = current_iteration + 1
278
+ st.metric("Current Iteration", f"{display_iteration}/{max_iterations}")
279
+ else:
280
+ st.metric("Current Iteration", f"{current_iteration}/{max_iterations}")
281
+ with col4:
282
  if len(st.session_state.optimizer.history) > 0:
283
  similarity = st.session_state.optimizer.history[-1]['similarity']
284
  st.metric("Similarity", f"{similarity:.2%}")
285
+ else:
286
+ status = "Completed" if is_completed else "Paused" if st.session_state.auto_paused else "In Progress"
287
+ st.metric("Status", status)
288
 
289
+ # Progress bars
290
  st.subheader("Progress")
291
+
292
+ # Overall progress across all models
293
+ overall_progress = progress_info['overall_progress']
294
+ if 'current_iteration' in progress_info and not is_completed:
295
+ # Add current model progress to overall
296
+ model_progress_fraction = progress_info['model_progress']
297
+ overall_progress = (progress_info['models_completed'] + model_progress_fraction) / progress_info['total_models']
298
+
299
+ # Create more descriptive progress labels
300
+ models_completed = progress_info['models_completed']
301
+ total_models = progress_info['total_models']
302
+
303
+ if total_models > 1:
304
+ # Multi-model scenario
305
+ if models_completed == total_models:
306
+ overall_label = f"🏁 All {total_models} models completed!"
307
+ else:
308
+ overall_label = f"🔄 Processing {current_model} (Model {models_completed + 1} of {total_models})"
309
+ else:
310
+ # Single model scenario
311
+ overall_label = f"🎯 Optimizing with {current_model}"
312
+
313
+ st.progress(overall_progress, text=overall_label)
314
+
315
+ # Current model progress
316
+ if 'current_iteration' in progress_info and not is_completed:
317
+ model_progress_value = progress_info['model_progress']
318
+ current_iter = progress_info['current_iteration']
319
+ max_iter = progress_info['max_iterations']
320
+
321
+ if current_iter == max_iter:
322
+ iteration_label = f"✅ {current_model}: Completed all {max_iter} iterations"
323
+ else:
324
+ iteration_label = f"⚡ {current_model}: Step {current_iter + 1} of {max_iter}"
325
+
326
+ st.progress(model_progress_value, text=iteration_label)
327
 
328
  # Next step logic
329
  if not is_completed:
 
369
  st.session_state.auto_paused = False
370
  st.rerun()
371
 
372
+ # Display multi-model history with tabs
373
+ _display_multi_model_history()
374
+
375
+ # Model Results Comparison (show when optimization is complete)
376
+ if is_completed and hasattr(st.session_state.optimizer, 'get_all_results'):
377
+ all_results = st.session_state.optimizer.get_all_results()
378
+ best_result = st.session_state.optimizer.get_best_result()
379
 
380
+ if all_results:
381
+ st.subheader("🏆 Model Comparison Results")
382
+
383
+ # Show best result prominently
384
+ if best_result:
385
+ st.success(f"🥇 **Best Result**: {best_result['model_name']} with {best_result['similarity']:.2%} similarity")
386
+
387
+ # Create columns for each model result
388
+ num_models = len(all_results)
389
+ if num_models > 0:
390
+ cols = st.columns(min(num_models, 3)) # Max 3 columns
391
+
392
+ # Sort results by similarity (best first)
393
+ sorted_results = sorted(all_results.items(),
394
+ key=lambda x: x[1]['final_similarity'],
395
+ reverse=True)
396
+
397
+ for i, (model_name, result) in enumerate(sorted_results):
398
+ col_idx = i % 3
399
+ with cols[col_idx]:
400
+ # Add medal emoji for top 3
401
+ medal = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else "🏅"
402
+ st.markdown(f"### {medal} {model_name}")
403
+
404
+ # Show final image
405
+ st.image(result['final_image'], width='stretch')
406
+
407
+ # Show metrics
408
+ st.metric("Final Similarity", f"{result['final_similarity']:.2%}")
409
+ st.metric("Iterations", result['iterations'])
410
+
411
+ # Show final prompt in expander
412
+ with st.expander("View Final Prompt"):
413
+ st.text(result['final_prompt'])
414
+
415
+
416
+ def _display_multi_model_history():
417
+ """Display optimization history with tabs for multi-model scenarios."""
418
+
419
+ # Get all completed model results
420
+ all_model_results = {}
421
+ if hasattr(st.session_state.optimizer, 'get_all_results'):
422
+ all_model_results = st.session_state.optimizer.get_all_results()
423
+
424
+ # Build separate histories for each model
425
+ all_histories = {}
426
+
427
+ # Add completed model histories - each model gets its own stored history
428
+ for model_name, result in all_model_results.items():
429
+ if 'history' in result and result['history']:
430
+ # Make a deep copy to ensure complete separation
431
+ all_histories[model_name] = [step.copy() for step in result['history']]
432
+
433
+ # Add current model's in-progress history (if any)
434
+ if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer:
435
+ current_history = st.session_state.optimizer.current_optimizer.history
436
+ if current_history:
437
+ current_model_name = st.session_state.optimizer.get_current_model_name()
438
+ # Only add if this model doesn't already have a completed history
439
+ if current_model_name not in all_histories:
440
+ all_histories[current_model_name] = [step.copy() for step in current_history]
441
+
442
+ # Display histories
443
+ if all_histories:
444
+ st.subheader("📊 Optimization History")
445
+
446
+ if len(all_histories) == 1:
447
+ # Single model - no tabs needed
448
+ model_name = list(all_histories.keys())[0]
449
+ history = all_histories[model_name]
450
+ st.markdown(f"**{model_name}** ({len(history)} steps)")
451
+ _display_model_history(history, model_name)
452
+ else:
453
+ # Multi-model - use tabs
454
+ tab_names = []
455
+ for model_name, history in all_histories.items():
456
+ tab_names.append(f"{model_name} ({len(history)} steps)")
457
+
458
+ tabs = st.tabs(tab_names)
459
+
460
+ for i, (model_name, history) in enumerate(all_histories.items()):
461
+ with tabs[i]:
462
+ _display_model_history(history, model_name)
463
+
464
+
465
+ def _display_model_history(history, model_name):
466
+ """Display history for a single model."""
467
+ if not history:
468
+ st.info(f"No optimization steps yet for {model_name}")
469
+ return
470
+
471
+ for idx, hist_entry in enumerate(history):
472
+ st.markdown(f"### Step {idx + 1}")
473
+ col1, col2 = st.columns([2, 3])
474
+
475
+ with col1:
476
+ st.image(hist_entry['image'], width='stretch')
477
+
478
+ with col2:
479
+ st.text(f"Similarity: {hist_entry['similarity']:.2%}")
480
+ st.text("Prompt:")
481
+ st.text(hist_entry['prompt'])
482
+
483
+ # Toggle analysis view per history entry
484
+ expand_key = f"expand_analysis_{model_name}_{idx}"
485
+ if 'analysis_expanded' not in st.session_state:
486
+ st.session_state['analysis_expanded'] = {}
487
+ if expand_key not in st.session_state['analysis_expanded']:
488
+ st.session_state['analysis_expanded'][expand_key] = False
489
+
490
+ if st.session_state['analysis_expanded'][expand_key]:
491
+ if st.button("Hide Analysis", key=f"hide_{expand_key}"):
492
+ st.session_state['analysis_expanded'][expand_key] = False
493
+ st.rerun()
494
+ st.text("Analysis:")
495
+ for key, value in hist_entry['analysis'].items():
496
+ st.text(f"{key}: {value}")
497
+ else:
498
+ if st.button("Expand Analysis", key=expand_key):
499
+ st.session_state['analysis_expanded'][expand_key] = True
500
+ st.rerun()
501
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  if __name__ == "__main__":
504
  main()
image_generators.py CHANGED
@@ -5,10 +5,29 @@ import requests
5
  from io import BytesIO
6
 
7
  from weave_prompt import ImageGenerator
 
8
 
9
  from dotenv import load_dotenv
10
  load_dotenv()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class FalImageGenerator(ImageGenerator):
13
  """Handles image generation using fal_client."""
14
 
@@ -45,4 +64,61 @@ class FalImageGenerator(ImageGenerator):
45
  image = Image.open(BytesIO(response.content))
46
  return image
47
  else:
48
- raise ValueError("No image found in the result")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from io import BytesIO
6
 
7
  from weave_prompt import ImageGenerator
8
+ from typing import List, Tuple
9
 
10
  from dotenv import load_dotenv
11
  load_dotenv()
12
 
13
+ # Available fal.ai models for text-to-image generation
14
+ AVAILABLE_MODELS = {
15
+ # FLUX Models
16
+ "FLUX.1 [pro]": "fal-ai/flux-pro",
17
+ "FLUX.1 [dev]": "fal-ai/flux/dev",
18
+ "FLUX.1 [schnell]": "fal-ai/flux/schnell",
19
+ "FLUX.1 with LoRAs": "fal-ai/flux-lora",
20
+
21
+ # Google Models
22
+ "Imagen 4": "fal-ai/imagen4/preview",
23
+ "Imagen 4 Ultra": "fal-ai/imagen4/preview/ultra",
24
+ "Gemini 2.5 Flash Image": "fal-ai/gemini-25-flash-image",
25
+
26
+ # Other Models
27
+ "Stable Diffusion 3.5 Large": "fal-ai/stable-diffusion-v35-large",
28
+ "Qwen Image": "fal-ai/qwen-image"
29
+ }
30
+
31
  class FalImageGenerator(ImageGenerator):
32
  """Handles image generation using fal_client."""
33
 
 
64
  image = Image.open(BytesIO(response.content))
65
  return image
66
  else:
67
+ raise ValueError("No image found in the result")
68
+
69
+
70
+ class MultiModelFalImageGenerator(ImageGenerator):
71
+ """Handles image generation using multiple fal.ai models."""
72
+
73
+ def __init__(self, selected_models: List[str] = None):
74
+ """Initialize with selected model names.
75
+
76
+ Args:
77
+ selected_models: List of model display names from AVAILABLE_MODELS keys
78
+ """
79
+ if selected_models is None:
80
+ selected_models = ["FLUX.1 [pro]"] # Default to single model
81
+
82
+ self.selected_models = selected_models
83
+ self.current_model_index = 0
84
+ self.generators = {}
85
+
86
+ # Create individual generators for each selected model
87
+ for model_name in selected_models:
88
+ if model_name in AVAILABLE_MODELS:
89
+ model_id = AVAILABLE_MODELS[model_name]
90
+ self.generators[model_name] = FalImageGenerator(model_id)
91
+
92
+ def get_current_model_name(self) -> str:
93
+ """Get the name of the currently active model."""
94
+ if self.current_model_index < len(self.selected_models):
95
+ return self.selected_models[self.current_model_index]
96
+ return self.selected_models[0] if self.selected_models else "Unknown"
97
+
98
+ def switch_to_next_model(self) -> bool:
99
+ """Switch to the next model in the sequence.
100
+
101
+ Returns:
102
+ True if switched to next model, False if no more models
103
+ """
104
+ self.current_model_index += 1
105
+ return self.current_model_index < len(self.selected_models)
106
+
107
+ def reset_to_first_model(self):
108
+ """Reset to the first model in the sequence."""
109
+ self.current_model_index = 0
110
+
111
+ def generate(self, prompt: str, **kwargs) -> Image.Image:
112
+ """Generate an image using the current model."""
113
+ current_model = self.get_current_model_name()
114
+ if current_model in self.generators:
115
+ return self.generators[current_model].generate(prompt, **kwargs)
116
+ else:
117
+ raise ValueError(f"Model {current_model} not available")
118
+
119
+ def generate_with_model(self, model_name: str, prompt: str, **kwargs) -> Image.Image:
120
+ """Generate an image using a specific model."""
121
+ if model_name in self.generators:
122
+ return self.generators[model_name].generate(prompt, **kwargs)
123
+ else:
124
+ raise ValueError(f"Model {model_name} not available")
multi_model_optimizer.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Tuple
2
+ from PIL import Image
3
+ from weave_prompt import PromptOptimizer, ImageEvaluator, PromptRefiner, ImageSimilarityMetric
4
+ from image_generators import MultiModelFalImageGenerator
5
+
6
+
7
+ class MultiModelPromptOptimizer:
8
+ """Sequential multi-model prompt optimizer that finds the best model-prompt combination."""
9
+
10
+ def __init__(self,
11
+ image_generator: MultiModelFalImageGenerator,
12
+ evaluator: ImageEvaluator,
13
+ refiner: PromptRefiner,
14
+ similarity_metric: ImageSimilarityMetric,
15
+ max_iterations: int = 10,
16
+ similarity_threshold: float = 0.95):
17
+ """Initialize the multi-model optimizer.
18
+
19
+ Args:
20
+ image_generator: Multi-model image generator
21
+ evaluator: Image evaluator for generating initial prompt and analysis
22
+ refiner: Prompt refinement strategy
23
+ similarity_metric: Image similarity metric
24
+ max_iterations: Maximum number of optimization iterations per model
25
+ similarity_threshold: Target similarity threshold for early stopping
26
+ """
27
+ self.image_generator = image_generator
28
+ self.evaluator = evaluator
29
+ self.refiner = refiner
30
+ self.similarity_metric = similarity_metric
31
+ self.max_iterations = max_iterations
32
+ self.similarity_threshold = similarity_threshold
33
+
34
+ # Multi-model state
35
+ self.target_img = None
36
+ self.current_model_index = 0
37
+ self.model_results = {} # Results per model
38
+ self.current_optimizer = None
39
+ self.best_result = None
40
+
41
+ # Initialize individual optimizers for each model
42
+ self._create_current_optimizer()
43
+
44
+ def _create_current_optimizer(self):
45
+ """Create optimizer for the current model."""
46
+ if self.current_model_index < len(self.image_generator.selected_models):
47
+ # Set the image generator to current model
48
+ self.image_generator.current_model_index = self.current_model_index
49
+
50
+ # Create individual optimizer for current model
51
+ self.current_optimizer = PromptOptimizer(
52
+ image_generator=self.image_generator,
53
+ evaluator=self.evaluator,
54
+ refiner=self.refiner,
55
+ similarity_metric=self.similarity_metric,
56
+ max_iterations=self.max_iterations,
57
+ similarity_threshold=self.similarity_threshold
58
+ )
59
+
60
+ def get_current_model_name(self) -> str:
61
+ """Get the name of the currently active model."""
62
+ # Ensure the image generator index is synchronized
63
+ self.image_generator.current_model_index = self.current_model_index
64
+ return self.image_generator.get_current_model_name()
65
+
66
+ def get_progress_info(self) -> Dict[str, Any]:
67
+ """Get current progress information."""
68
+ total_models = len(self.image_generator.selected_models)
69
+ current_model = self.current_model_index + 1
70
+
71
+ info = {
72
+ 'current_model_index': self.current_model_index,
73
+ 'current_model_name': self.get_current_model_name(),
74
+ 'total_models': total_models,
75
+ 'models_completed': self.current_model_index,
76
+ 'overall_progress': self.current_model_index / total_models if total_models > 0 else 0,
77
+ 'is_last_model': self.current_model_index >= total_models - 1
78
+ }
79
+
80
+ if self.current_optimizer:
81
+ info['current_iteration'] = len(self.current_optimizer.history)
82
+ info['max_iterations'] = self.max_iterations
83
+ info['model_progress'] = len(self.current_optimizer.history) / self.max_iterations
84
+
85
+ return info
86
+
87
+ def initialize(self, target_img: Image.Image) -> Tuple[bool, str, Image.Image]:
88
+ """Initialize the multi-model optimization process.
89
+
90
+ Args:
91
+ target_img: Target image to optimize towards
92
+ Returns:
93
+ Tuple of (is_completed, current_prompt, current_generated_image)
94
+ """
95
+ self.target_img = target_img
96
+ self.current_model_index = 0
97
+ self.model_results = {}
98
+ self.best_result = None
99
+
100
+ # Reset image generator to first model
101
+ self.image_generator.reset_to_first_model()
102
+ self._create_current_optimizer()
103
+
104
+ # Initialize first model
105
+ return self.current_optimizer.initialize(target_img)
106
+
107
+ def step(self) -> Tuple[bool, str, Image.Image]:
108
+ """Perform one optimization step.
109
+
110
+ Returns:
111
+ Tuple of (is_completed, current_prompt, current_generated_image)
112
+ """
113
+ if not self.current_optimizer:
114
+ raise RuntimeError("Must call initialize() before step()")
115
+
116
+ # Step the current model optimizer
117
+ is_model_completed, prompt, generated_image = self.current_optimizer.step()
118
+
119
+ if is_model_completed:
120
+ # Store results for current model - use data from history to ensure consistency
121
+ model_name = self.get_current_model_name()
122
+
123
+ if len(self.current_optimizer.history) > 0:
124
+ # Use the last step from history as the final result (ensures consistency)
125
+ last_step = self.current_optimizer.history[-1]
126
+ final_prompt = last_step['prompt']
127
+ final_image = last_step['image']
128
+ final_similarity = last_step['similarity']
129
+ else:
130
+ # Fallback to step results if no history (shouldn't happen)
131
+ final_prompt = prompt
132
+ final_image = generated_image
133
+ final_similarity = 0.0
134
+
135
+ self.model_results[model_name] = {
136
+ 'final_prompt': final_prompt,
137
+ 'final_image': final_image,
138
+ 'final_similarity': final_similarity,
139
+ 'history': self.current_optimizer.history.copy(),
140
+ 'iterations': len(self.current_optimizer.history)
141
+ }
142
+
143
+ # Update best result if this is better
144
+ if self.best_result is None or final_similarity > self.best_result['similarity']:
145
+ self.best_result = {
146
+ 'model_name': model_name,
147
+ 'prompt': final_prompt,
148
+ 'image': final_image,
149
+ 'similarity': final_similarity
150
+ }
151
+
152
+ # Move to next model
153
+ self.current_model_index += 1
154
+
155
+ if self.current_model_index < len(self.image_generator.selected_models):
156
+ # Initialize next model - ensure both indices are synchronized
157
+ self.image_generator.current_model_index = self.current_model_index
158
+ self._create_current_optimizer()
159
+ return self.current_optimizer.initialize(self.target_img)
160
+ else:
161
+ # All models completed - return best result
162
+ return True, self.best_result['prompt'], self.best_result['image']
163
+
164
+ return is_model_completed, prompt, generated_image
165
+
166
+ def get_all_results(self) -> Dict[str, Dict[str, Any]]:
167
+ """Get results from all completed models."""
168
+ return self.model_results.copy()
169
+
170
+ def get_best_result(self) -> Dict[str, Any]:
171
+ """Get the best result across all models."""
172
+ return self.best_result.copy() if self.best_result else None
173
+
174
+ @property
175
+ def history(self):
176
+ """Get history from current optimizer for compatibility."""
177
+ if self.current_optimizer:
178
+ return self.current_optimizer.history
179
+ return []