File size: 24,839 Bytes
c6eb9ce
 
e7247e4
fb2f0a7
 
c6eb9ce
fb2f0a7
1282f37
 
c6eb9ce
e7247e4
 
 
c6eb9ce
4584c11
c6eb9ce
 
 
 
 
1282f37
c6eb9ce
1282f37
c6eb9ce
 
1282f37
 
 
 
 
 
 
 
c6eb9ce
 
1282f37
 
 
c6eb9ce
 
 
1282f37
c6eb9ce
 
 
 
 
 
 
 
b3e00f3
 
 
 
 
 
 
 
 
 
 
c6eb9ce
1282f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6eb9ce
 
 
 
 
 
 
 
 
 
 
 
 
 
1282f37
 
c6eb9ce
1282f37
b3e00f3
 
1282f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c93286
 
 
1282f37
 
 
 
 
 
 
 
 
 
 
c6eb9ce
 
 
b3e00f3
1282f37
 
2c93286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3e00f3
1282f37
2c93286
1282f37
 
 
 
 
 
 
 
 
2c93286
 
 
b3e00f3
 
 
 
 
2c93286
 
b3e00f3
 
c6eb9ce
1282f37
 
2c93286
 
c6eb9ce
 
1282f37
c6eb9ce
1282f37
c6eb9ce
1282f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6eb9ce
 
 
1282f37
 
 
c6eb9ce
1282f37
b3e00f3
1282f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3e00f3
 
c6eb9ce
b3e00f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6eb9ce
 
b3e00f3
 
 
 
c6eb9ce
 
 
 
 
b3e00f3
 
c6eb9ce
 
1282f37
 
 
 
 
 
 
b3e00f3
1282f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fb9af6
c6eb9ce
 
5fb9af6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
import streamlit as st
from PIL import Image
from dotenv import load_dotenv
from image_evaluators import LlamaEvaluator
from prompt_refiners import LlamaPromptRefiner
from weave_prompt import PromptOptimizer
from similarity_metrics import LPIPSImageSimilarityMetric
from image_generators import FalImageGenerator, MultiModelFalImageGenerator, AVAILABLE_MODELS
from multi_model_optimizer import MultiModelPromptOptimizer

# Load environment variables from .env file
load_dotenv()

st.set_page_config(
    page_title="WeavePrompt",
    page_icon="🎨",
    layout="wide"
)

def main():
    st.title("🎨 WeavePrompt: Multi-Model Prompt Optimization")
    st.markdown("""
    Upload a target image and watch as WeavePrompt optimizes prompts across multiple AI models to find the best result.
    """)
    
    # Model selection state
    if 'selected_models' not in st.session_state:
        st.session_state.selected_models = ["FLUX.1 [pro]"]  # Default selection
    
    # Initialize max_iterations in session state if not exists
    if 'max_iterations' not in st.session_state:
        st.session_state.max_iterations = 2
    
    # Initialize session state
    if 'optimizer' not in st.session_state:
        image_generator = MultiModelFalImageGenerator(st.session_state.selected_models)
        st.session_state.optimizer = MultiModelPromptOptimizer(
            image_generator=image_generator,
            evaluator=LlamaEvaluator(),
            refiner=LlamaPromptRefiner(),
            similarity_metric=LPIPSImageSimilarityMetric(),
            max_iterations=st.session_state.max_iterations,
            similarity_threshold=0.95
        )
    
    if 'optimization_started' not in st.session_state:
        st.session_state.optimization_started = False
    
    if 'current_results' not in st.session_state:
        st.session_state.current_results = None
    
    # Auto mode state
    if 'auto_mode' not in st.session_state:
        st.session_state.auto_mode = False
    
    if 'auto_paused' not in st.session_state:
        st.session_state.auto_paused = False
    
    # Auto mode step control - use this to control when to step vs when to display
    if 'auto_should_step' not in st.session_state:
        st.session_state.auto_should_step = False

    # Model Selection UI
    st.subheader("πŸ€– Model Selection")
    st.markdown("Choose which AI models to optimize with:")
    
    # Organize models by category
    flux_models = [k for k in AVAILABLE_MODELS.keys() if k.startswith("FLUX")]
    google_models = [k for k in AVAILABLE_MODELS.keys() if k in ["Imagen 4", "Imagen 4 Ultra", "Gemini 2.5 Flash Image"]]
    other_models = [k for k in AVAILABLE_MODELS.keys() if k not in flux_models and k not in google_models]
    
    # Track if selection changed
    new_selection = []
    
    # FLUX Models Section
    st.markdown("### πŸ”₯ FLUX Models")
    cols_flux = st.columns(2)
    for i, model_name in enumerate(flux_models):
        col_idx = i % 2
        with cols_flux[col_idx]:
            is_selected = model_name in st.session_state.selected_models
            if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
                new_selection.append(model_name)
    
    # Google Models Section  
    st.markdown("### πŸ” Google Models")
    cols_google = st.columns(2)
    for i, model_name in enumerate(google_models):
        col_idx = i % 2
        with cols_google[col_idx]:
            is_selected = model_name in st.session_state.selected_models
            if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
                new_selection.append(model_name)
    
    # Other Models Section
    if other_models:
        st.markdown("### 🎨 Other Models")
        cols_other = st.columns(2)
        for i, model_name in enumerate(other_models):
            col_idx = i % 2
            with cols_other[col_idx]:
                is_selected = model_name in st.session_state.selected_models
                if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
                    new_selection.append(model_name)
    
    # Ensure at least one model is selected
    if not new_selection:
        st.error("Please select at least one model!")
        new_selection = ["FLUX.1 [pro]"]  # Default fallback
    
    # Update selection if changed
    if set(new_selection) != set(st.session_state.selected_models):
        st.session_state.selected_models = new_selection
        # Recreate optimizer with new models
        image_generator = MultiModelFalImageGenerator(st.session_state.selected_models)
        st.session_state.optimizer = MultiModelPromptOptimizer(
            image_generator=image_generator,
            evaluator=LlamaEvaluator(),
            refiner=LlamaPromptRefiner(),
            similarity_metric=LPIPSImageSimilarityMetric(),
            max_iterations=st.session_state.max_iterations,
            similarity_threshold=0.95
        )
        st.success(f"Updated to use {len(new_selection)} model(s): {', '.join(new_selection)}")
    
    st.markdown("---")

    # File uploader
    uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
    
    if uploaded_file is not None:
        # Display target image
        target_image = Image.open(uploaded_file)
        
        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Target Image")
            st.image(target_image, width='stretch')
        
        # Start button
        if not st.session_state.optimization_started:
            if st.button("πŸš€ Start Optimization", type="primary"):
                # Set state first to ensure immediate UI update
                st.session_state.optimization_started = True
                # Force immediate rerun to show disabled state
                st.rerun()
        else:
            # Show disabled button when optimization has started
            st.button("⏳ Optimization Running...", disabled=True, help="Optimization in progress", type="secondary")
            st.info("πŸ’‘ Optimization is running across selected models. Use the controls below to pause/resume or reset.")
            
            # Initialize optimization after state is set (only once)
            if st.session_state.current_results is None:
                try:
                    is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
                    st.session_state.current_results = (is_completed, prompt, generated_image)
                    st.rerun()
                except Exception as e:
                    st.error(f"Error initializing optimization: {str(e)}")
                    st.session_state.optimization_started = False
        
        # Settings (always visible)
        st.subheader("βš™οΈ Settings")
        
        # Check if optimization is actively running (only disable settings during auto mode)
        is_actively_running = (st.session_state.optimization_started and 
                              st.session_state.auto_mode and 
                              not st.session_state.auto_paused)
        
        # Single row: Number of iterations, Auto-progress, and Pause/Resume controls
        col_settings1, col_settings2, col_settings3 = st.columns(3)
        
        with col_settings1:
            new_max_iterations = st.number_input(
                "Number of Iterations", 
                min_value=1, 
                max_value=20, 
                value=st.session_state.max_iterations,
                help="Maximum number of optimization iterations per model",
                disabled=is_actively_running
            )
            
            # Update if changed (only when not running)
            if new_max_iterations != st.session_state.max_iterations and not is_actively_running:
                st.session_state.max_iterations = new_max_iterations
                # Update the optimizer's max_iterations
                st.session_state.optimizer.max_iterations = new_max_iterations
                if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer:
                    st.session_state.optimizer.current_optimizer.max_iterations = new_max_iterations
        
        with col_settings2:
            auto_mode = st.checkbox(
                "Auto-progress steps", 
                value=st.session_state.auto_mode,
                disabled=is_actively_running
            )
            if auto_mode != st.session_state.auto_mode and not is_actively_running:
                st.session_state.auto_mode = auto_mode
                if auto_mode:
                    st.session_state.auto_paused = False
                    st.session_state.auto_should_step = True  # Start by stepping
                    st.rerun()
        
        with col_settings3:
            # Pause/Resume controls (only when auto mode is enabled, optimization started, and not completed)
            optimization_completed = st.session_state.current_results is not None and st.session_state.current_results[0]
            if st.session_state.auto_mode and st.session_state.optimization_started and not optimization_completed:
                if st.session_state.auto_paused:
                    if st.button("▢️ Resume", key="resume_btn"):
                        st.session_state.auto_paused = False
                        st.rerun()
                else:
                    if st.button("⏸️ Pause", key="pause_btn"):
                        st.session_state.auto_paused = True
                        st.rerun()
            else:
                # Show placeholder or empty space when not in auto mode
                st.write("")
        
        # Display optimization progress
        if st.session_state.optimization_started:
            if st.session_state.current_results is not None:
                is_completed, prompt, generated_image = st.session_state.current_results
                
                # Get current model info
                current_model_name = st.session_state.optimizer.get_current_model_name()
                current_history = st.session_state.optimizer.history if hasattr(st.session_state.optimizer, 'history') else []
                
                # If we have current model history, show its latest result instead of current_results
                if current_history:
                    latest_step = current_history[-1]
                    display_image = latest_step['image']
                    display_prompt = latest_step['prompt']
                else:
                    # Fallback to current_results
                    display_image = generated_image
                    display_prompt = prompt
                
                # Always show the actual current model's latest result
                with col2:
                    st.subheader("Current Generated Image")
                    st.image(display_image, width='stretch')
                    
                    if current_history:
                        # Show info about the current state
                        if is_completed:
                            st.caption(f"🏁 {current_model_name} - Final Result")
                        else:
                            st.caption(f"🎯 {current_model_name}")
                    else:
                        st.caption(f"πŸš€ {current_model_name} - Initializing...")
                    
                    # Display current prompt under the image
                    st.text_area("Current Prompt", display_prompt, height=100)
            else:
                # Show loading state
                with col2:
                    st.subheader("Generated Image")
                    st.info("Initializing optimization...")
                    # Display loading prompt under the placeholder
                    st.text_area("Current Prompt", "Generating initial prompt...", height=100)
                is_completed = False
                prompt = ""
            
            # Multi-model progress info
            progress_info = st.session_state.optimizer.get_progress_info()
            # Ensure we get the actual current model name (not cached)
            current_model = st.session_state.optimizer.get_current_model_name()
            
            # Progress metrics
            col1, col2, col3, col4 = st.columns(4)
            with col1:
                st.metric("Current Model", current_model)
            with col2:
                # Show current model number vs total models (1-indexed for display)
                current_model_num = progress_info['current_model_index'] + 1
                total_models = progress_info['total_models']
                
                # Fix: Don't let current model number exceed total models
                if current_model_num > total_models:
                    current_model_num = total_models
                
                st.metric("Models Progress", f"{current_model_num}/{total_models}")
            with col3:
                if 'current_iteration' in progress_info:
                    # Show completed iterations vs max iterations
                    current_iteration = progress_info['current_iteration']
                    max_iterations = progress_info['max_iterations']
                    # If not completed and we haven't reached max, show next iteration number
                    if not is_completed and current_iteration < max_iterations:
                        display_iteration = current_iteration + 1
                        st.metric("Current Iteration", f"{display_iteration}/{max_iterations}")
                    else:
                        st.metric("Current Iteration", f"{current_iteration}/{max_iterations}")
            with col4:
                if len(st.session_state.optimizer.history) > 0:
                    similarity = st.session_state.optimizer.history[-1]['similarity']
                    st.metric("Similarity", f"{similarity:.2%}")
                else:
                    status = "Completed" if is_completed else "Paused" if st.session_state.auto_paused else "In Progress"
                    st.metric("Status", status)
            
            # Progress bars
            st.subheader("Progress")
            
            # Overall progress across all models
            overall_progress = progress_info['overall_progress']
            if 'current_iteration' in progress_info and not is_completed:
                # Add current model progress to overall
                model_progress_fraction = progress_info['model_progress']
                overall_progress = (progress_info['models_completed'] + model_progress_fraction) / progress_info['total_models']
            
            # Create more descriptive progress labels
            models_completed = progress_info['models_completed']
            total_models = progress_info['total_models']
            
            if total_models > 1:
                # Multi-model scenario
                if models_completed == total_models:
                    overall_label = f"🏁 All {total_models} models completed!"
                else:
                    overall_label = f"πŸ”„ Processing {current_model} (Model {models_completed + 1} of {total_models})"
            else:
                # Single model scenario
                overall_label = f"🎯 Optimizing with {current_model}"
            
            st.progress(overall_progress, text=overall_label)
            
            # Current model progress
            if 'current_iteration' in progress_info and not is_completed:
                model_progress_value = progress_info['model_progress']
                current_iter = progress_info['current_iteration']
                max_iter = progress_info['max_iterations']
                
                if current_iter == max_iter:
                    iteration_label = f"βœ… {current_model}: Completed all {max_iter} iterations"
                else:
                    iteration_label = f"⚑ {current_model}: Step {current_iter + 1} of {max_iter}"
                
                st.progress(model_progress_value, text=iteration_label)
            
            # Next step logic
            if not is_completed:
                # Auto mode logic - mimic pause button behavior
                if st.session_state.auto_mode and not st.session_state.auto_paused:
                    if st.session_state.auto_should_step:
                        # Execute the step
                        is_completed, prompt, generated_image = st.session_state.optimizer.step()
                        st.session_state.current_results = (is_completed, prompt, generated_image)
                        # Set flag to NOT step on next render (let history display)
                        st.session_state.auto_should_step = False
                        st.rerun()
                    else:
                        # Don't step, just display current state and history, then set flag to step next time
                        st.session_state.auto_should_step = True
                        # Use a small delay then rerun to continue auto mode
                        import time
                        time.sleep(0.5)  # Give user time to see the history
                        st.rerun()
                
                # Manual mode
                elif not st.session_state.auto_mode:
                    if st.button("Next Step"):
                        is_completed, prompt, generated_image = st.session_state.optimizer.step()
                        st.session_state.current_results = (is_completed, prompt, generated_image)
                        st.rerun()
                
                # Show status when auto mode is paused
                elif st.session_state.auto_paused:
                    st.info("Auto mode is paused. Click Resume to continue or uncheck Auto-progress to use manual mode.")
            else:
                st.success("Optimization completed! Click 'Reset' to try another image.")
                # Turn off auto mode when completed
                if st.session_state.auto_mode:
                    st.session_state.auto_mode = False
                    st.session_state.auto_paused = False
            
            # Reset button
            if st.button("Reset"):
                st.session_state.optimization_started = False
                st.session_state.current_results = None
                st.session_state.auto_mode = False
                st.session_state.auto_paused = False
                st.rerun()
            
            # Display multi-model history with tabs
            _display_multi_model_history()
            
            # Model Results Comparison (show when optimization is complete)
            if is_completed and hasattr(st.session_state.optimizer, 'get_all_results'):
                all_results = st.session_state.optimizer.get_all_results()
                best_result = st.session_state.optimizer.get_best_result()
                
                if all_results:
                    st.subheader("πŸ† Model Comparison Results")
                    
                    # Show best result prominently
                    if best_result:
                        st.success(f"πŸ₯‡ **Best Result**: {best_result['model_name']} with {best_result['similarity']:.2%} similarity")
                    
                    # Create columns for each model result
                    num_models = len(all_results)
                    if num_models > 0:
                        cols = st.columns(min(num_models, 3))  # Max 3 columns
                        
                        # Sort results by similarity (best first)
                        sorted_results = sorted(all_results.items(), 
                                              key=lambda x: x[1]['final_similarity'], 
                                              reverse=True)
                        
                        for i, (model_name, result) in enumerate(sorted_results):
                            col_idx = i % 3
                            with cols[col_idx]:
                                # Add medal emoji for top 3
                                medal = "πŸ₯‡" if i == 0 else "πŸ₯ˆ" if i == 1 else "πŸ₯‰" if i == 2 else "πŸ…"
                                st.markdown(f"### {medal} {model_name}")
                                
                                # Show final image
                                st.image(result['final_image'], width='stretch')
                                
                                # Show metrics
                                st.metric("Final Similarity", f"{result['final_similarity']:.2%}")
                                st.metric("Iterations", result['iterations'])
                                
                                # Show final prompt in expander
                                with st.expander("View Final Prompt"):
                                    st.text(result['final_prompt'])


