Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import argparse | |
| import os | |
| import random | |
| import time | |
| import torch | |
| import numpy as np | |
| from models.ltx.ltx_vace import LTXVace | |
| from annotators.utils import save_one_video, save_one_image, get_annotator | |
| MAX_HEIGHT = 720 | |
| MAX_WIDTH = 1280 | |
| MAX_NUM_FRAMES = 257 | |
| def get_total_gpu_memory(): | |
| if torch.cuda.is_available(): | |
| total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| return total_memory | |
| return None | |
| def seed_everething(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| def get_parser(): | |
| parser = argparse.ArgumentParser( | |
| description="Load models from separate directories and run the pipeline." | |
| ) | |
| # Directories | |
| parser.add_argument( | |
| "--ckpt_path", | |
| type=str, | |
| default='models/VACE-LTX-Video-0.9/ltx-video-2b-v0.9.safetensors', | |
| help="Path to a safetensors file that contains all model parts.", | |
| ) | |
| parser.add_argument( | |
| "--text_encoder_path", | |
| type=str, | |
| default='models/VACE-LTX-Video-0.9', | |
| help="Path to a safetensors file that contains all model parts.", | |
| ) | |
| parser.add_argument( | |
| "--src_video", | |
| type=str, | |
| default=None, | |
| help="The file of the source video. Default None.") | |
| parser.add_argument( | |
| "--src_mask", | |
| type=str, | |
| default=None, | |
| help="The file of the source mask. Default None.") | |
| parser.add_argument( | |
| "--src_ref_images", | |
| type=str, | |
| default=None, | |
| help="The file list of the source reference images. Separated by ','. Default None.") | |
| parser.add_argument( | |
| "--save_dir", | |
| type=str, | |
| default=None, | |
| help="Path to the folder to save output video, if None will save in results/ directory.", | |
| ) | |
| parser.add_argument("--seed", type=int, default="42") | |
| # Pipeline parameters | |
| parser.add_argument( | |
| "--num_inference_steps", type=int, default=40, help="Number of inference steps" | |
| ) | |
| parser.add_argument( | |
| "--num_images_per_prompt", | |
| type=int, | |
| default=1, | |
| help="Number of images per prompt", | |
| ) | |
| parser.add_argument( | |
| "--context_scale", | |
| type=float, | |
| default=1.0, | |
| help="Context scale for the pipeline", | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale", | |
| type=float, | |
| default=3, | |
| help="Guidance scale for the pipeline", | |
| ) | |
| parser.add_argument( | |
| "--stg_scale", | |
| type=float, | |
| default=1, | |
| help="Spatiotemporal guidance scale for the pipeline. 0 to disable STG.", | |
| ) | |
| parser.add_argument( | |
| "--stg_rescale", | |
| type=float, | |
| default=0.7, | |
| help="Spatiotemporal guidance rescaling scale for the pipeline. 1 to disable rescale.", | |
| ) | |
| parser.add_argument( | |
| "--stg_mode", | |
| type=str, | |
| default="stg_a", | |
| help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.", | |
| ) | |
| parser.add_argument( | |
| "--stg_skip_layers", | |
| type=str, | |
| default="19", | |
| help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.", | |
| ) | |
| parser.add_argument( | |
| "--image_cond_noise_scale", | |
| type=float, | |
| default=0.15, | |
| help="Amount of noise to add to the conditioned image", | |
| ) | |
| parser.add_argument( | |
| "--height", | |
| type=int, | |
| default=512, | |
| help="The height of the output video only if src_video is empty.", | |
| ) | |
| parser.add_argument( | |
| "--width", | |
| type=int, | |
| default=768, | |
| help="The width of the output video only if src_video is empty.", | |
| ) | |
| parser.add_argument( | |
| "--num_frames", | |
| type=int, | |
| default=97, | |
| help="The frames of the output video only if src_video is empty.", | |
| ) | |
| parser.add_argument( | |
| "--frame_rate", type=int, default=25, help="Frame rate for the output video" | |
| ) | |
| parser.add_argument( | |
| "--precision", | |
| choices=["bfloat16", "mixed_precision"], | |
| default="bfloat16", | |
| help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.", | |
| ) | |
| # VAE noise augmentation | |
| parser.add_argument( | |
| "--decode_timestep", | |
| type=float, | |
| default=0.05, | |
| help="Timestep for decoding noise", | |
| ) | |
| parser.add_argument( | |
| "--decode_noise_scale", | |
| type=float, | |
| default=0.025, | |
| help="Noise level for decoding noise", | |
| ) | |
| # Prompts | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| required=True, | |
| help="Text prompt to guide generation", | |
| ) | |
| parser.add_argument( | |
| "--negative_prompt", | |
| type=str, | |
| default="worst quality, inconsistent motion, blurry, jittery, distorted", | |
| help="Negative prompt for undesired features", | |
| ) | |
| parser.add_argument( | |
| "--offload_to_cpu", | |
| action="store_true", | |
| help="Offloading unnecessary computations to CPU.", | |
| ) | |
| parser.add_argument( | |
| "--use_prompt_extend", | |
| default='plain', | |
| choices=['plain', 'ltx_en', 'ltx_en_ds'], | |
| help="Whether to use prompt extend." | |
| ) | |
| return parser | |
| def main(args): | |
| args = argparse.Namespace(**args) if isinstance(args, dict) else args | |
| print(f"Running generation with arguments: {args}") | |
| seed_everething(args.seed) | |
| offload_to_cpu = False if not args.offload_to_cpu else get_total_gpu_memory() < 30 | |
| assert os.path.exists(args.ckpt_path) and os.path.exists(args.text_encoder_path) | |
| ltx_vace = LTXVace(ckpt_path=args.ckpt_path, | |
| text_encoder_path=args.text_encoder_path, | |
| precision=args.precision, | |
| stg_skip_layers=args.stg_skip_layers, | |
| stg_mode=args.stg_mode, | |
| offload_to_cpu=offload_to_cpu) | |
| src_ref_images = args.src_ref_images.split(',') if args.src_ref_images is not None else [] | |
| if args.use_prompt_extend and args.use_prompt_extend != 'plain': | |
| prompt = get_annotator(config_type='prompt', config_task=args.use_prompt_extend, return_dict=False).forward(args.prompt) | |
| print(f"Prompt extended from '{args.prompt}' to '{prompt}'") | |
| else: | |
| prompt = args.prompt | |
| output = ltx_vace.generate(src_video=args.src_video, | |
| src_mask=args.src_mask, | |
| src_ref_images=src_ref_images, | |
| prompt=prompt, | |
| negative_prompt=args.negative_prompt, | |
| seed=args.seed, | |
| num_inference_steps=args.num_inference_steps, | |
| num_images_per_prompt=args.num_images_per_prompt, | |
| context_scale=args.context_scale, | |
| guidance_scale=args.guidance_scale, | |
| stg_scale=args.stg_scale, | |
| stg_rescale=args.stg_rescale, | |
| frame_rate=args.frame_rate, | |
| image_cond_noise_scale=args.image_cond_noise_scale, | |
| decode_timestep=args.decode_timestep, | |
| decode_noise_scale=args.decode_noise_scale, | |
| output_height=args.height, | |
| output_width=args.width, | |
| num_frames=args.num_frames) | |
| if args.save_dir is None: | |
| save_dir = os.path.join('results', 'vace_ltxv', time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))) | |
| else: | |
| save_dir = args.save_dir | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| frame_rate = output['info']['frame_rate'] | |
| ret_data = {} | |
| if output['out_video'] is not None: | |
| save_path = os.path.join(save_dir, 'out_video.mp4') | |
| out_video = (torch.clamp(output['out_video'] / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) | |
| save_one_video(save_path, out_video, fps=frame_rate) | |
| print(f"Save out_video to {save_path}") | |
| ret_data['out_video'] = save_path | |
| if output['src_video'] is not None: | |
| save_path = os.path.join(save_dir, 'src_video.mp4') | |
| src_video = (torch.clamp(output['src_video'] / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) | |
| save_one_video(save_path, src_video, fps=frame_rate) | |
| print(f"Save src_video to {save_path}") | |
| ret_data['src_video'] = save_path | |
| if output['src_mask'] is not None: | |
| save_path = os.path.join(save_dir, 'src_mask.mp4') | |
| src_mask = (torch.clamp(output['src_mask'], min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8) | |
| save_one_video(save_path, src_mask, fps=frame_rate) | |
| print(f"Save src_mask to {save_path}") | |
| ret_data['src_mask'] = save_path | |
| if output['src_ref_images'] is not None: | |
| for i, ref_img in enumerate(output['src_ref_images']): # [C, F=1, H, W] | |
| save_path = os.path.join(save_dir, f'src_ref_image_{i}.png') | |
| ref_img = (torch.clamp(ref_img.squeeze(1), min=0.0, max=1.0).permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8) | |
| save_one_image(save_path, ref_img, use_type='pil') | |
| print(f"Save src_ref_image_{i} to {save_path}") | |
| ret_data[f'src_ref_image_{i}'] = save_path | |
| return ret_data | |
| if __name__ == "__main__": | |
| args = get_parser().parse_args() | |
| main(args) |