Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import os | |
| import spaces | |
| import sys | |
| sys.path.append('./VADER-VideoCrafter/scripts/main') | |
| sys.path.append('./VADER-VideoCrafter/scripts') | |
| sys.path.append('./VADER-VideoCrafter') | |
| from train_t2v_lora import main_fn, setup_model | |
| examples = [ | |
| ["A fairy tends to enchanted, glowing flowers.", 'huggingface-hps-aesthetic', 8, 400, 384, 512, 12.0, 25, 1.0, 24, 10], | |
| ["A cat playing an electric guitar in a loft with industrial-style decor and soft, multicolored lights.", 'huggingface-hps-aesthetic', 8, 206, 384, 512, 12.0, 25, 1.0, 24, 10], | |
| ["A raccoon playing a guitar under a blossoming cherry tree.", 'huggingface-hps-aesthetic', 8, 204, 384, 512, 12.0, 25, 1.0, 24, 10], | |
| ["A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.", | |
| "huggingface-pickscore", 16, 205, 384, 512, 12.0, 25, 1.0, 24, 10], | |
| ["A talking bird with shimmering feathers and a melodious voice leads an adventure to find a legendary treasure, guiding through enchanted forests, ancient ruins, and mystical challenges.", | |
| "huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10] | |
| ] | |
| model = None # Placeholder for model | |
| def gradio_main_fn(prompt, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta, | |
| frames, savefps): | |
| global model | |
| if model is None: | |
| return "Model is not loaded. Please load the model first." | |
| video_path = main_fn(prompt=prompt, | |
| seed=int(seed), | |
| height=int(height), | |
| width=int(width), | |
| unconditional_guidance_scale=float(unconditional_guidance_scale), | |
| ddim_steps=int(ddim_steps), | |
| ddim_eta=float(ddim_eta), | |
| frames=int(frames), | |
| savefps=int(savefps), | |
| model=model) | |
| return video_path | |
| def reset_fn(): | |
| return ("A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.", | |
| 200, 384, 512, 12.0, 25, 1.0, 24, 16, 10, "huggingface-pickscore") | |
| def update_lora_rank(lora_model): | |
| if lora_model == "huggingface-pickscore": | |
| return gr.update(value=16) | |
| elif lora_model == "huggingface-hps-aesthetic": | |
| return gr.update(value=8) | |
| else: # "Base Model" | |
| return gr.update(value=8) | |
| def update_dropdown(lora_rank): | |
| if lora_rank == 16: | |
| return gr.update(value="huggingface-pickscore") | |
| elif lora_rank == 8: | |
| return gr.update(value="huggingface-hps-aesthetic") | |
| else: # 0 | |
| return gr.update(value="Base Model") | |
| def setup_model_progress(lora_model, lora_rank): | |
| global model | |
| # Disable buttons and show loading indicator | |
| yield (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "Loading model...") | |
| model = setup_model(lora_model, lora_rank) # Ensure you pass the necessary parameters to the setup_model function | |
| # Enable buttons after loading and update indicator | |
| yield (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), "Model loaded successfully") | |
| def generate_example(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta, | |
| frames, savefps): | |
| global model | |
| model = setup_model(lora_model, lora_rank) | |
| video_path = main_fn(prompt=prompt, | |
| seed=int(seed), | |
| height=int(height), | |
| width=int(width), | |
| unconditional_guidance_scale=float(unconditional_guidance_scale), | |
| ddim_steps=int(ddim_steps), | |
| ddim_eta=float(ddim_eta), | |
| frames=int(frames), | |
| savefps=int(savefps), | |
| model=model) | |
| return video_path | |
| custom_css = """ | |
| #centered { | |
| display: flex; | |
| justify-content: center; | |
| } | |
| .column-centered { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| width: 60%; | |
| } | |
| #image-upload { | |
| flex-grow: 1; | |
| } | |
| #params .tabs { | |
| display: flex; | |
| flex-direction: column; | |
| flex-grow: 1; | |
| } | |
| #params .tabitem[style="display: block;"] { | |
| flex-grow: 1; | |
| display: flex !important; | |
| } | |
| #params .gap { | |
| flex-grow: 1; | |
| } | |
| #params .form { | |
| flex-grow: 1 !important; | |
| } | |
| #params .form > :last-child{ | |
| flex-grow: 1; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML( | |
| """ | |
| <h1 style='text-align: center; font-size: 3.2em; margin-bottom: 0.5em; font-family: Arial, sans-serif; margin: 20px;'> | |
| Video Diffusion Alignment via Reward Gradient | |
| </h1> | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <style> | |
| body { | |
| font-family: Arial, sans-serif; | |
| text-align: center; | |
| margin: 50px; | |
| } | |
| a { | |
| text-decoration: none !important; | |
| color: black !important; | |
| } | |
| </style> | |
| <body> | |
| <div style="font-size: 1.4em; margin-bottom: 0.5em; "> | |
| <a href="https://mihirp1998.github.io">Mihir Prabhudesai</a><sup>*</sup> | |
| <a href="https://russellmendonca.github.io/">Russell Mendonca</a><sup>*</sup> | |
| <a href="mailto: zheyangqin.qzy@gmail.com">Zheyang Qin</a><sup>*</sup> | |
| <a href="https://www.cs.cmu.edu/~katef/">Katerina Fragkiadaki</a><sup></sup> | |
| <a href="https://www.cs.cmu.edu/~dpathak/">Deepak Pathak</a><sup></sup> | |
| </div> | |
| <div style="font-size: 1.3em; font-style: italic;"> | |
| Carnegie Mellon University | |
| </div> | |
| </body> | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <head> | |
| <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css"> | |
| <style> | |
| .button-container { | |
| display: flex; | |
| justify-content: center; | |
| gap: 10px; | |
| margin-top: 10px; | |
| } | |
| .button-container a { | |
| display: inline-flex; | |
| align-items: center; | |
| padding: 10px 20px; | |
| border-radius: 30px; | |
| border: 1px solid #ccc; | |
| text-decoration: none; | |
| color: #333 !important; | |
| font-size: 16px; | |
| text-decoration: none !important; | |
| } | |
| .button-container a i { | |
| margin-right: 8px; | |
| } | |
| </style> | |
| </head> | |
| <div class="button-container"> | |
| <a href="https://arxiv.org/abs/2407.08737" class="btn btn-outline-primary"> | |
| <i class="fa-solid fa-file-pdf"></i> Paper | |
| </a> | |
| <a href="https://vader-vid.github.io/" class="btn btn-outline-danger"> | |
| <i class="fa-solid fa-video"></i> Website | |
| <a href="https://github.com/mihirp1998/VADER" class="btn btn-outline-secondary"> | |
| <i class="fa-brands fa-github"></i> Code | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(elem_id="centered"): | |
| with gr.Column(scale=0.3, elem_id="params"): | |
| lora_model = gr.Dropdown( | |
| label="VADER Model", | |
| choices=["huggingface-pickscore", "huggingface-hps-aesthetic", "Base Model"], | |
| value="huggingface-pickscore" | |
| ) | |
| lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16) | |
| load_btn = gr.Button("Load Model") | |
| # Add a label to show the loading indicator | |
| loading_indicator = gr.Label(value="", label="Loading Indicator") | |
| with gr.Column(scale=0.3): | |
| output_video = gr.Video(elem_id="image-upload") | |
| with gr.Row(elem_id="centered"): | |
| with gr.Column(scale=0.6): | |
| prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt", | |
| value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.") | |
| seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200) | |
| run_btn = gr.Button("Run Inference") | |
| with gr.Row(): | |
| height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=384) | |
| width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512) | |
| with gr.Row(): | |
| frames = gr.Slider(minimum=0, maximum=50, label="Frames", step = 1, value=24) | |
| savefps = gr.Slider(minimum=0, maximum=60, label="Save FPS", step = 1, value=10) | |
| with gr.Row(): | |
| DDIM_Steps = gr.Slider(minimum=0, maximum=100, label="DDIM Steps", step = 1, value=25) | |
| unconditional_guidance_scale = gr.Slider(minimum=0, maximum=50, label="Guidance Scale", step = 0.1, value=12.0) | |
| DDIM_Eta = gr.Slider(minimum=0, maximum=1, label="DDIM Eta", step = 0.01, value=1.0) | |
| # reset button | |
| reset_btn = gr.Button("Reset") | |
| reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model]) | |
| load_btn.click(fn=setup_model_progress, inputs=[lora_model, lora_rank], outputs=[load_btn, run_btn, reset_btn, loading_indicator]) | |
| run_btn.click(fn=gradio_main_fn, | |
| inputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, savefps], | |
| outputs=output_video | |
| ) | |
| lora_model.change(fn=update_lora_rank, inputs=lora_model, outputs=lora_rank) | |
| lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model) | |
| gr.Examples(examples=examples, | |
| inputs=[prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, savefps], | |
| outputs=output_video, | |
| fn=generate_example, | |
| run_on_click=False, | |
| cache_examples="lazy", | |
| ) | |
| demo.launch(share=True) |