def _display_multi_model_history():
    """Display optimization history with tabs for multi-model scenarios."""
    
    # Get all completed model results
    all_model_results = {}
    if hasattr(st.session_state.optimizer, 'get_all_results'):
        all_model_results = st.session_state.optimizer.get_all_results()
    
    # Build separate histories for each model
    all_histories = {}
    
    # Add completed model histories - each model gets its own stored history
    for model_name, result in all_model_results.items():
        if 'history' in result and result['history']:
            # Make a deep copy to ensure complete separation
            all_histories[model_name] = [step.copy() for step in result['history']]
    
    # Add current model's in-progress history (if any)
    if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer:
        current_history = st.session_state.optimizer.current_optimizer.history
        if current_history:
            current_model_name = st.session_state.optimizer.get_current_model_name()
            # Only add if this model doesn't already have a completed history
            if current_model_name not in all_histories:
                all_histories[current_model_name] = [step.copy() for step in current_history]
    
    # Display histories
    if all_histories:
        st.subheader("πŸ“Š Optimization History")
        
        if len(all_histories) == 1:
            # Single model - no tabs needed
            model_name = list(all_histories.keys())[0]
            history = all_histories[model_name]
            st.markdown(f"**{model_name}** ({len(history)} steps)")
            _display_model_history(history, model_name)
        else:
            # Multi-model - use tabs
            tab_names = []
            for model_name, history in all_histories.items():
                tab_names.append(f"{model_name} ({len(history)} steps)")
            
            tabs = st.tabs(tab_names)
            
            for i, (model_name, history) in enumerate(all_histories.items()):
                with tabs[i]:
                    _display_model_history(history, model_name)


