WeavePrompt / app.py
kevin1kevin1k's picture
Upload folder using huggingface_hub
2c93286 verified
import streamlit as st
from PIL import Image
from dotenv import load_dotenv
from image_evaluators import LlamaEvaluator
from prompt_refiners import LlamaPromptRefiner
from weave_prompt import PromptOptimizer
from similarity_metrics import LPIPSImageSimilarityMetric
from image_generators import FalImageGenerator, MultiModelFalImageGenerator, AVAILABLE_MODELS
from multi_model_optimizer import MultiModelPromptOptimizer
# Load environment variables from .env file
load_dotenv()
st.set_page_config(
page_title="WeavePrompt",
page_icon="🎨",
layout="wide"
)
def main():
st.title("🎨 WeavePrompt: Multi-Model Prompt Optimization")
st.markdown("""
Upload a target image and watch as WeavePrompt optimizes prompts across multiple AI models to find the best result.
""")
# Model selection state
if 'selected_models' not in st.session_state:
st.session_state.selected_models = ["FLUX.1 [pro]"] # Default selection
# Initialize max_iterations in session state if not exists
if 'max_iterations' not in st.session_state:
st.session_state.max_iterations = 2
# Initialize session state
if 'optimizer' not in st.session_state:
image_generator = MultiModelFalImageGenerator(st.session_state.selected_models)
st.session_state.optimizer = MultiModelPromptOptimizer(
image_generator=image_generator,
evaluator=LlamaEvaluator(),
refiner=LlamaPromptRefiner(),
similarity_metric=LPIPSImageSimilarityMetric(),
max_iterations=st.session_state.max_iterations,
similarity_threshold=0.95
)
if 'optimization_started' not in st.session_state:
st.session_state.optimization_started = False
if 'current_results' not in st.session_state:
st.session_state.current_results = None
# Auto mode state
if 'auto_mode' not in st.session_state:
st.session_state.auto_mode = False
if 'auto_paused' not in st.session_state:
st.session_state.auto_paused = False
# Auto mode step control - use this to control when to step vs when to display
if 'auto_should_step' not in st.session_state:
st.session_state.auto_should_step = False
# Model Selection UI
st.subheader("🤖 Model Selection")
st.markdown("Choose which AI models to optimize with:")
# Organize models by category
flux_models = [k for k in AVAILABLE_MODELS.keys() if k.startswith("FLUX")]
google_models = [k for k in AVAILABLE_MODELS.keys() if k in ["Imagen 4", "Imagen 4 Ultra", "Gemini 2.5 Flash Image"]]
other_models = [k for k in AVAILABLE_MODELS.keys() if k not in flux_models and k not in google_models]
# Track if selection changed
new_selection = []
# FLUX Models Section
st.markdown("### 🔥 FLUX Models")
cols_flux = st.columns(2)
for i, model_name in enumerate(flux_models):
col_idx = i % 2
with cols_flux[col_idx]:
is_selected = model_name in st.session_state.selected_models
if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
new_selection.append(model_name)
# Google Models Section
st.markdown("### 🔍 Google Models")
cols_google = st.columns(2)
for i, model_name in enumerate(google_models):
col_idx = i % 2
with cols_google[col_idx]:
is_selected = model_name in st.session_state.selected_models
if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
new_selection.append(model_name)
# Other Models Section
if other_models:
st.markdown("### 🎨 Other Models")
cols_other = st.columns(2)
for i, model_name in enumerate(other_models):
col_idx = i % 2
with cols_other[col_idx]:
is_selected = model_name in st.session_state.selected_models
if st.checkbox(model_name, value=is_selected, key=f"model_{model_name}"):
new_selection.append(model_name)
# Ensure at least one model is selected
if not new_selection:
st.error("Please select at least one model!")
new_selection = ["FLUX.1 [pro]"] # Default fallback
# Update selection if changed
if set(new_selection) != set(st.session_state.selected_models):
st.session_state.selected_models = new_selection
# Recreate optimizer with new models
image_generator = MultiModelFalImageGenerator(st.session_state.selected_models)
st.session_state.optimizer = MultiModelPromptOptimizer(
image_generator=image_generator,
evaluator=LlamaEvaluator(),
refiner=LlamaPromptRefiner(),
similarity_metric=LPIPSImageSimilarityMetric(),
max_iterations=st.session_state.max_iterations,
similarity_threshold=0.95
)
st.success(f"Updated to use {len(new_selection)} model(s): {', '.join(new_selection)}")
st.markdown("---")
# File uploader
uploaded_file = st.file_uploader("Choose a target image", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
# Display target image
target_image = Image.open(uploaded_file)
col1, col2 = st.columns(2)
with col1:
st.subheader("Target Image")
st.image(target_image, width='stretch')
# Start button
if not st.session_state.optimization_started:
if st.button("🚀 Start Optimization", type="primary"):
# Set state first to ensure immediate UI update
st.session_state.optimization_started = True
# Force immediate rerun to show disabled state
st.rerun()
else:
# Show disabled button when optimization has started
st.button("⏳ Optimization Running...", disabled=True, help="Optimization in progress", type="secondary")
st.info("💡 Optimization is running across selected models. Use the controls below to pause/resume or reset.")
# Initialize optimization after state is set (only once)
if st.session_state.current_results is None:
try:
is_completed, prompt, generated_image = st.session_state.optimizer.initialize(target_image)
st.session_state.current_results = (is_completed, prompt, generated_image)
st.rerun()
except Exception as e:
st.error(f"Error initializing optimization: {str(e)}")
st.session_state.optimization_started = False
# Settings (always visible)
st.subheader("⚙️ Settings")
# Check if optimization is actively running (only disable settings during auto mode)
is_actively_running = (st.session_state.optimization_started and
st.session_state.auto_mode and
not st.session_state.auto_paused)
# Single row: Number of iterations, Auto-progress, and Pause/Resume controls
col_settings1, col_settings2, col_settings3 = st.columns(3)
with col_settings1:
new_max_iterations = st.number_input(
"Number of Iterations",
min_value=1,
max_value=20,
value=st.session_state.max_iterations,
help="Maximum number of optimization iterations per model",
disabled=is_actively_running
)
# Update if changed (only when not running)
if new_max_iterations != st.session_state.max_iterations and not is_actively_running:
st.session_state.max_iterations = new_max_iterations
# Update the optimizer's max_iterations
st.session_state.optimizer.max_iterations = new_max_iterations
if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer:
st.session_state.optimizer.current_optimizer.max_iterations = new_max_iterations
with col_settings2:
auto_mode = st.checkbox(
"Auto-progress steps",
value=st.session_state.auto_mode,
disabled=is_actively_running
)
if auto_mode != st.session_state.auto_mode and not is_actively_running:
st.session_state.auto_mode = auto_mode
if auto_mode:
st.session_state.auto_paused = False
st.session_state.auto_should_step = True # Start by stepping
st.rerun()
with col_settings3:
# Pause/Resume controls (only when auto mode is enabled, optimization started, and not completed)
optimization_completed = st.session_state.current_results is not None and st.session_state.current_results[0]
if st.session_state.auto_mode and st.session_state.optimization_started and not optimization_completed:
if st.session_state.auto_paused:
if st.button("▶️ Resume", key="resume_btn"):
st.session_state.auto_paused = False
st.rerun()
else:
if st.button("⏸️ Pause", key="pause_btn"):
st.session_state.auto_paused = True
st.rerun()
else:
# Show placeholder or empty space when not in auto mode
st.write("")
# Display optimization progress
if st.session_state.optimization_started:
if st.session_state.current_results is not None:
is_completed, prompt, generated_image = st.session_state.current_results
# Get current model info
current_model_name = st.session_state.optimizer.get_current_model_name()
current_history = st.session_state.optimizer.history if hasattr(st.session_state.optimizer, 'history') else []
# If we have current model history, show its latest result instead of current_results
if current_history:
latest_step = current_history[-1]
display_image = latest_step['image']
display_prompt = latest_step['prompt']
else:
# Fallback to current_results
display_image = generated_image
display_prompt = prompt
# Always show the actual current model's latest result
with col2:
st.subheader("Current Generated Image")
st.image(display_image, width='stretch')
if current_history:
# Show info about the current state
if is_completed:
st.caption(f"🏁 {current_model_name} - Final Result")
else:
st.caption(f"🎯 {current_model_name}")
else:
st.caption(f"🚀 {current_model_name} - Initializing...")
# Display current prompt under the image
st.text_area("Current Prompt", display_prompt, height=100)
else:
# Show loading state
with col2:
st.subheader("Generated Image")
st.info("Initializing optimization...")
# Display loading prompt under the placeholder
st.text_area("Current Prompt", "Generating initial prompt...", height=100)
is_completed = False
prompt = ""
# Multi-model progress info
progress_info = st.session_state.optimizer.get_progress_info()
# Ensure we get the actual current model name (not cached)
current_model = st.session_state.optimizer.get_current_model_name()
# Progress metrics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Current Model", current_model)
with col2:
# Show current model number vs total models (1-indexed for display)
current_model_num = progress_info['current_model_index'] + 1
total_models = progress_info['total_models']
# Fix: Don't let current model number exceed total models
if current_model_num > total_models:
current_model_num = total_models
st.metric("Models Progress", f"{current_model_num}/{total_models}")
with col3:
if 'current_iteration' in progress_info:
# Show completed iterations vs max iterations
current_iteration = progress_info['current_iteration']
max_iterations = progress_info['max_iterations']
# If not completed and we haven't reached max, show next iteration number
if not is_completed and current_iteration < max_iterations:
display_iteration = current_iteration + 1
st.metric("Current Iteration", f"{display_iteration}/{max_iterations}")
else:
st.metric("Current Iteration", f"{current_iteration}/{max_iterations}")
with col4:
if len(st.session_state.optimizer.history) > 0:
similarity = st.session_state.optimizer.history[-1]['similarity']
st.metric("Similarity", f"{similarity:.2%}")
else:
status = "Completed" if is_completed else "Paused" if st.session_state.auto_paused else "In Progress"
st.metric("Status", status)
# Progress bars
st.subheader("Progress")
# Overall progress across all models
overall_progress = progress_info['overall_progress']
if 'current_iteration' in progress_info and not is_completed:
# Add current model progress to overall
model_progress_fraction = progress_info['model_progress']
overall_progress = (progress_info['models_completed'] + model_progress_fraction) / progress_info['total_models']
# Create more descriptive progress labels
models_completed = progress_info['models_completed']
total_models = progress_info['total_models']
if total_models > 1:
# Multi-model scenario
if models_completed == total_models:
overall_label = f"🏁 All {total_models} models completed!"
else:
overall_label = f"🔄 Processing {current_model} (Model {models_completed + 1} of {total_models})"
else:
# Single model scenario
overall_label = f"🎯 Optimizing with {current_model}"
st.progress(overall_progress, text=overall_label)
# Current model progress
if 'current_iteration' in progress_info and not is_completed:
model_progress_value = progress_info['model_progress']
current_iter = progress_info['current_iteration']
max_iter = progress_info['max_iterations']
if current_iter == max_iter:
iteration_label = f"✅ {current_model}: Completed all {max_iter} iterations"
else:
iteration_label = f"⚡ {current_model}: Step {current_iter + 1} of {max_iter}"
st.progress(model_progress_value, text=iteration_label)
# Next step logic
if not is_completed:
# Auto mode logic - mimic pause button behavior
if st.session_state.auto_mode and not st.session_state.auto_paused:
if st.session_state.auto_should_step:
# Execute the step
is_completed, prompt, generated_image = st.session_state.optimizer.step()
st.session_state.current_results = (is_completed, prompt, generated_image)
# Set flag to NOT step on next render (let history display)
st.session_state.auto_should_step = False
st.rerun()
else:
# Don't step, just display current state and history, then set flag to step next time
st.session_state.auto_should_step = True
# Use a small delay then rerun to continue auto mode
import time
time.sleep(0.5) # Give user time to see the history
st.rerun()
# Manual mode
elif not st.session_state.auto_mode:
if st.button("Next Step"):
is_completed, prompt, generated_image = st.session_state.optimizer.step()
st.session_state.current_results = (is_completed, prompt, generated_image)
st.rerun()
# Show status when auto mode is paused
elif st.session_state.auto_paused:
st.info("Auto mode is paused. Click Resume to continue or uncheck Auto-progress to use manual mode.")
else:
st.success("Optimization completed! Click 'Reset' to try another image.")
# Turn off auto mode when completed
if st.session_state.auto_mode:
st.session_state.auto_mode = False
st.session_state.auto_paused = False
# Reset button
if st.button("Reset"):
st.session_state.optimization_started = False
st.session_state.current_results = None
st.session_state.auto_mode = False
st.session_state.auto_paused = False
st.rerun()
# Display multi-model history with tabs
_display_multi_model_history()
# Model Results Comparison (show when optimization is complete)
if is_completed and hasattr(st.session_state.optimizer, 'get_all_results'):
all_results = st.session_state.optimizer.get_all_results()
best_result = st.session_state.optimizer.get_best_result()
if all_results:
st.subheader("🏆 Model Comparison Results")
# Show best result prominently
if best_result:
st.success(f"🥇 **Best Result**: {best_result['model_name']} with {best_result['similarity']:.2%} similarity")
# Create columns for each model result
num_models = len(all_results)
if num_models > 0:
cols = st.columns(min(num_models, 3)) # Max 3 columns
# Sort results by similarity (best first)
sorted_results = sorted(all_results.items(),
key=lambda x: x[1]['final_similarity'],
reverse=True)
for i, (model_name, result) in enumerate(sorted_results):
col_idx = i % 3
with cols[col_idx]:
# Add medal emoji for top 3
medal = "🥇" if i == 0 else "🥈" if i == 1 else "🥉" if i == 2 else "🏅"
st.markdown(f"### {medal} {model_name}")
# Show final image
st.image(result['final_image'], width='stretch')
# Show metrics
st.metric("Final Similarity", f"{result['final_similarity']:.2%}")
st.metric("Iterations", result['iterations'])
# Show final prompt in expander
with st.expander("View Final Prompt"):
st.text(result['final_prompt'])
def _display_multi_model_history():
"""Display optimization history with tabs for multi-model scenarios."""
# Get all completed model results
all_model_results = {}
if hasattr(st.session_state.optimizer, 'get_all_results'):
all_model_results = st.session_state.optimizer.get_all_results()
# Build separate histories for each model
all_histories = {}
# Add completed model histories - each model gets its own stored history
for model_name, result in all_model_results.items():
if 'history' in result and result['history']:
# Make a deep copy to ensure complete separation
all_histories[model_name] = [step.copy() for step in result['history']]
# Add current model's in-progress history (if any)
if hasattr(st.session_state.optimizer, 'current_optimizer') and st.session_state.optimizer.current_optimizer:
current_history = st.session_state.optimizer.current_optimizer.history
if current_history:
current_model_name = st.session_state.optimizer.get_current_model_name()
# Only add if this model doesn't already have a completed history
if current_model_name not in all_histories:
all_histories[current_model_name] = [step.copy() for step in current_history]
# Display histories
if all_histories:
st.subheader("📊 Optimization History")
if len(all_histories) == 1:
# Single model - no tabs needed
model_name = list(all_histories.keys())[0]
history = all_histories[model_name]
st.markdown(f"**{model_name}** ({len(history)} steps)")
_display_model_history(history, model_name)
else:
# Multi-model - use tabs
tab_names = []
for model_name, history in all_histories.items():
tab_names.append(f"{model_name} ({len(history)} steps)")
tabs = st.tabs(tab_names)
for i, (model_name, history) in enumerate(all_histories.items()):
with tabs[i]:
_display_model_history(history, model_name)
def _display_model_history(history, model_name):
"""Display history for a single model."""
if not history:
st.info(f"No optimization steps yet for {model_name}")
return
for idx, hist_entry in enumerate(history):
st.markdown(f"### Step {idx + 1}")
col1, col2 = st.columns([2, 3])
with col1:
st.image(hist_entry['image'], width='stretch')
with col2:
st.text(f"Similarity: {hist_entry['similarity']:.2%}")
st.text("Prompt:")
st.text(hist_entry['prompt'])
# Toggle analysis view per history entry
expand_key = f"expand_analysis_{model_name}_{idx}"
if 'analysis_expanded' not in st.session_state:
st.session_state['analysis_expanded'] = {}
if expand_key not in st.session_state['analysis_expanded']:
st.session_state['analysis_expanded'][expand_key] = False
if st.session_state['analysis_expanded'][expand_key]:
if st.button("Hide Analysis", key=f"hide_{expand_key}"):
st.session_state['analysis_expanded'][expand_key] = False
st.rerun()
st.text("Analysis:")
for key, value in hist_entry['analysis'].items():
st.text(f"{key}: {value}")
else:
if st.button("Expand Analysis", key=expand_key):
st.session_state['analysis_expanded'][expand_key] = True
st.rerun()
if __name__ == "__main__":
main()