Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import shutil | |
| import multiprocessing | |
| import subprocess | |
| import nltk | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import gc | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| from typing import List | |
| import shutil | |
| import numpy as np | |
| import random | |
| import spaces | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, CLIPFeatureExtractor | |
| from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler | |
| from diffusers.utils import export_to_video | |
| from moviepy.editor import VideoFileClip, CompositeVideoClip, TextClip | |
| import moviepy.editor as mpy | |
| from PIL import Image, ImageDraw, ImageFont | |
| from mutagen.mp3 import MP3 | |
| from gtts import gTTS | |
| from pydub import AudioSegment | |
| import uuid | |
| from safetensors.torch import load_file | |
| import textwrap | |
| # ------------------------------------------------------------------- | |
| # No more ImageMagick dependency! | |
| # ------------------------------------------------------------------- | |
| print("ImageMagick dependency removed. Using Pillow for text rendering.") | |
| # Ensure NLTKβs 'punkt_tab' (and other data) is present | |
| nltk.download('punkt_tab', quiet=True) | |
| nltk.download('punkt', quiet=True) | |
| # ------------------------------------------------------------------- | |
| # GPU / Environment Setup | |
| # ------------------------------------------------------------------- | |
| def log_gpu_memory(): | |
| """Log GPU memory usage.""" | |
| if torch.cuda.is_available(): | |
| print(subprocess.check_output('nvidia-smi').decode('utf-8')) | |
| else: | |
| print("CUDA is not available. Cannot log GPU memory.") | |
| def check_gpu_availability(): | |
| """Print GPU availability and device details.""" | |
| if torch.cuda.is_available(): | |
| print(f"CUDA devices: {torch.cuda.device_count()}") | |
| print(f"Current device: {torch.cuda.current_device()}") | |
| print(torch.cuda.get_device_properties(torch.cuda.current_device())) | |
| else: | |
| print("CUDA is not available. Running on CPU.") | |
| check_gpu_availability() | |
| # Ensure proper multiprocessing start method | |
| multiprocessing.set_start_method("spawn", force=True) | |
| # ------------------------------------------------------------------- | |
| # Constants & Model Setup | |
| # ------------------------------------------------------------------- | |
| dtype = torch.float16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE_720 = 720 # Changed maximum image size to 720, now max resolution is 720p | |
| MAX_IMAGE_SIZE = MAX_IMAGE_SIZE_720 | |
| RESOLUTIONS = { | |
| "16:9": [ | |
| {"resolution": "360p", "width": 640, "height": 360}, | |
| {"resolution": "480p", "width": 854, "height": 480}, | |
| {"resolution": "720p", "width": 1280, "height": 720}, | |
| #{"resolution": "1080p", "width": 1920, "height": 1080} # Commented out resolutions higher than 720p | |
| ], | |
| "4:3": [ | |
| {"resolution": "360p", "width": 480, "height": 360}, | |
| {"resolution": "480p", "width": 640, "height": 480}, | |
| {"resolution": "720p", "width": 960, "height": 720}, | |
| #{"resolution": "1080p", "width": 1440, "height": 1080} # Commented out resolutions higher than 720p | |
| ], | |
| "1:1": [ | |
| {"resolution": "360p", "width": 360, "height": 360}, | |
| {"resolution": "480p", "width": 480, "height": 480}, | |
| {"resolution": "720p", "width": 720, "height": 720}, | |
| #{"resolution": "1080p", "width": 1080, "height": 1080}, # Commented out resolutions higher than 720p | |
| #{"resolution": "1920p", "width": 1920, "height": 1920} # Commented out resolutions higher than 720p | |
| ], | |
| "9:16": [ | |
| {"resolution": "360p", "width": 360, "height": 640}, | |
| {"resolution": "480p", "width": 480, "height": 854}, | |
| {"resolution": "720p", "width": 720, "height": 1280}, | |
| #{"resolution": "1080p", "width": 1080, "height": 1920} # Commented out resolutions higher than 720p | |
| ]} | |
| DESCRIPTION = ( | |
| "Video Story Generator with Audio\n" | |
| "PS: Generation of video by using Artificial Intelligence via AnimateDiff, DistilBART, and GTTS." | |
| ) | |
| TITLE = "Video Story Generator with Audio (AnimateDiff, DistilBART, and GTTS)" | |
| def load_text_summarization_model(): | |
| """Load the tokenizer and model for text summarization on GPU/CPU.""" | |
| print("Loading text summarization model...") | |
| tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6") | |
| return tokenizer, model | |
| tokenizer, model = load_text_summarization_model() | |
| # Base models for AnimateDiffLightning | |
| bases = { | |
| "Cartoon": "frankjoshua/toonyou_beta6", | |
| "Realistic": "emilianJR/epiCRealism", | |
| "3d": "Lykon/DreamShaper", | |
| "Anime": "Yntec/mistoonAnime2" | |
| } | |
| # Keep track of what's loaded to avoid reloading each time | |
| step_loaded = None | |
| base_loaded = "Realistic" | |
| motion_loaded = None | |
| # Initialize AnimateDiff pipeline | |
| if not torch.cuda.is_available(): | |
| raise NotImplementedError("No GPU detected!") | |
| pipe = AnimateDiffPipeline.from_pretrained( | |
| bases[base_loaded], | |
| torch_dtype=dtype | |
| ).to(device) | |
| pipe.scheduler = EulerDiscreteScheduler.from_config( | |
| pipe.scheduler.config, | |
| timestep_spacing="trailing", | |
| beta_schedule="linear" | |
| ) | |
| feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") | |
| # ------------------------------------------------------------------- | |
| # Function: Generate Short Animation | |
| # ------------------------------------------------------------------- | |
| def generate_short_animation( | |
| prompt_text: str, | |
| base: str = "Realistic", | |
| motion: str = "", | |
| step: int = 4, | |
| seed: int = 42, | |
| width: int = 512, | |
| height: int = 512, | |
| ) -> str: | |
| """ | |
| Generates a short animated video (MP4) from a given prompt using AnimateDiffLightning. | |
| Returns the local path to the resulting MP4. | |
| """ | |
| global step_loaded | |
| global base_loaded | |
| global motion_loaded | |
| # 1) Possibly reload correct step weights | |
| if step_loaded != step: | |
| repo = "ByteDance/AnimateDiff-Lightning" | |
| ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" | |
| pipe.unet.load_state_dict( | |
| load_file(hf_hub_download(repo, ckpt), device=device), | |
| strict=False | |
| ) | |
| step_loaded = step | |
| # 2) Possibly reload the correct base model | |
| if base_loaded != base: | |
| pipe.unet.load_state_dict( | |
| torch.load( | |
| hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), | |
| map_location=device | |
| ), | |
| strict=False | |
| ) | |
| base_loaded = base | |
| # 3) Possibly unload/load motion LORA | |
| if motion_loaded != motion: | |
| pipe.unload_lora_weights() | |
| if motion: | |
| pipe.load_lora_weights(motion, adapter_name="motion") | |
| pipe.set_adapters(["motion"], [0.7]) # weighting can be adjusted | |
| motion_loaded = motion | |
| # 4) Generate frames | |
| print(f"[INFO] Generating short animation for prompt: '{prompt_text}' ...") | |
| generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None | |
| output = pipe( | |
| prompt=prompt_text, | |
| guidance_scale=1.2, | |
| num_inference_steps=step, | |
| generator=generator, | |
| width=width, | |
| height=height | |
| ) | |
| # 5) Export frames to a short MP4 | |
| short_mp4_path = f"short_{uuid.uuid4().hex}.mp4" | |
| export_to_video(output.frames[0], short_mp4_path, fps=10) | |
| return short_mp4_path | |
| # ------------------------------------------------------------------- | |
| # Function: Merge MP3 files | |
| # ------------------------------------------------------------------- | |
| def merge_audio_files(mp3_names: List[str]) -> str: | |
| """ | |
| Merges a list of MP3 files into a single MP3 file. | |
| Returns the path to the merged MP3 file. | |
| """ | |
| combined = AudioSegment.empty() | |
| for f_name in mp3_names: | |
| audio = AudioSegment.from_mp3(f_name) | |
| combined += audio | |
| export_path = f"merged_audio_{uuid.uuid4().hex}.mp3" # Dynamic output path for merged audio | |
| combined.export(export_path, format="mp3") | |
| print(f"DEBUG: Audio files merged and saved to {export_path}") | |
| return export_path | |
| # ------------------------------------------------------------------- | |
| # Function: Overlay Subtitles on a Video | |
| # ------------------------------------------------------------------- | |
| def add_subtitles_to_video(input_video_path: str, text: str, duration: float) -> str: | |
| """ | |
| Overlays `text` as subtitles over the entire `input_video_path` for `duration` seconds using Pillow. | |
| Returns the path to the newly generated MP4 with subtitles. | |
| """ | |
| base_clip = VideoFileClip(input_video_path) | |
| final_dur = max(duration, base_clip.duration) | |
| def make_frame(t): | |
| frame_pil = Image.fromarray(base_clip.get_frame(t)) | |
| draw = ImageDraw.Draw(frame_pil) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 40) # Change the font size if needed | |
| except IOError: | |
| font = ImageFont.load_default() # Use default font if Arial is not found | |
| # Correctly compute text size using `textbbox()` | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| textwidth, textheight = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| x = (frame_pil.width - textwidth) / 2 | |
| y = frame_pil.height - 70 - textheight # Position at the bottom | |
| draw.text((x, y), text, font=font, fill=(255, 255, 0)) # Yellow color | |
| return np.array(frame_pil) | |
| # Create the video clip without `size` argument | |
| subtitled_clip = mpy.VideoClip(make_frame, duration=final_dur) | |
| # Composite the subtitled clip over the original video | |
| final_clip = CompositeVideoClip([base_clip, subtitled_clip.set_position((0, 0))]) | |
| final_clip = final_clip.set_duration(final_dur) | |
| out_path = f"sub_{uuid.uuid4().hex}.mp4" | |
| final_clip.write_videofile(out_path, fps=24, logger=None) | |
| # Cleanup | |
| base_clip.close() | |
| final_clip.close() | |
| subtitled_clip.close() | |
| return out_path | |
| # ------------------------------------------------------------------- | |
| # Main Function: Generate Output Video | |
| # ------------------------------------------------------------------- | |
| def get_output_video(text, base_model_name, motion_name, num_inference_steps_backend, randomize_seed, seed, width, height): | |
| """ | |
| Summarize the user prompt, generate a short animated video for each sentence, | |
| overlay subtitles, merge all into a final video with a single audio track. | |
| """ | |
| print("DEBUG: Starting get_output_video function...") | |
| # Summarize the input text | |
| print("DEBUG: Summarizing text...") | |
| device_local = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device_local) # Move summarization model to GPU/CPU as needed | |
| inputs = tokenizer( | |
| text, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device_local) | |
| summary_ids = model.generate(inputs["input_ids"]) | |
| summary = tokenizer.batch_decode( | |
| summary_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| ) | |
| plot = list(summary[0].split('.')) # Split summary into sentences | |
| print(f"DEBUG: Summary generated: {plot}") | |
| # Prepare seed based on randomize_seed checkbox | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else seed | |
| # We'll generate a short video for each sentence | |
| # We'll also create an audio track for each sentence | |
| short_videos = [] | |
| mp3_names = [] | |
| mp3_lengths = [] | |
| result_no_audio = f"result_no_audio_{uuid.uuid4().hex}.mp4" # Dynamic filename for no audio video | |
| movie_final = f'result_final_{uuid.uuid4().hex}.mp4' # Dynamic filename for final video | |
| merged_audio_path = "" # To store merged audio path for cleanup | |
| try: # Try-finally block to ensure cleanup | |
| for i, sentence in enumerate(plot[:-1]): | |
| # 1) Generate short video for this sentence | |
| prompt_for_animation = f"Generate a realistic video about this: {sentence}" | |
| print(f"DEBUG: Generating short video {i+1} of {len(plot)-1} ...") | |
| short_mp4_path = generate_short_animation( | |
| prompt_text=prompt_for_animation, | |
| base=base_model_name, | |
| motion=motion_name, | |
| step=int(num_inference_steps_backend), | |
| seed=current_seed + i, # Increment seed for each sentence for variation | |
| width=width, | |
| height=height | |
| ) | |
| # 2) Generate audio for the sentence | |
| audio_filename = f'audio_{uuid.uuid4().hex}_{i}.mp3' # Dynamic audio filename | |
| tts_obj = gTTS(text=sentence, lang='en', slow=False) | |
| tts_obj.save(audio_filename) | |
| audio_info = MP3(audio_filename) | |
| audio_duration = audio_info.info.length | |
| mp3_names.append(audio_filename) | |
| mp3_lengths.append(audio_duration) | |
| # 3) Overlay subtitles on top of the short video (using Pillow now) | |
| final_clip_duration = audio_duration + 0.5 # half-second pad | |
| short_subtitled_path = add_subtitles_to_video( | |
| input_video_path=short_mp4_path, | |
| text=sentence.strip(), | |
| duration=final_clip_duration | |
| ) | |
| short_videos.append(short_subtitled_path) | |
| # Clean up the original short clip (no subtitles) | |
| os.remove(short_mp4_path) | |
| # ---------------------------------------------------------------- | |
| # Merge all MP3 files into one | |
| # ---------------------------------------------------------------- | |
| merged_audio_path = merge_audio_files(mp3_names) | |
| # ---------------------------------------------------------------- | |
| # Concatenate all short subtitled videos | |
| # ---------------------------------------------------------------- | |
| print("DEBUG: Concatenating all short videos into a single clip...") | |
| clip_objects = [] | |
| for vid_path in short_videos: | |
| clip = mpy.VideoFileClip(vid_path) | |
| clip_objects.append(clip) | |
| final_concat = mpy.concatenate_videoclips(clip_objects, method="compose") | |
| final_concat.write_videofile(result_no_audio, fps=24, logger=None) | |
| # ---------------------------------------------------------------- | |
| # Combine big video with merged audio | |
| # ---------------------------------------------------------------- | |
| def combine_audio(vidname, audname, outname, fps=24): | |
| print(f"DEBUG: Combining audio for video: '{vidname}'") | |
| my_clip = mpy.VideoFileClip(vidname) | |
| audio_background = mpy.AudioFileClip(audname) | |
| final_clip = my_clip.set_audio(audio_background) | |
| final_clip.write_videofile(outname, fps=fps, logger=None) | |
| my_clip.close() | |
| final_clip.close() | |
| combine_audio(result_no_audio, merged_audio_path, movie_final) | |
| finally: # Cleanup always executes | |
| print("DEBUG: Cleaning up temporary files...") | |
| # Remove short subtitled videos | |
| for path_ in short_videos: | |
| os.remove(path_) | |
| # Remove mp3 segments | |
| for f_mp3 in mp3_names: | |
| os.remove(f_mp3) | |
| # Remove merged audio | |
| if os.path.exists(merged_audio_path): | |
| os.remove(merged_audio_path) | |
| # Remove partial no-audio mp4 | |
| if os.path.exists(result_no_audio): | |
| os.remove(result_no_audio) | |
| print("DEBUG: get_output_video function completed successfully.") | |
| return movie_final | |
| # ------------------------------------------------------------------- | |
| # Example text (user can override) | |
| # ------------------------------------------------------------------- | |
| text = ( | |
| "Once, there was a girl called Laura who went to the supermarket to buy the ingredients to make a cake. " | |
| "Because today is her birthday and her friends come to her house and help her to prepare the cake." | |
| ) | |
| # ------------------------------------------------------------------- | |
| # Gradio Interface | |
| # ------------------------------------------------------------------- | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown( | |
| """ | |
| # Video Generator β‘ from stories with Artificial Intelligence | |
| A story can be input by user. The story is summarized using DistilBART model. | |
| Then, the images are generated by using AnimateDiff and AnimateDiff-Lightning, | |
| and the subtitles and audio are created using gTTS. These are combined to generate a video. | |
| **Credits**: Developed by [ruslanmv.com](https://ruslanmv.com). | |
| """ | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| input_start_text = gr.Textbox(value=text, label='Prompt') | |
| with gr.Row(): | |
| select_base = gr.Dropdown( | |
| label='Base model', | |
| choices=["Cartoon", "Realistic", "3d", "Anime"], | |
| value=base_loaded, | |
| interactive=True | |
| ) | |
| select_motion = gr.Dropdown( | |
| label='Motion', | |
| choices=[ | |
| ("Default", ""), | |
| ("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"), | |
| ("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"), | |
| ("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"), | |
| ("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"), | |
| ("Pan left", "guoyww/animatediff-motion-lora-pan-left"), | |
| ("Pan right", "guoyww/animatediff-motion-lora-pan-right"), | |
| ("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"), | |
| ("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"), | |
| ], | |
| value="", # default: no motion lora | |
| interactive=True | |
| ) | |
| select_step = gr.Dropdown( | |
| label='Inference steps', | |
| choices=[('1-Step', 1), ('2-Step', 2), ('4-Step', 4), ('8-Step', 8)], | |
| value=4, | |
| interactive=True | |
| ) | |
| button_gen_video = gr.Button( | |
| scale=1, | |
| variant='primary', | |
| value="Generate Video" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE_720, # μ ν 720 pixels maximum μ¬μ΄μ¦, updated max size to 720p | |
| step=1, | |
| value=640, # Default width for 480p 4:3 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE_720, # μ ν 720 pixels maximum μ¬μ΄μ¦, updated max size to 720p | |
| step=1, | |
| value=480, # Default height for 480p 4:3 | |
| ) | |
| with gr.Column(): | |
| #output_interpolation = gr.Video(label="Generated Video") | |
| output_interpolation = gr.Video(value="video.mp4", label="Generated Video") # Set default video | |
| button_gen_video.click( | |
| fn=get_output_video, | |
| inputs=[input_start_text, select_base, select_motion, select_step, randomize_seed, seed, width, height], | |
| outputs=output_interpolation | |
| ) | |
| demo.queue().launch(debug=True, share=False) |