Spaces:
Sleeping
Sleeping
| import os | |
| import imageio | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.utils.data import DataLoader | |
| import torchvision.transforms as T | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from transformers import CLIPVisionModelWithProjection | |
| from models.guider import Guider | |
| from models.referencenet import ReferenceNet2DConditionModel | |
| from models.unet import UNet3DConditionModel | |
| from models.video_pipeline import VideoPipeline | |
| from dataset.val_dataset import ValDataset, val_collate_fn | |
| def load_model_state_dict(model, model_ckpt_path, name): | |
| ckpt = torch.load(model_ckpt_path, map_location="cpu") | |
| model_state_dict = model.state_dict() | |
| model_new_sd = {} | |
| count = 0 | |
| for k, v in ckpt.items(): | |
| if k in model_state_dict: | |
| count += 1 | |
| model_new_sd[k] = v | |
| miss, _ = model.load_state_dict(model_new_sd, strict=False) | |
| print(f'load {name} from {model_ckpt_path}\n - load params: {count}\n - miss params: {miss}') | |
| def visualize(dataloader, pipeline, generator, W, H, video_length, num_inference_steps, guidance_scale, output_dir, limit=1): | |
| for i, batch in enumerate(dataloader): | |
| ref_frame=batch['ref_frame'][0] | |
| clip_image = batch['clip_image'][0] | |
| motions=batch['motions'][0] | |
| file_name = batch['file_name'][0] | |
| if motions is None: | |
| continue | |
| if 'lmk_name' in batch: | |
| lmk_name = batch['lmk_name'][0].split('.')[0] | |
| else: | |
| lmk_name = 'lmk' | |
| print(file_name, lmk_name) | |
| # tensor to pil image | |
| ref_frame = torch.clamp((ref_frame + 1.0) / 2.0, min=0, max=1) | |
| ref_frame = ref_frame.permute((1, 2, 3, 0)).squeeze() | |
| ref_frame = (ref_frame * 255).cpu().numpy().astype(np.uint8) | |
| ref_image = Image.fromarray(ref_frame) | |
| # tensor to pil image | |
| motions = motions.permute((1, 2, 3, 0)) | |
| motions = (motions * 255).cpu().numpy().astype(np.uint8) | |
| lmk_images = [] | |
| for motion in motions: | |
| lmk_images.append(Image.fromarray(motion)) | |
| preds = pipeline(ref_image=ref_image, | |
| lmk_images=lmk_images, | |
| width=W, | |
| height=H, | |
| video_length=video_length, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| clip_image=clip_image, | |
| ).videos | |
| preds = preds.permute((0,2,3,4,1)).squeeze(0) | |
| preds = (preds * 255).cpu().numpy().astype(np.uint8) | |
| mp4_path = os.path.join(output_dir, lmk_name+'_'+file_name.split('.')[0]+'_oo.mp4') | |
| mp4_writer = imageio.get_writer(mp4_path, fps=7) | |
| for pred in preds: | |
| mp4_writer.append_data(pred) | |
| mp4_writer.close() | |
| mp4_path = os.path.join(output_dir, lmk_name+'_'+file_name.split('.')[0]+'_all.mp4') | |
| mp4_writer = imageio.get_writer(mp4_path, fps=8) | |
| if 'frames' in batch: | |
| frames = batch['frames'][0] | |
| frames = torch.clamp((frames + 1.0) / 2.0, min=0, max=1) | |
| frames = frames.permute((1, 2, 3, 0)) | |
| frames = (frames * 255).cpu().numpy().astype(np.uint8) | |
| for frame, motion, pred in zip(frames, motions, preds): | |
| out = np.concatenate((frame, motion, ref_frame, pred), axis=1) | |
| mp4_writer.append_data(out) | |
| else: | |
| for motion, pred in zip(motions, preds): | |
| out = np.concatenate((motion, ref_frame, pred), axis=1) | |
| mp4_writer.append_data(out) | |
| mp4_writer.close() | |
| if i >= limit: | |
| break | |
| def main(args, config): | |
| dist.init_process_group(backend='nccl') | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| torch.cuda.set_device(local_rank) | |
| if dist.get_rank() == 0: | |
| os.makedirs(args.output_path, exist_ok=True) | |
| if config.weight_dtype == "fp16": | |
| weight_dtype = torch.float16 | |
| elif config.weight_dtype == "fp32": | |
| weight_dtype = torch.float32 | |
| else: | |
| raise ValueError(f"Do not support weight dtype: {config.weight_dtype} during training") | |
| # init model | |
| print('init model') | |
| vae = AutoencoderKL.from_pretrained(config.vae_model_path).to(dtype=weight_dtype, device="cuda") | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained(config.image_encoder_path).to(dtype=weight_dtype, device="cuda") | |
| referencenet = ReferenceNet2DConditionModel.from_pretrained_2d(config.base_model_path, | |
| referencenet_additional_kwargs=config.model.referencenet_additional_kwargs).to(device="cuda") | |
| unet = UNet3DConditionModel.from_pretrained_2d(config.base_model_path, | |
| motion_module_path=config.motion_module_path, | |
| unet_additional_kwargs=config.model.unet_additional_kwargs).to(device="cuda") | |
| lmk_guider = Guider(conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)).to(device="cuda") | |
| # load model | |
| print('load model') | |
| load_model_state_dict(referencenet, f'{config.init_checkpoint}/referencenet.pth', 'referencenet') | |
| load_model_state_dict(unet, f'{config.init_checkpoint}/unet.pth', 'unet') | |
| load_model_state_dict(lmk_guider, f'{config.init_checkpoint}/lmk_guider.pth', 'lmk_guider') | |
| if config.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| referencenet.enable_xformers_memory_efficient_attention() | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| unet.set_reentrant(use_reentrant=False) | |
| referencenet.set_reentrant(use_reentrant=False) | |
| vae.eval() | |
| image_encoder.eval() | |
| unet.eval() | |
| referencenet.eval() | |
| lmk_guider.eval() | |
| # noise scheduler | |
| print('init noise scheduler') | |
| sched_kwargs = OmegaConf.to_container(config.scheduler) | |
| if config.enable_zero_snr: | |
| sched_kwargs.update(rescale_betas_zero_snr=True, | |
| timestep_spacing="trailing", | |
| prediction_type="v_prediction") | |
| noise_scheduler = DDIMScheduler(**sched_kwargs) | |
| # pipeline | |
| pipeline = VideoPipeline(vae=vae, | |
| image_encoder=image_encoder, | |
| referencenet=referencenet, | |
| unet=unet, | |
| lmk_guider=lmk_guider, | |
| scheduler=noise_scheduler).to(vae.device, dtype=weight_dtype) | |
| # dataset creation | |
| print('init dataset') | |
| val_dataset = ValDataset( | |
| input_path=args.input_path, | |
| lmk_path=args.lmk_path, | |
| resolution_h=config.resolution_h, | |
| resolution_w=config.resolution_w | |
| ) | |
| print(len(val_dataset)) | |
| sampler = DistributedSampler(val_dataset, shuffle=False) | |
| # DataLoaders creation: | |
| val_dataloader = DataLoader( | |
| val_dataset, | |
| batch_size=1, | |
| num_workers=0, | |
| sampler=sampler, | |
| collate_fn=val_collate_fn, | |
| ) | |
| generator = torch.Generator(device=vae.device) | |
| generator.manual_seed(config.seed) | |
| # run visualize | |
| print('run visualize') | |
| with torch.no_grad(): | |
| visualize(val_dataloader, | |
| pipeline, | |
| generator, | |
| W=config.resolution_w, | |
| H=config.resolution_h, | |
| video_length=config.video_length, | |
| num_inference_steps=30, | |
| guidance_scale=3.5, | |
| output_dir=args.output_path, | |
| limit=100000000) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, required=True) | |
| parser.add_argument('--model_path', type=str, required=True) | |
| parser.add_argument('--model_step', type=int, required=True) | |
| parser.add_argument('--output_path', type=str, required=True) | |
| parser.add_argument('--input_path', type=str, required=True) | |
| parser.add_argument('--lmk_path', type=str, required=True) | |
| parser.add_argument('--seed', type=int, default=42) | |
| args = parser.parse_args() | |
| config = OmegaConf.load(args.config) | |
| config.init_checkpoint = args.model_path | |
| config.init_num = args.model_step | |
| main(args, config) |