Spaces:
Runtime error
Runtime error
| # pylint: disable=E1101 | |
| # scripts/inference.py | |
| """ | |
| This script contains the main inference pipeline for processing audio and image inputs to generate a video output. | |
| The script imports necessary packages and classes, defines a neural network model, | |
| and contains functions for processing audio embeddings and performing inference. | |
| The main inference process is outlined in the following steps: | |
| 1. Initialize the configuration. | |
| 2. Set up runtime variables. | |
| 3. Prepare the input data for inference (source image, face mask, and face embeddings). | |
| 4. Process the audio embeddings. | |
| 5. Build and freeze the model and scheduler. | |
| 6. Run the inference loop and save the result. | |
| Usage: | |
| This script can be run from the command line with the following arguments: | |
| - audio_path: Path to the audio file. | |
| - image_path: Path to the source image. | |
| - face_mask_path: Path to the face mask image. | |
| - face_emb_path: Path to the face embeddings file. | |
| - output_path: Path to save the output video. | |
| Example: | |
| python scripts/inference.py --audio_path audio.wav --image_path image.jpg | |
| --face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4 | |
| """ | |
| import argparse | |
| import os | |
| import torch | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from omegaconf import OmegaConf | |
| from torch import nn | |
| from hallo.animate.face_animate import FaceAnimatePipeline | |
| from hallo.datasets.audio_processor import AudioProcessor | |
| from hallo.datasets.image_processor import ImageProcessor | |
| from hallo.models.audio_proj import AudioProjModel | |
| from hallo.models.face_locator import FaceLocator | |
| from hallo.models.image_proj import ImageProjModel | |
| from hallo.models.unet_2d_condition import UNet2DConditionModel | |
| from hallo.models.unet_3d import UNet3DConditionModel | |
| from hallo.utils.util import tensor_to_video | |
| class Net(nn.Module): | |
| """ | |
| The Net class combines all the necessary modules for the inference process. | |
| Args: | |
| reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference. | |
| denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio. | |
| face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image. | |
| imageproj (nn.Module): The ImageProjector model used to project the source image onto the face. | |
| audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face. | |
| """ | |
| def __init__( | |
| self, | |
| reference_unet: UNet2DConditionModel, | |
| denoising_unet: UNet3DConditionModel, | |
| face_locator: FaceLocator, | |
| imageproj, | |
| audioproj, | |
| ): | |
| super().__init__() | |
| self.reference_unet = reference_unet | |
| self.denoising_unet = denoising_unet | |
| self.face_locator = face_locator | |
| self.imageproj = imageproj | |
| self.audioproj = audioproj | |
| def forward(self,): | |
| """ | |
| empty function to override abstract function of nn Module | |
| """ | |
| def get_modules(self): | |
| """ | |
| Simple method to avoid too-few-public-methods pylint error | |
| """ | |
| return { | |
| "reference_unet": self.reference_unet, | |
| "denoising_unet": self.denoising_unet, | |
| "face_locator": self.face_locator, | |
| "imageproj": self.imageproj, | |
| "audioproj": self.audioproj, | |
| } | |
| def process_audio_emb(audio_emb): | |
| """ | |
| Process the audio embedding to concatenate with other tensors. | |
| Parameters: | |
| audio_emb (torch.Tensor): The audio embedding tensor to process. | |
| Returns: | |
| concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. | |
| """ | |
| concatenated_tensors = [] | |
| for i in range(audio_emb.shape[0]): | |
| vectors_to_concat = [ | |
| audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)] | |
| concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) | |
| audio_emb = torch.stack(concatenated_tensors, dim=0) | |
| return audio_emb | |
| def inference_process(args: argparse.Namespace): | |
| """ | |
| Perform inference processing. | |
| Args: | |
| args (argparse.Namespace): Command-line arguments. | |
| This function initializes the configuration for the inference process. It sets up the necessary | |
| modules and variables to prepare for the upcoming inference steps. | |
| """ | |
| # 1. init config | |
| config = OmegaConf.load(args.config) | |
| config = OmegaConf.merge(config, vars(args)) | |
| source_image_path = config.source_image | |
| driving_audio_path = config.driving_audio | |
| save_path = config.save_path | |
| if not os.path.exists(save_path): | |
| os.makedirs(save_path) | |
| motion_scale = [config.pose_weight, config.face_weight, config.lip_weight] | |
| if args.checkpoint is not None: | |
| config.audio_ckpt_dir = args.checkpoint | |
| # 2. runtime variables | |
| device = torch.device( | |
| "cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| if config.weight_dtype == "fp16": | |
| weight_dtype = torch.float16 | |
| elif config.weight_dtype == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| elif config.weight_dtype == "fp32": | |
| weight_dtype = torch.float32 | |
| else: | |
| weight_dtype = torch.float32 | |
| # 3. prepare inference data | |
| # 3.1 prepare source image, face mask, face embeddings | |
| img_size = (config.data.source_image.width, | |
| config.data.source_image.height) | |
| clip_length = config.data.n_sample_frames | |
| face_analysis_model_path = config.face_analysis.model_path | |
| with ImageProcessor(img_size, face_analysis_model_path) as image_processor: | |
| source_image_pixels, \ | |
| source_image_face_region, \ | |
| source_image_face_emb, \ | |
| source_image_full_mask, \ | |
| source_image_face_mask, \ | |
| source_image_lip_mask = image_processor.preprocess( | |
| source_image_path, save_path, config.face_expand_ratio) | |
| # 3.2 prepare audio embeddings | |
| sample_rate = config.data.driving_audio.sample_rate | |
| assert sample_rate == 16000, "audio sample rate must be 16000" | |
| fps = config.data.export_video.fps | |
| wav2vec_model_path = config.wav2vec.model_path | |
| wav2vec_only_last_features = config.wav2vec.features == "last" | |
| audio_separator_model_file = config.audio_separator.model_path | |
| with AudioProcessor( | |
| sample_rate, | |
| fps, | |
| wav2vec_model_path, | |
| wav2vec_only_last_features, | |
| os.path.dirname(audio_separator_model_file), | |
| os.path.basename(audio_separator_model_file), | |
| os.path.join(save_path, "audio_preprocess") | |
| ) as audio_processor: | |
| audio_emb = audio_processor.preprocess(driving_audio_path) | |
| # 4. build modules | |
| sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs) | |
| if config.enable_zero_snr: | |
| sched_kwargs.update( | |
| rescale_betas_zero_snr=True, | |
| timestep_spacing="trailing", | |
| prediction_type="v_prediction", | |
| ) | |
| val_noise_scheduler = DDIMScheduler(**sched_kwargs) | |
| sched_kwargs.update({"beta_schedule": "scaled_linear"}) | |
| vae = AutoencoderKL.from_pretrained(config.vae.model_path) | |
| reference_unet = UNet2DConditionModel.from_pretrained( | |
| config.base_model_path, subfolder="unet") | |
| denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
| config.base_model_path, | |
| config.motion_module_path, | |
| subfolder="unet", | |
| unet_additional_kwargs=OmegaConf.to_container( | |
| config.unet_additional_kwargs), | |
| use_landmark=False, | |
| ) | |
| face_locator = FaceLocator(conditioning_embedding_channels=320) | |
| image_proj = ImageProjModel( | |
| cross_attention_dim=denoising_unet.config.cross_attention_dim, | |
| clip_embeddings_dim=512, | |
| clip_extra_context_tokens=4, | |
| ) | |
| audio_proj = AudioProjModel( | |
| seq_len=5, | |
| blocks=12, # use 12 layers' hidden states of wav2vec | |
| channels=768, # audio embedding channel | |
| intermediate_dim=512, | |
| output_dim=768, | |
| context_tokens=32, | |
| ).to(device=device, dtype=weight_dtype) | |
| audio_ckpt_dir = config.audio_ckpt_dir | |
| # Freeze | |
| vae.requires_grad_(False) | |
| image_proj.requires_grad_(False) | |
| reference_unet.requires_grad_(False) | |
| denoising_unet.requires_grad_(False) | |
| face_locator.requires_grad_(False) | |
| audio_proj.requires_grad_(False) | |
| reference_unet.enable_gradient_checkpointing() | |
| denoising_unet.enable_gradient_checkpointing() | |
| net = Net( | |
| reference_unet, | |
| denoising_unet, | |
| face_locator, | |
| image_proj, | |
| audio_proj, | |
| ) | |
| m,u = net.load_state_dict( | |
| torch.load( | |
| os.path.join(audio_ckpt_dir, "net.pth"), | |
| map_location="cpu", | |
| ), | |
| ) | |
| assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint." | |
| print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth")) | |
| # 5. inference | |
| pipeline = FaceAnimatePipeline( | |
| vae=vae, | |
| reference_unet=net.reference_unet, | |
| denoising_unet=net.denoising_unet, | |
| face_locator=net.face_locator, | |
| scheduler=val_noise_scheduler, | |
| image_proj=net.imageproj, | |
| ) | |
| pipeline.to(device=device, dtype=weight_dtype) | |
| audio_emb = process_audio_emb(audio_emb) | |
| source_image_pixels = source_image_pixels.unsqueeze(0) | |
| source_image_face_region = source_image_face_region.unsqueeze(0) | |
| source_image_face_emb = source_image_face_emb.reshape(1, -1) | |
| source_image_face_emb = torch.tensor(source_image_face_emb) | |
| source_image_full_mask = [ | |
| (mask.repeat(clip_length, 1)) | |
| for mask in source_image_full_mask | |
| ] | |
| source_image_face_mask = [ | |
| (mask.repeat(clip_length, 1)) | |
| for mask in source_image_face_mask | |
| ] | |
| source_image_lip_mask = [ | |
| (mask.repeat(clip_length, 1)) | |
| for mask in source_image_lip_mask | |
| ] | |
| times = audio_emb.shape[0] // clip_length | |
| tensor_result = [] | |
| generator = torch.manual_seed(42) | |
| for t in range(times): | |
| if len(tensor_result) == 0: | |
| # The first iteration | |
| motion_zeros = source_image_pixels.repeat( | |
| config.data.n_motion_frames, 1, 1, 1) | |
| motion_zeros = motion_zeros.to( | |
| dtype=source_image_pixels.dtype, device=source_image_pixels.device) | |
| pixel_values_ref_img = torch.cat( | |
| [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames | |
| else: | |
| motion_frames = tensor_result[-1][0] | |
| motion_frames = motion_frames.permute(1, 0, 2, 3) | |
| motion_frames = motion_frames[0-config.data.n_motion_frames:] | |
| motion_frames = motion_frames * 2.0 - 1.0 | |
| motion_frames = motion_frames.to( | |
| dtype=source_image_pixels.dtype, device=source_image_pixels.device) | |
| pixel_values_ref_img = torch.cat( | |
| [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames | |
| pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) | |
| audio_tensor = audio_emb[ | |
| t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) | |
| ] | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| audio_tensor = audio_tensor.to( | |
| device=net.audioproj.device, dtype=net.audioproj.dtype) | |
| audio_tensor = net.audioproj(audio_tensor) | |
| pipeline_output = pipeline( | |
| ref_image=pixel_values_ref_img, | |
| audio_tensor=audio_tensor, | |
| face_emb=source_image_face_emb, | |
| face_mask=source_image_face_region, | |
| pixel_values_full_mask=source_image_full_mask, | |
| pixel_values_face_mask=source_image_face_mask, | |
| pixel_values_lip_mask=source_image_lip_mask, | |
| width=img_size[0], | |
| height=img_size[1], | |
| video_length=clip_length, | |
| num_inference_steps=config.inference_steps, | |
| guidance_scale=config.cfg_scale, | |
| generator=generator, | |
| motion_scale=motion_scale, | |
| ) | |
| tensor_result.append(pipeline_output.videos) | |
| tensor_result = torch.cat(tensor_result, dim=2) | |
| tensor_result = tensor_result.squeeze(0) | |
| output_file = config.output | |
| # save the result after all iteration | |
| tensor_to_video(tensor_result, output_file, driving_audio_path) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-c", "--config", default="configs/inference/default.yaml") | |
| parser.add_argument("--source_image", type=str, required=False, | |
| help="source image", default="test_data/source_images/6.jpg") | |
| parser.add_argument("--driving_audio", type=str, required=False, | |
| help="driving audio", default="test_data/driving_audios/singing/sing_4.wav") | |
| parser.add_argument( | |
| "--output", type=str, help="output video file name", default=".cache/output.mp4") | |
| parser.add_argument( | |
| "--pose_weight", type=float, help="weight of pose", default=1.0) | |
| parser.add_argument( | |
| "--face_weight", type=float, help="weight of face", default=1.0) | |
| parser.add_argument( | |
| "--lip_weight", type=float, help="weight of lip", default=1.0) | |
| parser.add_argument( | |
| "--face_expand_ratio", type=float, help="face region", default=1.2) | |
| parser.add_argument( | |
| "--checkpoint", type=str, help="which checkpoint", default=None) | |
| command_line_args = parser.parse_args() | |
| inference_process(command_line_args) | |