def _display_model_history(history, model_name):
    """Display history for a single model."""
    if not history:
        st.info(f"No optimization steps yet for {model_name}")
        return
    
    for idx, hist_entry in enumerate(history):
        st.markdown(f"### Step {idx + 1}")
        col1, col2 = st.columns([2, 3])
        
        with col1:
            st.image(hist_entry['image'], width='stretch')
        
        with col2:
            st.text(f"Similarity: {hist_entry['similarity']:.2%}")
            st.text("Prompt:")
            st.text(hist_entry['prompt'])
            
            # Toggle analysis view per history entry
            expand_key = f"expand_analysis_{model_name}_{idx}"
            if 'analysis_expanded' not in st.session_state:
                st.session_state['analysis_expanded'] = {}
            if expand_key not in st.session_state['analysis_expanded']:
                st.session_state['analysis_expanded'][expand_key] = False

            if st.session_state['analysis_expanded'][expand_key]:
                if st.button("Hide Analysis", key=f"hide_{expand_key}"):
                    st.session_state['analysis_expanded'][expand_key] = False
                    st.rerun()
                st.text("Analysis:")
                for key, value in hist_entry['analysis'].items():
                    st.text(f"{key}: {value}")
            else:
                if st.button("Expand Analysis", key=expand_key):
                    st.session_state['analysis_expanded'][expand_key] = True
                    st.rerun()


if __name__ == "__main__":
    main()