Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| from image_evaluators import LlamaEvaluator | |
| from prompt_refiners import LlamaPromptRefiner | |
| from weave_prompt import PromptOptimizer | |
| from similarity_metrics import LPIPSImageSimilarityMetric | |
| from image_generators import FalImageGenerator, MultiModelFalImageGenerator, AVAILABLE_MODELS | |
| from multi_model_optimizer import MultiModelPromptOptimizer | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| st.set_page_config( | |
| page_title="WeavePrompt", | |
| page_icon="🎨", | |
| layout="wide" | |
| ) | |
| def main(): | |
| st.title("🎨 WeavePrompt: Multi-Model Prompt Optimization") | |
| st.markdown(""" | |
| Upload a target image and watch as WeavePrompt optimizes prompts across multiple AI models to find the best result. | |
| """) | |
| # Model selection state | |
| if 'selected_models' not in st.session_state: | |
| st.session_state.selected_models = ["FLUX.1 [pro]"] # Default selection | |
| # Initialize max_iterations in session state if not exists | |
| if 'max_iterations' not in st.session_state: | |
| st.session_state.max_iterations = 2 | |
| # Initialize session state | |
| if 'optimizer' not in st.session_state: | |
| image_generator = MultiModelFalImageGenerator(st.session_state.selected_models) | |
| st.session_state.optimizer = MultiModelPromptOptimizer( | |
| image_generator=image_generator, | |
| evaluator=LlamaEvaluator(), | |
| refiner=LlamaPromptRefiner(), | |
| similarity_metric=LPIPSImageSimilarityMetric(), | |
| max_iterations=st.session_state.max_iterations, | |
| similarity_threshold=0.95 | |
| ) | |
| if 'optimization_started' not in st.session_state: | |
| st.session_state.optimization_started = False | |
| if 'current_results' not in st.session_state: | |
| st.session_state.current_results = None | |
| # Auto mode state | |
| if 'auto_mode' not in st.session_state: | |
| st.session_state.auto_mode = False | |
| if 'auto_paused' not in st.session_state: | |
| st.session_state.auto_paused = False | |
| # Auto mode step control - use this to control when to step vs when to display | |
| if 'auto_should_step' not in st.session_state: | |
| st.session_state.auto_should_step = False | |
| # Model Selection UI | |
| st.subheader("🤖 Model Selection") | |
| st.markdown("Choose which AI models to optimize with:") | |
| # Organize models by category | |
| flux_models = [k for k in AVAILABLE_MODELS.keys() if k.startswith("FLUX")] | |
| google_models = [k for k in AVAILABLE_MODELS.keys() if k in ["Imagen 4", "Imagen 4 Ultra", "Gemini 2.5 Flash Image"]] | |
| other_models = [k for k in AVAILABLE_MODELS.keys() if k not in flux_models and k not in google_models] | |
| # Track if selection changed | |
| new_selection = [] | |
| # FLUX Models Section | |
| st.markdown("### 🔥 FLUX Models") | |
| cols_flux = st.columns(2) | |
| for i, model_name in enumerate(flux_models): | |
| col_idx = i % 2 | |
| with cols_flux[col_idx]: | |
| is_selected = model_name in st.session_state.selected_models | |
| if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"): | |
| new_selection.append(model_name) | |
| # Google Models Section | |
| st.markdown("### 🔍 Google Models") | |
| cols_google = st.columns(2) | |
| for i, model_name in enumerate(google_models): | |
| col_idx = i % 2 | |
| with cols_google[col_idx]: | |
| is_selected = model_name in st.session_state.selected_models | |
| if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"): | |
| new_selection.append(model_name) | |
| # Other Models Section | |
| if other_models: | |
| st.markdown("### 🎨 Other Models") | |
| cols_other = st.columns(2) | |
| for i, model_name in enumerate(other_models): | |
| col_idx = i % 2 | |
| with cols_other[col_idx]: | |
| is_selected = model_name in st.session_state.selected_models | |
| if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"): | |
| new_selection.append(model_name) | |
| # Ensure at least one model is selected | |
| if not new_selection: | |
| st.error("Please select at least one model!") | |
| new_selection = ["FLUX.1 [pro]"] # Default fallback | |
| # Update selection if changed | |
| if set(new_selection) != set(st.session_state.selected_models): | |
| st.session_state.selected_models = new_selection | |
| # Recreate optimizer with new models | |
| image_generator = MultiModelFalImageGenerator(st.session_state.selected_models) | |
| st.session_state.optimizer = MultiModelPromptOptimizer( | |
| image_generator=image_generator, | |
| evaluator=LlamaEvaluator(), | |
| refiner=LlamaPromptRefiner(), | |
| similarity_metric=LPIPSImageSimilarityMetric(), | |
| max_iterations=st.session_state.max_iterations, | |
| similarity_threshold=0.95 | |
| ) | |
| st.success(f"Updated to use {len(new_selection)} model(s): {', '.join(new_selection)}") | |
| st.markdown("---") | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg']) | |
| if uploaded_file is not None: | |
| # Display target image | |
| target_image = Image.open(uploaded_file) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Target Image") | |
| st.image(target_image, width='stretch') | |
| # Start button | |
| if not st.session_state.optimization_started: | |
| if st.button("🚀 Start Optimization", type="primary"): | |
| # Set state first to ensure immediate UI update | |
| st.session_state.optimization_started = True | |
| # Force immediate rerun to show disabled state | |
| st.rerun() | |
| else: | |
| # Show disabled button when optimization has started | |
| st.button("⏳ Optimization Running...", disabled=True, help="Optimization in progress", type="secondary") | |
| st.info("💡 Optimization is running across selected models. Use the controls below to pause/resume or reset.") | |
| # Initialize optimization after state is set (only once) | |
| if st.session_state.current_results is None: | |
| try: | |
| is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image) | |
| st.session_state.current_results = (is_completed, prompt, generated_image) | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error initializing optimization: {str(e)}") | |
| st.session_state.optimization_started = False | |
| # Settings (always visible) | |
| st.subheader("⚙️ Settings") | |
| # Check if optimization is actively running (only disable settings during auto mode) | |
| is_actively_running = (st.session_state.optimization_started and | |
| st.session_state.auto_mode and | |
| not st.session_state.auto_paused) | |
| # Single row: Number of iterations, Auto-progress, and Pause/Resume controls | |
| col_settings1, col_settings2, col_settings3 = st.columns(3) | |
| with col_settings1: | |
| new_max_iterations = st.number_input( | |
| "Number of Iterations", | |
| min_value=1, | |
| max_value=20, | |
| value=st.session_state.max_iterations, | |
| help="Maximum number of optimization iterations per model", | |
| disabled=is_actively_running | |
| ) | |
| # Update if changed (only when not running) | |
| if new_max_iterations != st.session_state.max_iterations and not is_actively_running: | |
| st.session_state.max_iterations = new_max_iterations | |
| # Update the optimizer's max_iterations | |
| st.session_state.optimizer.max_iterations = new_max_iterations | |
| if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer: | |
| st.session_state.optimizer.current_optimizer.max_iterations = new_max_iterations | |
| with col_settings2: | |
| auto_mode = st.checkbox( | |
| "Auto-progress steps", | |
| value=st.session_state.auto_mode, | |
| disabled=is_actively_running | |
| ) | |
| if auto_mode != st.session_state.auto_mode and not is_actively_running: | |
| st.session_state.auto_mode = auto_mode | |
| if auto_mode: | |
| st.session_state.auto_paused = False | |
| st.session_state.auto_should_step = True # Start by stepping | |
| st.rerun() | |
| with col_settings3: | |
| # Pause/Resume controls (only when auto mode is enabled, optimization started, and not completed) | |
| optimization_completed = st.session_state.current_results is not None and st.session_state.current_results[0] | |
| if st.session_state.auto_mode and st.session_state.optimization_started and not optimization_completed: | |
| if st.session_state.auto_paused: | |
| if st.button("▶️ Resume", key="resume_btn"): | |
| st.session_state.auto_paused = False | |
| st.rerun() | |
| else: | |
| if st.button("⏸️ Pause", key="pause_btn"): | |
| st.session_state.auto_paused = True | |
| st.rerun() | |
| else: | |
| # Show placeholder or empty space when not in auto mode | |
| st.write("") | |
| # Display optimization progress | |
| if st.session_state.optimization_started: | |
| if st.session_state.current_results is not None: | |
| is_completed, prompt, generated_image = st.session_state.current_results | |
| # Get current model info | |
| current_model_name = st.session_state.optimizer.get_current_model_name() | |
| current_history = st.session_state.optimizer.history if hasattr(st.session_state.optimizer, 'history') else [] | |
| # If we have current model history, show its latest result instead of current_results | |
| if current_history: | |
| latest_step = current_history[-1] | |
| display_image = latest_step['image'] | |
| display_prompt = latest_step['prompt'] | |
| else: | |
| # Fallback to current_results | |
| display_image = generated_image | |
| display_prompt = prompt | |
| # Always show the actual current model's latest result | |
| with col2: | |
| st.subheader("Current Generated Image") | |
| st.image(display_image, width='stretch') | |
| if current_history: | |
| # Show info about the current state | |
| if is_completed: | |
| st.caption(f"🏁 {current_model_name} - Final Result") | |
| else: | |
| st.caption(f"🎯 {current_model_name}") | |
| else: | |
| st.caption(f"🚀 {current_model_name} - Initializing...") | |
| # Display current prompt under the image | |
| st.text_area("Current Prompt", display_prompt, height=100) | |
| else: | |
| # Show loading state | |
| with col2: | |
| st.subheader("Generated Image") | |
| st.info("Initializing optimization...") | |
| # Display loading prompt under the placeholder | |
| st.text_area("Current Prompt", "Generating initial prompt...", height=100) | |
| is_completed = False | |
| prompt = "" | |
| # Multi-model progress info | |
| progress_info = st.session_state.optimizer.get_progress_info() | |
| # Ensure we get the actual current model name (not cached) | |
| current_model = st.session_state.optimizer.get_current_model_name() | |
| # Progress metrics | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("Current Model", current_model) | |
| with col2: | |
| # Show current model number vs total models (1-indexed for display) | |
| current_model_num = progress_info['current_model_index'] + 1 | |
| total_models = progress_info['total_models'] | |
| # Fix: Don't let current model number exceed total models | |
| if current_model_num > total_models: | |
| current_model_num = total_models | |
| st.metric("Models Progress", f"{current_model_num}/{total_models}") | |
| with col3: | |
| if 'current_iteration' in progress_info: | |
| # Show completed iterations vs max iterations | |
| current_iteration = progress_info['current_iteration'] | |
| max_iterations = progress_info['max_iterations'] | |
| # If not completed and we haven't reached max, show next iteration number | |
| if not is_completed and current_iteration < max_iterations: | |
| display_iteration = current_iteration + 1 | |
| st.metric("Current Iteration", f"{display_iteration}/{max_iterations}") | |
| else: | |
| st.metric("Current Iteration", f"{current_iteration}/{max_iterations}") | |
| with col4: | |
| if len(st.session_state.optimizer.history) > 0: | |
| similarity = st.session_state.optimizer.history[-1]['similarity'] | |
| st.metric("Similarity", f"{similarity:.2%}") | |
| else: | |
| status = "Completed" if is_completed else "Paused" if st.session_state.auto_paused else "In Progress" | |
| st.metric("Status", status) | |
| # Progress bars | |
| st.subheader("Progress") | |
| # Overall progress across all models | |
| overall_progress = progress_info['overall_progress'] | |
| if 'current_iteration' in progress_info and not is_completed: | |
| # Add current model progress to overall | |
| model_progress_fraction = progress_info['model_progress'] | |
| overall_progress = (progress_info['models_completed'] + model_progress_fraction) / progress_info['total_models'] | |
| # Create more descriptive progress labels | |
| models_completed = progress_info['models_completed'] | |
| total_models = progress_info['total_models'] | |
| if total_models > 1: | |
| # Multi-model scenario | |
| if models_completed == total_models: | |
| overall_label = f"🏁 All {total_models} models completed!" | |
| else: | |
| overall_label = f"🔄 Processing {current_model} (Model {models_completed + 1} of {total_models})" | |
| else: | |
| # Single model scenario | |
| overall_label = f"🎯 Optimizing with {current_model}" | |
| st.progress(overall_progress, text=overall_label) | |
| # Current model progress | |
| if 'current_iteration' in progress_info and not is_completed: | |
| model_progress_value = progress_info['model_progress'] | |
| current_iter = progress_info['current_iteration'] | |
| max_iter = progress_info['max_iterations'] | |
| if current_iter == max_iter: | |
| iteration_label = f"✅ {current_model}: Completed all {max_iter} iterations" | |
| else: | |
| iteration_label = f"⚡ {current_model}: Step {current_iter + 1} of {max_iter}" | |
| st.progress(model_progress_value, text=iteration_label) | |
| # Next step logic | |
| if not is_completed: | |
| # Auto mode logic - mimic pause button behavior | |
| if st.session_state.auto_mode and not st.session_state.auto_paused: | |
| if st.session_state.auto_should_step: | |
| # Execute the step | |
| is_completed, prompt, generated_image = st.session_state.optimizer.step() | |
| st.session_state.current_results = (is_completed, prompt, generated_image) | |
| # Set flag to NOT step on next render (let history display) | |
| st.session_state.auto_should_step = False | |
| st.rerun() | |
| else: | |
| # Don't step, just display current state and history, then set flag to step next time | |
| st.session_state.auto_should_step = True | |
| # Use a small delay then rerun to continue auto mode | |
| import time | |
| time.sleep(0.5) # Give user time to see the history | |
| st.rerun() | |
| # Manual mode | |
| elif not st.session_state.auto_mode: | |
| if st.button("Next Step"): | |
| is_completed, prompt, generated_image = st.session_state.optimizer.step() | |
| st.session_state.current_results = (is_completed, prompt, generated_image) | |
| st.rerun() | |
| # Show status when auto mode is paused | |
| elif st.session_state.auto_paused: | |
| st.info("Auto mode is paused. Click Resume to continue or uncheck Auto-progress to use manual mode.") | |
| else: | |
| st.success("Optimization completed! Click 'Reset' to try another image.") | |
| # Turn off auto mode when completed | |
| if st.session_state.auto_mode: | |
| st.session_state.auto_mode = False | |
| st.session_state.auto_paused = False | |
| # Reset button | |
| if st.button("Reset"): | |
| st.session_state.optimization_started = False | |
| st.session_state.current_results = None | |
| st.session_state.auto_mode = False | |
| st.session_state.auto_paused = False | |
| st.rerun() | |
| # Display multi-model history with tabs | |
| _display_multi_model_history() | |
| # Model Results Comparison (show when optimization is complete) | |
| if is_completed and hasattr(st.session_state.optimizer, 'get_all_results'): | |
| all_results = st.session_state.optimizer.get_all_results() | |
| best_result = st.session_state.optimizer.get_best_result() | |
| if all_results: | |
| st.subheader("🏆 Model Comparison Results") | |
| # Show best result prominently | |
| if best_result: | |
| st.success(f"🥇 **Best Result**: {best_result['model_name']} with {best_result['similarity']:.2%} similarity") | |
| # Create columns for each model result | |
| num_models = len(all_results) | |
| if num_models > 0: | |
| cols = st.columns(min(num_models, 3)) # Max 3 columns | |
| # Sort results by similarity (best first) | |
| sorted_results = sorted(all_results.items(), | |
| key=lambda x: x[1]['final_similarity'], | |
| reverse=True) | |
| for i, (model_name, result) in enumerate(sorted_results): | |
| col_idx = i % 3 | |
| with cols[col_idx]: | |
| # Add medal emoji for top 3 | |
| medal = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else "🏅" | |
| st.markdown(f"### {medal} {model_name}") | |
| # Show final image | |
| st.image(result['final_image'], width='stretch') | |
| # Show metrics | |
| st.metric("Final Similarity", f"{result['final_similarity']:.2%}") | |
| st.metric("Iterations", result['iterations']) | |
| # Show final prompt in expander | |
| with st.expander("View Final Prompt"): | |
| st.text(result['final_prompt']) | |
| def _display_multi_model_history(): | |
| """Display optimization history with tabs for multi-model scenarios.""" | |
| # Get all completed model results | |
| all_model_results = {} | |
| if hasattr(st.session_state.optimizer, 'get_all_results'): | |
| all_model_results = st.session_state.optimizer.get_all_results() | |
| # Build separate histories for each model | |
| all_histories = {} | |
| # Add completed model histories - each model gets its own stored history | |
| for model_name, result in all_model_results.items(): | |
| if 'history' in result and result['history']: | |
| # Make a deep copy to ensure complete separation | |
| all_histories[model_name] = [step.copy() for step in result['history']] | |
| # Add current model's in-progress history (if any) | |
| if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer: | |
| current_history = st.session_state.optimizer.current_optimizer.history | |
| if current_history: | |
| current_model_name = st.session_state.optimizer.get_current_model_name() | |
| # Only add if this model doesn't already have a completed history | |
| if current_model_name not in all_histories: | |
| all_histories[current_model_name] = [step.copy() for step in current_history] | |
| # Display histories | |
| if all_histories: | |
| st.subheader("📊 Optimization History") | |
| if len(all_histories) == 1: | |
| # Single model - no tabs needed | |
| model_name = list(all_histories.keys())[0] | |
| history = all_histories[model_name] | |
| st.markdown(f"**{model_name}** ({len(history)} steps)") | |
| _display_model_history(history, model_name) | |
| else: | |
| # Multi-model - use tabs | |
| tab_names = [] | |
| for model_name, history in all_histories.items(): | |
| tab_names.append(f"{model_name} ({len(history)} steps)") | |
| tabs = st.tabs(tab_names) | |
| for i, (model_name, history) in enumerate(all_histories.items()): | |
| with tabs[i]: | |
| _display_model_history(history, model_name) | |
| def _display_model_history(history, model_name): | |
| """Display history for a single model.""" | |
| if not history: | |
| st.info(f"No optimization steps yet for {model_name}") | |
| return | |
| for idx, hist_entry in enumerate(history): | |
| st.markdown(f"### Step {idx + 1}") | |
| col1, col2 = st.columns([2, 3]) | |
| with col1: | |
| st.image(hist_entry['image'], width='stretch') | |
| with col2: | |
| st.text(f"Similarity: {hist_entry['similarity']:.2%}") | |
| st.text("Prompt:") | |
| st.text(hist_entry['prompt']) | |
| # Toggle analysis view per history entry | |
| expand_key = f"expand_analysis_{model_name}_{idx}" | |
| if 'analysis_expanded' not in st.session_state: | |
| st.session_state['analysis_expanded'] = {} | |
| if expand_key not in st.session_state['analysis_expanded']: | |
| st.session_state['analysis_expanded'][expand_key] = False | |
| if st.session_state['analysis_expanded'][expand_key]: | |
| if st.button("Hide Analysis", key=f"hide_{expand_key}"): | |
| st.session_state['analysis_expanded'][expand_key] = False | |
| st.rerun() | |
| st.text("Analysis:") | |
| for key, value in hist_entry['analysis'].items(): | |
| st.text(f"{key}: {value}") | |
| else: | |
| if st.button("Expand Analysis", key=expand_key): | |
| st.session_state['analysis_expanded'][expand_key] = True | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() | |