Spaces:
Paused
Paused
| import spaces | |
| import os | |
| import cv2 | |
| import torch | |
| import gradio as gr | |
| import torchvision | |
| import warnings | |
| import numpy as np | |
| from PIL import Image, ImageSequence | |
| from moviepy.editor import VideoFileClip | |
| import imageio | |
| from diffusers import ( | |
| TextToVideoSDPipeline, | |
| AutoencoderKL, | |
| DDPMScheduler, | |
| DDIMScheduler, | |
| UNet3DConditionModel, | |
| ) | |
| from transformers import CLIPTokenizer, CLIPTextModel | |
| from diffusers.utils import export_to_video | |
| from typing import List | |
| from text2vid_modded import TextToVideoSDPipelineModded | |
| from invert_utils import ddim_inversion as dd_inversion | |
| from gifs_filter import filter | |
| import subprocess | |
| import uuid | |
| from huggingface_hub import snapshot_download | |
| def load_frames(image: Image, mode='RGBA'): | |
| return np.array([np.array(frame.convert(mode)) for frame in ImageSequence.Iterator(image)]) | |
| os.makedirs("t2v_sketch-lora", exist_ok=True) | |
| snapshot_download( | |
| repo_id="Hmrishav/t2v_sketch-lora", | |
| local_dir="./t2v_sketch-lora" | |
| ) | |
| def save_gif(frames, path): | |
| imageio.mimsave( | |
| path, | |
| [frame.astype(np.uint8) for frame in frames], | |
| format="GIF", | |
| duration=1 / 10, | |
| loop=0 # 0 means infinite loop | |
| ) | |
| def load_image(imgname, target_size=None): | |
| pil_img = Image.open(imgname).convert('RGB') | |
| if target_size: | |
| if isinstance(target_size, int): | |
| target_size = (target_size, target_size) | |
| pil_img = pil_img.resize(target_size, Image.Resampling.LANCZOS) | |
| return torchvision.transforms.ToTensor()(pil_img).unsqueeze(0) | |
| def prepare_latents(pipe, x_aug): | |
| with torch.cuda.amp.autocast(): | |
| batch_size, num_frames, channels, height, width = x_aug.shape | |
| x_aug = x_aug.reshape(batch_size * num_frames, channels, height, width) | |
| latents = pipe.vae.encode(x_aug).latent_dist.sample() | |
| latents = latents.view(batch_size, num_frames, -1, latents.shape[2], latents.shape[3]) | |
| latents = latents.permute(0, 2, 1, 3, 4) | |
| return pipe.vae.config.scaling_factor * latents | |
| def invert(pipe, inv, load_name, device="cuda", dtype=torch.bfloat16): | |
| input_img = [load_image(load_name, 256).to(device, dtype=dtype).unsqueeze(1)] * 5 | |
| input_img = torch.cat(input_img, dim=1) | |
| latents = prepare_latents(pipe, input_img).to(torch.bfloat16) | |
| inv.set_timesteps(25) | |
| id_latents = dd_inversion(pipe, inv, video_latent=latents, num_inv_steps=25, prompt="")[-1].to(dtype) | |
| return torch.mean(id_latents, dim=2, keepdim=True) | |
| def load_primary_models(pretrained_model_path): | |
| return ( | |
| DDPMScheduler.from_config(pretrained_model_path, subfolder="scheduler"), | |
| CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer"), | |
| CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder"), | |
| AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae"), | |
| UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet"), | |
| ) | |
| def initialize_pipeline(model: str, device: str = "cuda"): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(model) | |
| pipe = TextToVideoSDPipeline.from_pretrained( | |
| pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b", | |
| scheduler=scheduler, | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder.to(device=device, dtype=torch.bfloat16), | |
| vae=vae.to(device=device, dtype=torch.bfloat16), | |
| unet=unet.to(device=device, dtype=torch.bfloat16), | |
| ) | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| return pipe, pipe.scheduler | |
| # Initialize the models | |
| LORA_CHECKPOINT = "t2v_sketch-lora/checkpoint-2500" | |
| os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| dtype = torch.bfloat16 | |
| pipe_inversion, inv = initialize_pipeline(LORA_CHECKPOINT, device) | |
| pipe = TextToVideoSDPipelineModded.from_pretrained( | |
| pretrained_model_name_or_path="damo-vilab/text-to-video-ms-1.7b", | |
| scheduler=pipe_inversion.scheduler, | |
| tokenizer=pipe_inversion.tokenizer, | |
| text_encoder=pipe_inversion.text_encoder, | |
| vae=pipe_inversion.vae, | |
| unet=pipe_inversion.unet, | |
| ).to(device) | |
| def process_video(num_frames, num_seeds, generator, exp_dir, load_name, caption, lambda_): | |
| pipe_inversion.to(device) | |
| id_latents = invert(pipe_inversion, inv, load_name).to(device, dtype=dtype) | |
| latents = id_latents.repeat(num_seeds, 1, 1, 1, 1) | |
| generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(num_seeds)] | |
| video_frames = pipe( | |
| prompt=caption, | |
| negative_prompt="", | |
| num_frames=num_frames, | |
| num_inference_steps=25, | |
| inv_latents=latents, | |
| guidance_scale=9, | |
| generator=generator, | |
| lambda_=lambda_, | |
| ).frames | |
| gifs = [] | |
| for seed in range(num_seeds): | |
| vid_name = f"{exp_dir}/mp4_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.mp4" | |
| gif_name = f"{exp_dir}/gif_logs/vid_{os.path.basename(load_name)[:-4]}-rand{seed}.gif" | |
| os.makedirs(os.path.dirname(vid_name), exist_ok=True) | |
| os.makedirs(os.path.dirname(gif_name), exist_ok=True) | |
| video_path = export_to_video(video_frames[seed], output_video_path=vid_name) | |
| VideoFileClip(vid_name).write_gif(gif_name) | |
| with Image.open(gif_name) as im: | |
| frames = load_frames(im) | |
| frames_collect = np.empty((0, 1024, 1024), int) | |
| for frame in frames: | |
| frame = cv2.resize(frame, (1024, 1024))[:, :, :3] | |
| frame = cv2.cvtColor(255 - frame, cv2.COLOR_RGB2GRAY) | |
| _, frame = cv2.threshold(255 - frame, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| frames_collect = np.append(frames_collect, [frame], axis=0) | |
| save_gif(frames_collect, gif_name) | |
| gifs.append(gif_name) | |
| return gifs | |
| def generate_output(image, apply_filter, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5, progress=gr.Progress(track_tqdm=True)) -> List[str]: | |
| """Main function to generate output GIFs""" | |
| unique_id = str(uuid.uuid4()) | |
| exp_dir = f"static/app_tmp_{unique_id}" | |
| os.makedirs(exp_dir, exist_ok=True) | |
| # Save the input image temporarily | |
| temp_image_path = os.path.join(exp_dir, "temp_input.png") | |
| image.save(temp_image_path) | |
| # Generate the GIFs | |
| generated_gifs = process_video( | |
| num_frames=10, | |
| num_seeds=num_seeds, | |
| generator=None, | |
| exp_dir=exp_dir, | |
| load_name=temp_image_path, | |
| caption=prompt, | |
| lambda_=1 - lambda_value | |
| ) | |
| if apply_filter == True: | |
| print("APPLYING FILTER") | |
| # Apply filtering (assuming filter function is imported) | |
| filtered_gifs = filter(generated_gifs, temp_image_path) | |
| return filtered_gifs, filtered_gifs | |
| else: | |
| print("NOT APPLYING FILTER") | |
| return generated_gifs, generated_gifs | |
| def generate_output_from_sketchpad(image, apply_filter, prompt: str, num_seeds: int = 3, lambda_value: float = 0.5, progress=gr.Progress(track_tqdm=True)): | |
| image = image['composite'] | |
| results, results_to_download= generate_output(image, apply_filter, prompt, num_seeds, lambda_value) | |
| return results, results_to_download | |
| css=""" """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| <div align="center" id = "user-content-toc"> | |
| <img align="left" width="70" height="70" src="https://github.com/user-attachments/assets/c61cec76-3c4b-42eb-8c65-f07e0166b7d8" alt=""> | |
| # [FlipSketch: Flipping assets Drawings to Text-Guided Sketch Animations](https://hmrishavbandy.github.io/flipsketch-web/) | |
| ## [Hmrishav Bandyopadhyay](https://hmrishavbandy.github.io/) . [Yi-Zhe Song](https://personalpages.surrey.ac.uk/y.song/) | |
| </div> | |
| """ | |
| ) | |
| with gr.Tab("Upload your sketch"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_sketch = gr.Image( | |
| type="pil", | |
| label="Selected Sketch", | |
| scale=1, | |
| interactive=True, | |
| height=300 # Fixed height for consistency | |
| ) | |
| motion_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the motion...", | |
| lines=2 | |
| ) | |
| num_seeds = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Seeds" | |
| ) | |
| lambda_ = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.1, | |
| label="Motion Strength" | |
| ) | |
| apply_filter = gr.Checkbox( | |
| label="Apply GIFs Filters", | |
| value=True, | |
| info="If Apply Filters is checked, non accurate results compared to input sketch will be filtered off", | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[ | |
| ['./static/examples/sketch1.png', 'The camel walks slowly'], | |
| ['./static/examples/sketch2.png', 'The wine in the wine glass sways from side to side'], | |
| ['./static/examples/sketch3.png', 'The squirrel is eating a nut'], | |
| ['./static/examples/sketch4.png', 'The surfer surfs on the waves'], | |
| ['./static/examples/sketch5.png', 'A galloping horse'], | |
| ['./static/examples/sketch6.png', 'The cat walks forward'], | |
| ['./static/examples/sketch7.png', 'The eagle flies in the sky'], | |
| ['./static/examples/sketch8.png', 'The flower is blooming slowly'], | |
| ['./static/examples/sketch9.png', 'The reindeer looks around'], | |
| ['./static/examples/sketch10.png', 'The cloud floats in the sky'], | |
| ['./static/examples/sketch11.png', 'The jazz saxophonist performs on stage with a rhythmic sway, his upper body sways subtly to the rhythm of the music.'], | |
| ['./static/examples/sketch12.png', 'The biker rides on the road'] | |
| ], | |
| inputs=[input_sketch, motion_prompt], | |
| examples_per_page=4, | |
| ) | |
| generate_btn = gr.Button( | |
| "Generate Animation", | |
| variant="primary", | |
| elem_classes="generate-btn", | |
| interactive=True, | |
| ) | |
| with gr.Tab("Draw your own"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| draw_sketchpad = gr.Sketchpad( | |
| label="Draw your own Sketch", | |
| value={ | |
| "background": "./static/examples/background.jpeg", | |
| "layers": None, | |
| "composite": None | |
| }, | |
| type="pil", | |
| image_mode="RGB", | |
| layers=False, | |
| ) | |
| with gr.Column(): | |
| draw_motion_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the motion...", | |
| lines=2 | |
| ) | |
| draw_num_seeds = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Seeds" | |
| ) | |
| draw_lambda_ = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.1, | |
| label="Motion Strength" | |
| ) | |
| draw_apply_filter = gr.Checkbox( | |
| label="Apply GIFs Filters", | |
| info="If Apply Filters is checked, non accurate results compared to input sketch will be filtered off", | |
| value=False | |
| ) | |
| sketchpad_generate_btn = gr.Button( | |
| "Generate Animation", | |
| variant="primary", | |
| elem_classes="generate-btn", | |
| interactive=True, | |
| ) | |
| output_gallery = gr.Gallery( | |
| label="Results", | |
| elem_classes="output-gallery", | |
| columns=3, | |
| rows=2, | |
| height="auto" | |
| ) | |
| download_gifs = gr.Files( | |
| label="Download GIFs" | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_output, | |
| inputs=[ | |
| input_sketch, | |
| apply_filter, | |
| motion_prompt, | |
| num_seeds, | |
| lambda_ | |
| ], | |
| outputs=[output_gallery, download_gifs] | |
| ) | |
| def reload_pad(): | |
| blank_pad ={ | |
| "background": "./static/examples/background.jpeg", | |
| "layers": None, | |
| "composite": None | |
| } | |
| return blank_pad | |
| draw_sketchpad.clear( | |
| fn=reload_pad, | |
| inputs = None, | |
| outputs = [draw_sketchpad], | |
| queue=False | |
| ) | |
| sketchpad_generate_btn.click( | |
| fn=generate_output_from_sketchpad, | |
| inputs=[ | |
| draw_sketchpad, | |
| draw_apply_filter, | |
| draw_motion_prompt, | |
| draw_num_seeds, | |
| draw_lambda_, | |
| ], | |
| outputs=[output_gallery, download_gifs] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_api=False, | |
| ssr_mode=False | |
| ) |