Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .base_architecture import BaseArchitecture | |
| from ..builder import ( | |
| ARCHITECTURES, | |
| build_architecture, | |
| build_submodule, | |
| build_loss | |
| ) | |
| from ..utils.gaussian_diffusion import ( | |
| GaussianDiffusion, get_named_beta_schedule, create_named_schedule_sampler, | |
| ModelMeanType, ModelVarType, LossType, space_timesteps, SpacedDiffusion | |
| ) | |
| def build_diffusion(cfg): | |
| beta_scheduler = cfg['beta_scheduler'] | |
| diffusion_steps = cfg['diffusion_steps'] | |
| betas = get_named_beta_schedule(beta_scheduler, diffusion_steps) | |
| model_mean_type = { | |
| 'start_x': ModelMeanType.START_X, | |
| 'previous_x': ModelMeanType.PREVIOUS_X, | |
| 'epsilon': ModelMeanType.EPSILON | |
| }[cfg['model_mean_type']] | |
| model_var_type = { | |
| 'learned': ModelVarType.LEARNED, | |
| 'fixed_small': ModelVarType.FIXED_SMALL, | |
| 'fixed_large': ModelVarType.FIXED_LARGE, | |
| 'learned_range': ModelVarType.LEARNED_RANGE | |
| }[cfg['model_var_type']] | |
| if cfg.get('respace', None) is not None: | |
| diffusion = SpacedDiffusion( | |
| use_timesteps=space_timesteps(diffusion_steps, cfg['respace']), | |
| betas=betas, | |
| model_mean_type=model_mean_type, | |
| model_var_type=model_var_type, | |
| loss_type=LossType.MSE | |
| ) | |
| else: | |
| diffusion = GaussianDiffusion( | |
| betas=betas, | |
| model_mean_type=model_mean_type, | |
| model_var_type=model_var_type, | |
| loss_type=LossType.MSE) | |
| return diffusion | |
| class MotionDiffusion(BaseArchitecture): | |
| def __init__(self, | |
| model=None, | |
| loss_recon=None, | |
| diffusion_train=None, | |
| diffusion_test=None, | |
| init_cfg=None, | |
| inference_type='ddpm', | |
| **kwargs): | |
| super().__init__(init_cfg=init_cfg, **kwargs) | |
| self.model = build_submodule(model) | |
| self.loss_recon = build_loss(loss_recon) | |
| self.diffusion_train = build_diffusion(diffusion_train) | |
| self.diffusion_test = build_diffusion(diffusion_test) | |
| self.sampler = create_named_schedule_sampler('uniform', self.diffusion_train) | |
| self.inference_type = inference_type | |
| def forward(self, **kwargs): | |
| motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'].float() | |
| sample_idx = kwargs.get('sample_idx', None) | |
| clip_feat = kwargs.get('clip_feat', None) | |
| B, T = motion.shape[:2] | |
| text = [] | |
| for i in range(B): | |
| text.append(kwargs['motion_metas'][i]['text']) | |
| if self.training: | |
| t, _ = self.sampler.sample(B, motion.device) | |
| output = self.diffusion_train.training_losses( | |
| model=self.model, | |
| x_start=motion, | |
| t=t, | |
| model_kwargs={ | |
| 'motion_mask': motion_mask, | |
| 'motion_length': kwargs['motion_length'], | |
| 'text': text, | |
| 'clip_feat': clip_feat, | |
| 'sample_idx': sample_idx} | |
| ) | |
| pred, target = output['pred'], output['target'] | |
| recon_loss = self.loss_recon(pred, target, reduction_override='none') | |
| recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum() | |
| loss = {'recon_loss': recon_loss} | |
| return loss | |
| else: | |
| dim_pose = kwargs['motion'].shape[-1] | |
| model_kwargs = self.model.get_precompute_condition(device=motion.device, text=text, **kwargs) | |
| model_kwargs['motion_mask'] = motion_mask | |
| model_kwargs['sample_idx'] = sample_idx | |
| inference_kwargs = kwargs.get('inference_kwargs', {}) | |
| if self.inference_type == 'ddpm': | |
| output = self.diffusion_test.p_sample_loop( | |
| self.model, | |
| (B, T, dim_pose), | |
| clip_denoised=False, | |
| progress=False, | |
| model_kwargs=model_kwargs, | |
| **inference_kwargs | |
| ) | |
| else: | |
| output = self.diffusion_test.ddim_sample_loop( | |
| self.model, | |
| (B, T, dim_pose), | |
| clip_denoised=False, | |
| progress=False, | |
| model_kwargs=model_kwargs, | |
| eta=0, | |
| **inference_kwargs | |
| ) | |
| if getattr(self.model, "post_process") is not None: | |
| output = self.model.post_process(output) | |
| results = kwargs | |
| results['pred_motion'] = output | |
| results = self.split_results(results) | |
| return results | |