import streamlit as st from PIL import Image from dotenv import load_dotenv from image_evaluators import LlamaEvaluator from prompt_refiners import LlamaPromptRefiner from weave_prompt import PromptOptimizer from similarity_metrics import LPIPSImageSimilarityMetric from image_generators import FalImageGenerator, MultiModelFalImageGenerator, AVAILABLE_MODELS from multi_model_optimizer import MultiModelPromptOptimizer # Load environment variables from .env file load_dotenv() st.set_page_config( page_title="WeavePrompt", page_icon="🎨", layout="wide" ) def main(): st.title("🎨 WeavePrompt: Multi-Model Prompt Optimization") st.markdown(""" Upload a target image and watch as WeavePrompt optimizes prompts across multiple AI models to find the best result. """) # Model selection state if 'selected_models' not in st.session_state: st.session_state.selected_models = ["FLUX.1 [pro]"] # Default selection # Initialize max_iterations in session state if not exists if 'max_iterations' not in st.session_state: st.session_state.max_iterations = 2 # Initialize session state if 'optimizer' not in st.session_state: image_generator = MultiModelFalImageGenerator(st.session_state.selected_models) st.session_state.optimizer = MultiModelPromptOptimizer( image_generator=image_generator, evaluator=LlamaEvaluator(), refiner=LlamaPromptRefiner(), similarity_metric=LPIPSImageSimilarityMetric(), max_iterations=st.session_state.max_iterations, similarity_threshold=0.95 ) if 'optimization_started' not in st.session_state: st.session_state.optimization_started = False if 'current_results' not in st.session_state: st.session_state.current_results = None # Auto mode state if 'auto_mode' not in st.session_state: st.session_state.auto_mode = False if 'auto_paused' not in st.session_state: st.session_state.auto_paused = False # Auto mode step control - use this to control when to step vs when to display if 'auto_should_step' not in st.session_state: st.session_state.auto_should_step = False # Model Selection UI st.subheader("πŸ€– Model Selection") st.markdown("Choose which AI models to optimize with:") # Organize models by category flux_models = [k for k in AVAILABLE_MODELS.keys() if k.startswith("FLUX")] google_models = [k for k in AVAILABLE_MODELS.keys() if k in ["Imagen 4", "Imagen 4 Ultra", "Gemini 2.5 Flash Image"]] other_models = [k for k in AVAILABLE_MODELS.keys() if k not in flux_models and k not in google_models] # Track if selection changed new_selection = [] # FLUX Models Section st.markdown("### πŸ”₯ FLUX Models") cols_flux = st.columns(2) for i, model_name in enumerate(flux_models): col_idx = i % 2 with cols_flux[col_idx]: is_selected = model_name in st.session_state.selected_models if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"): new_selection.append(model_name) # Google Models Section st.markdown("### πŸ” Google Models") cols_google = st.columns(2) for i, model_name in enumerate(google_models): col_idx = i % 2 with cols_google[col_idx]: is_selected = model_name in st.session_state.selected_models if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"): new_selection.append(model_name) # Other Models Section if other_models: st.markdown("### 🎨 Other Models") cols_other = st.columns(2) for i, model_name in enumerate(other_models): col_idx = i % 2 with cols_other[col_idx]: is_selected = model_name in st.session_state.selected_models if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"): new_selection.append(model_name) # Ensure at least one model is selected if not new_selection: st.error("Please select at least one model!") new_selection = ["FLUX.1 [pro]"] # Default fallback # Update selection if changed if set(new_selection) != set(st.session_state.selected_models): st.session_state.selected_models = new_selection # Recreate optimizer with new models image_generator = MultiModelFalImageGenerator(st.session_state.selected_models) st.session_state.optimizer = MultiModelPromptOptimizer( image_generator=image_generator, evaluator=LlamaEvaluator(), refiner=LlamaPromptRefiner(), similarity_metric=LPIPSImageSimilarityMetric(), max_iterations=st.session_state.max_iterations, similarity_threshold=0.95 ) st.success(f"Updated to use {len(new_selection)} model(s): {', '.join(new_selection)}") st.markdown("---") # File uploader uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg']) if uploaded_file is not None: # Display target image target_image = Image.open(uploaded_file) col1, col2 = st.columns(2) with col1: st.subheader("Target Image") st.image(target_image, width='stretch') # Start button if not st.session_state.optimization_started: if st.button("πŸš€ Start Optimization", type="primary"): # Set state first to ensure immediate UI update st.session_state.optimization_started = True # Force immediate rerun to show disabled state st.rerun() else: # Show disabled button when optimization has started st.button("⏳ Optimization Running...", disabled=True, help="Optimization in progress", type="secondary") st.info("πŸ’‘ Optimization is running across selected models. Use the controls below to pause/resume or reset.") # Initialize optimization after state is set (only once) if st.session_state.current_results is None: try: is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image) st.session_state.current_results = (is_completed, prompt, generated_image) st.rerun() except Exception as e: st.error(f"Error initializing optimization: {str(e)}") st.session_state.optimization_started = False # Settings (always visible) st.subheader("βš™οΈ Settings") # Check if optimization is actively running (only disable settings during auto mode) is_actively_running = (st.session_state.optimization_started and st.session_state.auto_mode and not st.session_state.auto_paused) # Single row: Number of iterations, Auto-progress, and Pause/Resume controls col_settings1, col_settings2, col_settings3 = st.columns(3) with col_settings1: new_max_iterations = st.number_input( "Number of Iterations", min_value=1, max_value=20, value=st.session_state.max_iterations, help="Maximum number of optimization iterations per model", disabled=is_actively_running ) # Update if changed (only when not running) if new_max_iterations != st.session_state.max_iterations and not is_actively_running: st.session_state.max_iterations = new_max_iterations # Update the optimizer's max_iterations st.session_state.optimizer.max_iterations = new_max_iterations if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer: st.session_state.optimizer.current_optimizer.max_iterations = new_max_iterations with col_settings2: auto_mode = st.checkbox( "Auto-progress steps", value=st.session_state.auto_mode, disabled=is_actively_running ) if auto_mode != st.session_state.auto_mode and not is_actively_running: st.session_state.auto_mode = auto_mode if auto_mode: st.session_state.auto_paused = False st.session_state.auto_should_step = True # Start by stepping st.rerun() with col_settings3: # Pause/Resume controls (only when auto mode is enabled, optimization started, and not completed) optimization_completed = st.session_state.current_results is not None and st.session_state.current_results[0] if st.session_state.auto_mode and st.session_state.optimization_started and not optimization_completed: if st.session_state.auto_paused: if st.button("▢️ Resume", key="resume_btn"): st.session_state.auto_paused = False st.rerun() else: if st.button("⏸️ Pause", key="pause_btn"): st.session_state.auto_paused = True st.rerun() else: # Show placeholder or empty space when not in auto mode st.write("") # Display optimization progress if st.session_state.optimization_started: if st.session_state.current_results is not None: is_completed, prompt, generated_image = st.session_state.current_results # Get current model info current_model_name = st.session_state.optimizer.get_current_model_name() current_history = st.session_state.optimizer.history if hasattr(st.session_state.optimizer, 'history') else [] # If we have current model history, show its latest result instead of current_results if current_history: latest_step = current_history[-1] display_image = latest_step['image'] display_prompt = latest_step['prompt'] else: # Fallback to current_results display_image = generated_image display_prompt = prompt # Always show the actual current model's latest result with col2: st.subheader("Current Generated Image") st.image(display_image, width='stretch') if current_history: # Show info about the current state if is_completed: st.caption(f"🏁 {current_model_name} - Final Result") else: st.caption(f"🎯 {current_model_name}") else: st.caption(f"πŸš€ {current_model_name} - Initializing...") # Display current prompt under the image st.text_area("Current Prompt", display_prompt, height=100) else: # Show loading state with col2: st.subheader("Generated Image") st.info("Initializing optimization...") # Display loading prompt under the placeholder st.text_area("Current Prompt", "Generating initial prompt...", height=100) is_completed = False prompt = "" # Multi-model progress info progress_info = st.session_state.optimizer.get_progress_info() # Ensure we get the actual current model name (not cached) current_model = st.session_state.optimizer.get_current_model_name() # Progress metrics col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Current Model", current_model) with col2: # Show current model number vs total models (1-indexed for display) current_model_num = progress_info['current_model_index'] + 1 total_models = progress_info['total_models'] # Fix: Don't let current model number exceed total models if current_model_num > total_models: current_model_num = total_models st.metric("Models Progress", f"{current_model_num}/{total_models}") with col3: if 'current_iteration' in progress_info: # Show completed iterations vs max iterations current_iteration = progress_info['current_iteration'] max_iterations = progress_info['max_iterations'] # If not completed and we haven't reached max, show next iteration number if not is_completed and current_iteration < max_iterations: display_iteration = current_iteration + 1 st.metric("Current Iteration", f"{display_iteration}/{max_iterations}") else: st.metric("Current Iteration", f"{current_iteration}/{max_iterations}") with col4: if len(st.session_state.optimizer.history) > 0: similarity = st.session_state.optimizer.history[-1]['similarity'] st.metric("Similarity", f"{similarity:.2%}") else: status = "Completed" if is_completed else "Paused" if st.session_state.auto_paused else "In Progress" st.metric("Status", status) # Progress bars st.subheader("Progress") # Overall progress across all models overall_progress = progress_info['overall_progress'] if 'current_iteration' in progress_info and not is_completed: # Add current model progress to overall model_progress_fraction = progress_info['model_progress'] overall_progress = (progress_info['models_completed'] + model_progress_fraction) / progress_info['total_models'] # Create more descriptive progress labels models_completed = progress_info['models_completed'] total_models = progress_info['total_models'] if total_models > 1: # Multi-model scenario if models_completed == total_models: overall_label = f"🏁 All {total_models} models completed!" else: overall_label = f"πŸ”„ Processing {current_model} (Model {models_completed + 1} of {total_models})" else: # Single model scenario overall_label = f"🎯 Optimizing with {current_model}" st.progress(overall_progress, text=overall_label) # Current model progress if 'current_iteration' in progress_info and not is_completed: model_progress_value = progress_info['model_progress'] current_iter = progress_info['current_iteration'] max_iter = progress_info['max_iterations'] if current_iter == max_iter: iteration_label = f"βœ… {current_model}: Completed all {max_iter} iterations" else: iteration_label = f"⚑ {current_model}: Step {current_iter + 1} of {max_iter}" st.progress(model_progress_value, text=iteration_label) # Next step logic if not is_completed: # Auto mode logic - mimic pause button behavior if st.session_state.auto_mode and not st.session_state.auto_paused: if st.session_state.auto_should_step: # Execute the step is_completed, prompt, generated_image = st.session_state.optimizer.step() st.session_state.current_results = (is_completed, prompt, generated_image) # Set flag to NOT step on next render (let history display) st.session_state.auto_should_step = False st.rerun() else: # Don't step, just display current state and history, then set flag to step next time st.session_state.auto_should_step = True # Use a small delay then rerun to continue auto mode import time time.sleep(0.5) # Give user time to see the history st.rerun() # Manual mode elif not st.session_state.auto_mode: if st.button("Next Step"): is_completed, prompt, generated_image = st.session_state.optimizer.step() st.session_state.current_results = (is_completed, prompt, generated_image) st.rerun() # Show status when auto mode is paused elif st.session_state.auto_paused: st.info("Auto mode is paused. Click Resume to continue or uncheck Auto-progress to use manual mode.") else: st.success("Optimization completed! Click 'Reset' to try another image.") # Turn off auto mode when completed if st.session_state.auto_mode: st.session_state.auto_mode = False st.session_state.auto_paused = False # Reset button if st.button("Reset"): st.session_state.optimization_started = False st.session_state.current_results = None st.session_state.auto_mode = False st.session_state.auto_paused = False st.rerun() # Display multi-model history with tabs _display_multi_model_history() # Model Results Comparison (show when optimization is complete) if is_completed and hasattr(st.session_state.optimizer, 'get_all_results'): all_results = st.session_state.optimizer.get_all_results() best_result = st.session_state.optimizer.get_best_result() if all_results: st.subheader("πŸ† Model Comparison Results") # Show best result prominently if best_result: st.success(f"πŸ₯‡ **Best Result**: {best_result['model_name']} with {best_result['similarity']:.2%} similarity") # Create columns for each model result num_models = len(all_results) if num_models > 0: cols = st.columns(min(num_models, 3)) # Max 3 columns # Sort results by similarity (best first) sorted_results = sorted(all_results.items(), key=lambda x: x[1]['final_similarity'], reverse=True) for i, (model_name, result) in enumerate(sorted_results): col_idx = i % 3 with cols[col_idx]: # Add medal emoji for top 3 medal = "πŸ₯‡" if i == 0 else "πŸ₯ˆ" if i == 1 else "πŸ₯‰" if i == 2 else "πŸ…" st.markdown(f"### {medal} {model_name}") # Show final image st.image(result['final_image'], width='stretch') # Show metrics st.metric("Final Similarity", f"{result['final_similarity']:.2%}") st.metric("Iterations", result['iterations']) # Show final prompt in expander with st.expander("View Final Prompt"): st.text(result['final_prompt']) def _display_multi_model_history(): """Display optimization history with tabs for multi-model scenarios.""" # Get all completed model results all_model_results = {} if hasattr(st.session_state.optimizer, 'get_all_results'): all_model_results = st.session_state.optimizer.get_all_results() # Build separate histories for each model all_histories = {} # Add completed model histories - each model gets its own stored history for model_name, result in all_model_results.items(): if 'history' in result and result['history']: # Make a deep copy to ensure complete separation all_histories[model_name] = [step.copy() for step in result['history']] # Add current model's in-progress history (if any) if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer: current_history = st.session_state.optimizer.current_optimizer.history if current_history: current_model_name = st.session_state.optimizer.get_current_model_name() # Only add if this model doesn't already have a completed history if current_model_name not in all_histories: all_histories[current_model_name] = [step.copy() for step in current_history] # Display histories if all_histories: st.subheader("πŸ“Š Optimization History") if len(all_histories) == 1: # Single model - no tabs needed model_name = list(all_histories.keys())[0] history = all_histories[model_name] st.markdown(f"**{model_name}** ({len(history)} steps)") _display_model_history(history, model_name) else: # Multi-model - use tabs tab_names = [] for model_name, history in all_histories.items(): tab_names.append(f"{model_name} ({len(history)} steps)") tabs = st.tabs(tab_names) for i, (model_name, history) in enumerate(all_histories.items()): with tabs[i]: _display_model_history(history, model_name) def _display_model_history(history, model_name): """Display history for a single model.""" if not history: st.info(f"No optimization steps yet for {model_name}") return for idx, hist_entry in enumerate(history): st.markdown(f"### Step {idx + 1}") col1, col2 = st.columns([2, 3]) with col1: st.image(hist_entry['image'], width='stretch') with col2: st.text(f"Similarity: {hist_entry['similarity']:.2%}") st.text("Prompt:") st.text(hist_entry['prompt']) # Toggle analysis view per history entry expand_key = f"expand_analysis_{model_name}_{idx}" if 'analysis_expanded' not in st.session_state: st.session_state['analysis_expanded'] = {} if expand_key not in st.session_state['analysis_expanded']: st.session_state['analysis_expanded'][expand_key] = False if st.session_state['analysis_expanded'][expand_key]: if st.button("Hide Analysis", key=f"hide_{expand_key}"): st.session_state['analysis_expanded'][expand_key] = False st.rerun() st.text("Analysis:") for key, value in hist_entry['analysis'].items(): st.text(f"{key}: {value}") else: if st.button("Expand Analysis", key=expand_key): st.session_state['analysis_expanded'][expand_key] = True st.rerun() if __name__ == "__main__": main()