Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import os | |
| from typing import Callable, Dict, Union | |
| import numpy as np | |
| import torch | |
| from data_loaders.get_data import get_dataset_loader, load_local_data | |
| from diffusion.respace import SpacedDiffusion | |
| from model.cfg_sampler import ClassifierFreeSampleModel | |
| from model.diffusion import FiLMTransformer | |
| from torch.utils.data import DataLoader | |
| from utils.diff_parser_utils import generate_args | |
| from utils.misc import fixseed, prGreen | |
| from utils.model_util import create_model_and_diffusion, get_person_num, load_model | |
| def _construct_template_variables(unconstrained: bool) -> (str,): | |
| row_file_template = "sample{:02d}.mp4" | |
| all_file_template = "samples_{:02d}_to_{:02d}.mp4" | |
| if unconstrained: | |
| sample_file_template = "row{:02d}_col{:02d}.mp4" | |
| sample_print_template = "[{} row #{:02d} column #{:02d} | -> {}]" | |
| row_file_template = row_file_template.replace("sample", "row") | |
| row_print_template = "[{} row #{:02d} | all columns | -> {}]" | |
| all_file_template = all_file_template.replace("samples", "rows") | |
| all_print_template = "[rows {:02d} to {:02d} | -> {}]" | |
| else: | |
| sample_file_template = "sample{:02d}_rep{:02d}.mp4" | |
| sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]' | |
| row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]' | |
| all_print_template = "[samples {:02d} to {:02d} | all repetitions | -> {}]" | |
| return ( | |
| sample_print_template, | |
| row_print_template, | |
| all_print_template, | |
| sample_file_template, | |
| row_file_template, | |
| all_file_template, | |
| ) | |
| def _replace_keyframes( | |
| model_kwargs: Dict[str, Dict[str, torch.Tensor]], | |
| model: Union[FiLMTransformer, ClassifierFreeSampleModel], | |
| ) -> torch.Tensor: | |
| B, T = ( | |
| model_kwargs["y"]["keyframes"].shape[0], | |
| model_kwargs["y"]["keyframes"].shape[1], | |
| ) | |
| with torch.no_grad(): | |
| tokens = model.transformer.generate( | |
| model_kwargs["y"]["audio"], | |
| T, | |
| layers=model.tokenizer.residual_depth, | |
| n_sequences=B, | |
| ) | |
| tokens = tokens.reshape((B, -1, model.tokenizer.residual_depth)) | |
| pred = model.tokenizer.decode(tokens).detach().cpu() | |
| assert ( | |
| model_kwargs["y"]["keyframes"].shape == pred.shape | |
| ), f"{model_kwargs['y']['keyframes'].shape} vs {pred.shape}" | |
| return pred | |
| def _run_single_diffusion( | |
| args, | |
| model_kwargs: Dict[str, Dict[str, torch.Tensor]], | |
| diffusion: SpacedDiffusion, | |
| model: Union[FiLMTransformer, ClassifierFreeSampleModel], | |
| inv_transform: Callable, | |
| gt: torch.Tensor, | |
| ) -> (torch.Tensor,): | |
| if args.data_format == "pose" and args.resume_trans is not None: | |
| model_kwargs["y"]["keyframes"] = _replace_keyframes(model_kwargs, model) | |
| sample_fn = diffusion.ddim_sample_loop | |
| with torch.no_grad(): | |
| sample = sample_fn( | |
| model, | |
| (args.batch_size, model.nfeats, 1, args.curr_seq_length), | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| init_image=None, | |
| progress=True, | |
| dump_steps=None, | |
| noise=None, | |
| const_noise=False, | |
| ) | |
| sample = inv_transform(sample.cpu().permute(0, 2, 3, 1), args.data_format).permute( | |
| 0, 3, 1, 2 | |
| ) | |
| curr_audio = inv_transform(model_kwargs["y"]["audio"].cpu().numpy(), "audio") | |
| keyframes = inv_transform(model_kwargs["y"]["keyframes"], args.data_format) | |
| gt_seq = inv_transform(gt.cpu().permute(0, 2, 3, 1), args.data_format).permute( | |
| 0, 3, 1, 2 | |
| ) | |
| return sample, curr_audio, keyframes, gt_seq | |
| def _generate_sequences( | |
| args, | |
| model_kwargs: Dict[str, Dict[str, torch.Tensor]], | |
| diffusion: SpacedDiffusion, | |
| model: Union[FiLMTransformer, ClassifierFreeSampleModel], | |
| test_data: torch.Tensor, | |
| gt: torch.Tensor, | |
| ) -> Dict[str, np.ndarray]: | |
| all_motions = [] | |
| all_lengths = [] | |
| all_audio = [] | |
| all_gt = [] | |
| all_keyframes = [] | |
| for rep_i in range(args.num_repetitions): | |
| print(f"### Sampling [repetitions #{rep_i}]") | |
| # add CFG scale to batch | |
| if args.guidance_param != 1: | |
| model_kwargs["y"]["scale"] = ( | |
| torch.ones(args.batch_size, device=args.device) * args.guidance_param | |
| ) | |
| model_kwargs["y"] = { | |
| key: val.to(args.device) if torch.is_tensor(val) else val | |
| for key, val in model_kwargs["y"].items() | |
| } | |
| sample, curr_audio, keyframes, gt_seq = _run_single_diffusion( | |
| args, model_kwargs, diffusion, model, test_data.dataset.inv_transform, gt | |
| ) | |
| all_motions.append(sample.cpu().numpy()) | |
| all_audio.append(curr_audio) | |
| all_keyframes.append(keyframes.cpu().numpy()) | |
| all_gt.append(gt_seq.cpu().numpy()) | |
| all_lengths.append(model_kwargs["y"]["lengths"].cpu().numpy()) | |
| print(f"created {len(all_motions) * args.batch_size} samples") | |
| return { | |
| "motions": np.concatenate(all_motions, axis=0), | |
| "audio": np.concatenate(all_audio, axis=0), | |
| "gt": np.concatenate(all_gt, axis=0), | |
| "lengths": np.concatenate(all_lengths, axis=0), | |
| "keyframes": np.concatenate(all_keyframes, axis=0), | |
| } | |
| def _render_pred( | |
| args, | |
| data_block: Dict[str, torch.Tensor], | |
| sample_file_template: str, | |
| audio_per_frame: int, | |
| ) -> None: | |
| from visualize.render_codes import BodyRenderer | |
| face_codes = None | |
| if args.face_codes is not None: | |
| face_codes = np.load(args.face_codes, allow_pickle=True).item() | |
| face_motions = face_codes["motions"] | |
| face_gts = face_codes["gt"] | |
| face_audio = face_codes["audio"] | |
| config_base = f"./checkpoints/ca_body/data/{get_person_num(args.data_root)}" | |
| body_renderer = BodyRenderer( | |
| config_base=config_base, | |
| render_rgb=True, | |
| ) | |
| for sample_i in range(args.num_samples): | |
| for rep_i in range(args.num_repetitions): | |
| idx = rep_i * args.batch_size + sample_i | |
| save_file = sample_file_template.format(sample_i, rep_i) | |
| animation_save_path = os.path.join(args.output_dir, save_file) | |
| # format data | |
| length = data_block["lengths"][idx] | |
| body_motion = ( | |
| data_block["motions"][idx].transpose(2, 0, 1)[:length].squeeze(-1) | |
| ) | |
| face_motion = face_motions[idx].transpose(2, 0, 1)[:length].squeeze(-1) | |
| assert np.array_equal( | |
| data_block["audio"][idx], face_audio[idx] | |
| ), "face audio is not the same" | |
| audio = data_block["audio"][idx, : length * audio_per_frame, :].T | |
| # set up render data block to pass into renderer | |
| render_data_block = { | |
| "audio": audio, | |
| "body_motion": body_motion, | |
| "face_motion": face_motion, | |
| } | |
| if args.render_gt: | |
| gt_body = data_block["gt"][idx].transpose(2, 0, 1)[:length].squeeze(-1) | |
| gt_face = face_gts[idx].transpose(2, 0, 1)[:length].squeeze(-1) | |
| render_data_block["gt_body"] = gt_body | |
| render_data_block["gt_face"] = gt_face | |
| body_renderer.render_full_video( | |
| render_data_block, | |
| animation_save_path, | |
| audio_sr=audio_per_frame * 30, | |
| render_gt=args.render_gt, | |
| ) | |
| def _reset_sample_args(args) -> None: | |
| # set the sequence length to match the one specified by user | |
| name = os.path.basename(os.path.dirname(args.model_path)) | |
| niter = os.path.basename(args.model_path).replace("model", "").replace(".pt", "") | |
| args.curr_seq_length = ( | |
| args.curr_seq_length | |
| if args.curr_seq_length is not None | |
| else args.max_seq_length | |
| ) | |
| # add the resume predictor model path | |
| resume_trans_name = "" | |
| if args.data_format == "pose" and args.resume_trans is not None: | |
| resume_trans_parts = args.resume_trans.split("/") | |
| resume_trans_name = f"{resume_trans_parts[1]}_{resume_trans_parts[-1]}" | |
| # reformat the output directory | |
| args.output_dir = os.path.join( | |
| os.path.dirname(args.model_path), | |
| "samples_{}_{}_seed{}_{}".format(name, niter, args.seed, resume_trans_name), | |
| ) | |
| assert ( | |
| args.num_samples <= args.batch_size | |
| ), f"Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})" | |
| # set the batch size to match the number of samples to generate | |
| args.batch_size = args.num_samples | |
| def _setup_dataset(args) -> DataLoader: | |
| data_root = args.data_root | |
| data_dict = load_local_data( | |
| data_root, | |
| audio_per_frame=1600, | |
| flip_person=args.flip_person, | |
| ) | |
| test_data = get_dataset_loader( | |
| args=args, | |
| data_dict=data_dict, | |
| split="test", | |
| chunk=True, | |
| ) | |
| return test_data | |
| def _setup_model( | |
| args, | |
| ) -> (Union[FiLMTransformer, ClassifierFreeSampleModel], SpacedDiffusion): | |
| model, diffusion = create_model_and_diffusion(args, split_type="test") | |
| print(f"Loading checkpoints from [{args.model_path}]...") | |
| state_dict = torch.load(args.model_path, map_location="cpu") | |
| load_model(model, state_dict) | |
| if not args.unconstrained: | |
| assert args.guidance_param != 1 | |
| if args.guidance_param != 1: | |
| prGreen("[CFS] wrapping model in classifier free sample") | |
| model = ClassifierFreeSampleModel(model) | |
| model.to(args.device) | |
| model.eval() | |
| return model, diffusion | |
| def main(): | |
| args = generate_args() | |
| fixseed(args.seed) | |
| _reset_sample_args(args) | |
| print("Loading dataset...") | |
| test_data = _setup_dataset(args) | |
| iterator = iter(test_data) | |
| print("Creating model and diffusion...") | |
| model, diffusion = _setup_model(args) | |
| if args.pose_codes is None: | |
| # generate sequences | |
| gt, model_kwargs = next(iterator) | |
| data_block = _generate_sequences( | |
| args, model_kwargs, diffusion, model, test_data, gt | |
| ) | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| npy_path = os.path.join(args.output_dir, "results.npy") | |
| print(f"saving results file to [{npy_path}]") | |
| np.save(npy_path, data_block) | |
| else: | |
| # load the pre generated results | |
| data_block = np.load(args.pose_codes, allow_pickle=True).item() | |
| # plot function only if face_codes exist and we are on pose prediction | |
| if args.plot: | |
| assert args.face_codes is not None, "need body and faces" | |
| assert ( | |
| args.data_format == "pose" | |
| ), "currently only supporting plot on pose stuff" | |
| print(f"saving visualizations to [{args.output_dir}]...") | |
| _, _, _, sample_file_template, _, _ = _construct_template_variables( | |
| args.unconstrained | |
| ) | |
| _render_pred( | |
| args, | |
| data_block, | |
| sample_file_template, | |
| test_data.dataset.audio_per_frame, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |