Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitignore +1 -0
 - README.md +6 -9
 - app.py +123 -0
 - configs/_base_/datasets/human_ml3d_bs128.py +60 -0
 - configs/_base_/datasets/kit_ml_bs128.py +60 -0
 - configs/mdm/mdm_t2m_official.py +67 -0
 - configs/motiondiffuse/motiondiffuse_kit.py +89 -0
 - configs/motiondiffuse/motiondiffuse_t2m.py +90 -0
 - configs/remodiffuse/remodiffuse_kit.py +141 -0
 - configs/remodiffuse/remodiffuse_t2m.py +141 -0
 - data/database/t2m_text_train.npz +3 -0
 - data/datasets/human_ml3d/mean.npy +3 -0
 - data/datasets/human_ml3d/std.npy +3 -0
 - data/datasets/kit_ml/mean.npy +3 -0
 - data/datasets/kit_ml/std.npy +3 -0
 - logs/mdm/mdm_t2m/latest.pth +3 -0
 - logs/motiondiffuse/motiondiffuse_t2m/latest.pth +3 -0
 - logs/remodiffuse/remodiffuse_t2m/latest.pth +3 -0
 - mogen/__init__.py +56 -0
 - mogen/apis/__init__.py +13 -0
 - mogen/apis/test.py +160 -0
 - mogen/apis/train.py +165 -0
 - mogen/core/__init__.py +0 -0
 - mogen/core/distributed_wrapper.py +136 -0
 - mogen/core/evaluation/__init__.py +4 -0
 - mogen/core/evaluation/builder.py +29 -0
 - mogen/core/evaluation/eval_hooks.py +138 -0
 - mogen/core/evaluation/evaluators/__init__.py +0 -0
 - mogen/core/evaluation/evaluators/base_evaluator.py +144 -0
 - mogen/core/evaluation/evaluators/diversity_evaluator.py +52 -0
 - mogen/core/evaluation/evaluators/fid_evaluator.py +58 -0
 - mogen/core/evaluation/evaluators/matching_score_evaluator.py +71 -0
 - mogen/core/evaluation/evaluators/multimodality_evaluator.py +63 -0
 - mogen/core/evaluation/evaluators/precision_evaluator.py +74 -0
 - mogen/core/evaluation/get_model.py +46 -0
 - mogen/core/evaluation/utils.py +130 -0
 - mogen/core/optimizer/__init__.py +3 -0
 - mogen/core/optimizer/builder.py +52 -0
 - mogen/datasets/__init__.py +11 -0
 - mogen/datasets/base_dataset.py +117 -0
 - mogen/datasets/builder.py +113 -0
 - mogen/datasets/dataset_wrappers.py +42 -0
 - mogen/datasets/pipelines/__init__.py +18 -0
 - mogen/datasets/pipelines/compose.py +42 -0
 - mogen/datasets/pipelines/formatting.py +134 -0
 - mogen/datasets/pipelines/transforms.py +120 -0
 - mogen/datasets/samplers/__init__.py +3 -0
 - mogen/datasets/samplers/distributed_sampler.py +42 -0
 - mogen/datasets/text_motion_dataset.py +93 -0
 - mogen/models/__init__.py +7 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            **__pycache__**
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,12 +1,9 @@ 
     | 
|
| 1 | 
         
            -
             
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            colorTo: gray
         
     | 
| 6 | 
         
             
            sdk: gradio
         
     | 
| 7 | 
         
            -
            sdk_version: 3. 
     | 
| 8 | 
         
             
            app_file: app.py
         
     | 
| 9 | 
         
             
            pinned: false
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         
     | 
| 
         | 
|
| 1 | 
         
            +
            title: MotionDiffuse
         
     | 
| 2 | 
         
            +
            emoji: 🏢
         
     | 
| 3 | 
         
            +
            colorFrom: blue
         
     | 
| 4 | 
         
            +
            colorTo: red
         
     | 
| 
         | 
|
| 5 | 
         
             
            sdk: gradio
         
     | 
| 6 | 
         
            +
            sdk_version: 3.44.1
         
     | 
| 7 | 
         
             
            app_file: app.py
         
     | 
| 8 | 
         
             
            pinned: false
         
     | 
| 9 | 
         
            +
            license: mit
         
     | 
| 
         | 
|
| 
         | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,123 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import sys
         
     | 
| 3 | 
         
            +
            import gradio as gr
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            os.makedirs("outputs", exist_ok=True) 
         
     | 
| 6 | 
         
            +
            sys.path.insert(0, '.')
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import argparse
         
     | 
| 9 | 
         
            +
            import os.path as osp
         
     | 
| 10 | 
         
            +
            import mmcv
         
     | 
| 11 | 
         
            +
            import numpy as np
         
     | 
| 12 | 
         
            +
            import torch
         
     | 
| 13 | 
         
            +
            from mogen.models import build_architecture
         
     | 
| 14 | 
         
            +
            from mmcv.runner import load_checkpoint
         
     | 
| 15 | 
         
            +
            from mmcv.parallel import MMDataParallel
         
     | 
| 16 | 
         
            +
            from mogen.utils.plot_utils import (
         
     | 
| 17 | 
         
            +
                recover_from_ric,
         
     | 
| 18 | 
         
            +
                plot_3d_motion,
         
     | 
| 19 | 
         
            +
                t2m_kinematic_chain
         
     | 
| 20 | 
         
            +
            )
         
     | 
| 21 | 
         
            +
            from scipy.ndimage import gaussian_filter
         
     | 
| 22 | 
         
            +
            from IPython.display import Image
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def motion_temporal_filter(motion, sigma=1):
         
     | 
| 26 | 
         
            +
                motion = motion.reshape(motion.shape[0], -1)
         
     | 
| 27 | 
         
            +
                for i in range(motion.shape[1]):
         
     | 
| 28 | 
         
            +
                    motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
         
     | 
| 29 | 
         
            +
                return motion.reshape(motion.shape[0], -1, 3)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def plot_t2m(data, result_path, npy_path, caption):
         
     | 
| 33 | 
         
            +
                joint = recover_from_ric(torch.from_numpy(data).float(), 22).numpy()
         
     | 
| 34 | 
         
            +
                joint = motion_temporal_filter(joint, sigma=2.5)
         
     | 
| 35 | 
         
            +
                plot_3d_motion(result_path, t2m_kinematic_chain, joint, title=caption, fps=20)
         
     | 
| 36 | 
         
            +
                if npy_path is not None:
         
     | 
| 37 | 
         
            +
                    np.save(npy_path, joint)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            def create_remodiffuse():
         
     | 
| 40 | 
         
            +
                config_path = "configs/remodiffuse/remodiffuse_t2m.py"
         
     | 
| 41 | 
         
            +
                ckpt_path = "logs/remodiffuse/remodiffuse_t2m/latest.pth"
         
     | 
| 42 | 
         
            +
                cfg = mmcv.Config.fromfile(config_path)
         
     | 
| 43 | 
         
            +
                model = build_architecture(cfg.model)
         
     | 
| 44 | 
         
            +
                load_checkpoint(model, ckpt_path, map_location='cpu')
         
     | 
| 45 | 
         
            +
                model.cpu()
         
     | 
| 46 | 
         
            +
                model.eval()
         
     | 
| 47 | 
         
            +
                return model
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            def create_motiondiffuse():
         
     | 
| 50 | 
         
            +
                config_path = "configs/motiondiffuse/motiondiffuse_t2m.py"
         
     | 
| 51 | 
         
            +
                ckpt_path = "logs/motiondiffuse/motiondiffuse_t2m/latest.pth"
         
     | 
| 52 | 
         
            +
                cfg = mmcv.Config.fromfile(config_path)
         
     | 
| 53 | 
         
            +
                model = build_architecture(cfg.model)
         
     | 
| 54 | 
         
            +
                load_checkpoint(model, ckpt_path, map_location='cpu')
         
     | 
| 55 | 
         
            +
                model.cpu()
         
     | 
| 56 | 
         
            +
                model.eval()
         
     | 
| 57 | 
         
            +
                return model
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            def create_mdm():
         
     | 
| 60 | 
         
            +
                config_path = "configs/mdm/mdm_t2m_official.py"
         
     | 
| 61 | 
         
            +
                ckpt_path = "logs/mdm/mdm_t2m/latest.pth"
         
     | 
| 62 | 
         
            +
                cfg = mmcv.Config.fromfile(config_path)
         
     | 
| 63 | 
         
            +
                model = build_architecture(cfg.model)
         
     | 
| 64 | 
         
            +
                load_checkpoint(model, ckpt_path, map_location='cpu')
         
     | 
| 65 | 
         
            +
                model.cpu()
         
     | 
| 66 | 
         
            +
                model.eval()
         
     | 
| 67 | 
         
            +
                return model
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            model_remodiffuse = create_remodiffuse()
         
     | 
| 70 | 
         
            +
            # model_motiondiffuse = create_motiondiffuse()
         
     | 
| 71 | 
         
            +
            # model_mdm = create_mdm()
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            mean_path = "data/datasets/human_ml3d/mean.npy"
         
     | 
| 74 | 
         
            +
            std_path = "data/datasets/human_ml3d/std.npy"
         
     | 
| 75 | 
         
            +
            mean = np.load(mean_path)
         
     | 
| 76 | 
         
            +
            std = np.load(std_path)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def show_generation_result(model, text, motion_length, result_path):
         
     | 
| 80 | 
         
            +
                device = 'cpu'
         
     | 
| 81 | 
         
            +
                motion = torch.zeros(1, motion_length, 263).to(device)
         
     | 
| 82 | 
         
            +
                motion_mask = torch.ones(1, motion_length).to(device)
         
     | 
| 83 | 
         
            +
                motion_length = torch.Tensor([motion_length]).long().to(device)
         
     | 
| 84 | 
         
            +
                model = model.to(device)
         
     | 
| 85 | 
         
            +
                input = {
         
     | 
| 86 | 
         
            +
                    'motion': motion,
         
     | 
| 87 | 
         
            +
                    'motion_mask': motion_mask,
         
     | 
| 88 | 
         
            +
                    'motion_length': motion_length,
         
     | 
| 89 | 
         
            +
                    'motion_metas': [{'text': text}],
         
     | 
| 90 | 
         
            +
                }
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                all_pred_motion = []
         
     | 
| 93 | 
         
            +
                with torch.no_grad():
         
     | 
| 94 | 
         
            +
                    input['inference_kwargs'] = {}
         
     | 
| 95 | 
         
            +
                    output_list = []
         
     | 
| 96 | 
         
            +
                    output = model(**input)[0]['pred_motion']
         
     | 
| 97 | 
         
            +
                    pred_motion = output.cpu().detach().numpy()
         
     | 
| 98 | 
         
            +
                    pred_motion = pred_motion * std + mean
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                plot_t2m(pred_motion, result_path, None, text)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
            def generate(prompt, length):
         
     | 
| 103 | 
         
            +
                if not os.path.exists("outputs"):
         
     | 
| 104 | 
         
            +
                    os.mkdir("outputs")
         
     | 
| 105 | 
         
            +
                result_path = "outputs/" + str(hash(prompt)) + ".mp4"
         
     | 
| 106 | 
         
            +
                show_generation_result(model_remodiffuse, prompt, length, result_path)
         
     | 
| 107 | 
         
            +
                return result_path
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
            demo = gr.Interface(
         
     | 
| 110 | 
         
            +
                fn=generate,
         
     | 
| 111 | 
         
            +
                inputs=["text", gr.Slider(20, 196, value=60)],
         
     | 
| 112 | 
         
            +
                examples=[
         
     | 
| 113 | 
         
            +
                    ["the man throws a punch with each hand.", 58],
         
     | 
| 114 | 
         
            +
                    ["a person spins quickly and takes off running.", 29],
         
     | 
| 115 | 
         
            +
                    ["a person quickly waves with their right hand", 46],
         
     | 
| 116 | 
         
            +
                    ["a person performing a slight bow", 89],
         
     | 
| 117 | 
         
            +
                ],
         
     | 
| 118 | 
         
            +
                outputs="video",
         
     | 
| 119 | 
         
            +
                title="ReMoDiffuse: Retrieval-Augmented Motion Diffusion Model",
         
     | 
| 120 | 
         
            +
                description="This is an interactive demo for ReMoDiffuse. For more information, feel free to visit our project page(https://mingyuan-zhang.github.io/projects/ReMoDiffuse.html).")
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            demo.queue()
         
     | 
| 123 | 
         
            +
            demo.launch()
         
     | 
    	
        configs/_base_/datasets/human_ml3d_bs128.py
    ADDED
    
    | 
         @@ -0,0 +1,60 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # dataset settings
         
     | 
| 2 | 
         
            +
            data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat']
         
     | 
| 3 | 
         
            +
            meta_keys = ['text', 'token']
         
     | 
| 4 | 
         
            +
            train_pipeline = [
         
     | 
| 5 | 
         
            +
                dict(
         
     | 
| 6 | 
         
            +
                    type='Normalize',
         
     | 
| 7 | 
         
            +
                    mean_path='data/datasets/human_ml3d/mean.npy',
         
     | 
| 8 | 
         
            +
                    std_path='data/datasets/human_ml3d/std.npy'),
         
     | 
| 9 | 
         
            +
                dict(type='Crop', crop_size=196),
         
     | 
| 10 | 
         
            +
                dict(type='ToTensor', keys=data_keys),
         
     | 
| 11 | 
         
            +
                dict(type='Collect', keys=data_keys, meta_keys=meta_keys)
         
     | 
| 12 | 
         
            +
            ]
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            data = dict(
         
     | 
| 15 | 
         
            +
                samples_per_gpu=128,
         
     | 
| 16 | 
         
            +
                workers_per_gpu=1,
         
     | 
| 17 | 
         
            +
                train=dict(
         
     | 
| 18 | 
         
            +
                    type='RepeatDataset',
         
     | 
| 19 | 
         
            +
                    dataset=dict(
         
     | 
| 20 | 
         
            +
                        type='TextMotionDataset',
         
     | 
| 21 | 
         
            +
                        dataset_name='human_ml3d',
         
     | 
| 22 | 
         
            +
                        data_prefix='data',
         
     | 
| 23 | 
         
            +
                        pipeline=train_pipeline,
         
     | 
| 24 | 
         
            +
                        ann_file='train.txt',
         
     | 
| 25 | 
         
            +
                        motion_dir='motions',
         
     | 
| 26 | 
         
            +
                        text_dir='texts',
         
     | 
| 27 | 
         
            +
                        token_dir='tokens',
         
     | 
| 28 | 
         
            +
                        clip_feat_dir='clip_feats',
         
     | 
| 29 | 
         
            +
                    ),
         
     | 
| 30 | 
         
            +
                    times=200
         
     | 
| 31 | 
         
            +
                ),
         
     | 
| 32 | 
         
            +
                test=dict(
         
     | 
| 33 | 
         
            +
                    type='TextMotionDataset',
         
     | 
| 34 | 
         
            +
                    dataset_name='human_ml3d',
         
     | 
| 35 | 
         
            +
                    data_prefix='data',
         
     | 
| 36 | 
         
            +
                    pipeline=train_pipeline,
         
     | 
| 37 | 
         
            +
                    ann_file='test.txt',
         
     | 
| 38 | 
         
            +
                    motion_dir='motions',
         
     | 
| 39 | 
         
            +
                    text_dir='texts',
         
     | 
| 40 | 
         
            +
                    token_dir='tokens',
         
     | 
| 41 | 
         
            +
                    clip_feat_dir='clip_feats',
         
     | 
| 42 | 
         
            +
                    eval_cfg=dict(
         
     | 
| 43 | 
         
            +
                        shuffle_indexes=True,
         
     | 
| 44 | 
         
            +
                        replication_times=20,
         
     | 
| 45 | 
         
            +
                        replication_reduction='statistics',
         
     | 
| 46 | 
         
            +
                        text_encoder_name='human_ml3d',
         
     | 
| 47 | 
         
            +
                        text_encoder_path='data/evaluators/human_ml3d/finest.tar',
         
     | 
| 48 | 
         
            +
                        motion_encoder_name='human_ml3d',
         
     | 
| 49 | 
         
            +
                        motion_encoder_path='data/evaluators/human_ml3d/finest.tar',
         
     | 
| 50 | 
         
            +
                        metrics=[
         
     | 
| 51 | 
         
            +
                            dict(type='R Precision', batch_size=32, top_k=3),
         
     | 
| 52 | 
         
            +
                            dict(type='Matching Score', batch_size=32),
         
     | 
| 53 | 
         
            +
                            dict(type='FID'),
         
     | 
| 54 | 
         
            +
                            dict(type='Diversity', num_samples=300),
         
     | 
| 55 | 
         
            +
                            dict(type='MultiModality', num_samples=100, num_repeats=30, num_picks=10)
         
     | 
| 56 | 
         
            +
                        ]
         
     | 
| 57 | 
         
            +
                    ),
         
     | 
| 58 | 
         
            +
                    test_mode=True
         
     | 
| 59 | 
         
            +
                )
         
     | 
| 60 | 
         
            +
            )
         
     | 
    	
        configs/_base_/datasets/kit_ml_bs128.py
    ADDED
    
    | 
         @@ -0,0 +1,60 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # dataset settings
         
     | 
| 2 | 
         
            +
            data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat'] 
         
     | 
| 3 | 
         
            +
            meta_keys = ['text', 'token']
         
     | 
| 4 | 
         
            +
            train_pipeline = [
         
     | 
| 5 | 
         
            +
                dict(type='Crop', crop_size=196),
         
     | 
| 6 | 
         
            +
                dict(
         
     | 
| 7 | 
         
            +
                    type='Normalize',
         
     | 
| 8 | 
         
            +
                    mean_path='data/datasets/kit_ml/mean.npy',
         
     | 
| 9 | 
         
            +
                    std_path='data/datasets/kit_ml/std.npy'),
         
     | 
| 10 | 
         
            +
                dict(type='ToTensor', keys=data_keys),
         
     | 
| 11 | 
         
            +
                dict(type='Collect', keys=data_keys, meta_keys=meta_keys)
         
     | 
| 12 | 
         
            +
            ]
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            data = dict(
         
     | 
| 15 | 
         
            +
                samples_per_gpu=128,
         
     | 
| 16 | 
         
            +
                workers_per_gpu=1,
         
     | 
| 17 | 
         
            +
                train=dict(
         
     | 
| 18 | 
         
            +
                    type='RepeatDataset',
         
     | 
| 19 | 
         
            +
                    dataset=dict(
         
     | 
| 20 | 
         
            +
                        type='TextMotionDataset',
         
     | 
| 21 | 
         
            +
                        dataset_name='kit_ml',
         
     | 
| 22 | 
         
            +
                        data_prefix='data',
         
     | 
| 23 | 
         
            +
                        pipeline=train_pipeline,
         
     | 
| 24 | 
         
            +
                        ann_file='train.txt',
         
     | 
| 25 | 
         
            +
                        motion_dir='motions',
         
     | 
| 26 | 
         
            +
                        text_dir='texts',
         
     | 
| 27 | 
         
            +
                        token_dir='tokens',
         
     | 
| 28 | 
         
            +
                        clip_feat_dir='clip_feats',
         
     | 
| 29 | 
         
            +
                    ),
         
     | 
| 30 | 
         
            +
                    times=100
         
     | 
| 31 | 
         
            +
                ),
         
     | 
| 32 | 
         
            +
                test=dict(
         
     | 
| 33 | 
         
            +
                    type='TextMotionDataset',
         
     | 
| 34 | 
         
            +
                    dataset_name='kit_ml',
         
     | 
| 35 | 
         
            +
                    data_prefix='data',
         
     | 
| 36 | 
         
            +
                    pipeline=train_pipeline,
         
     | 
| 37 | 
         
            +
                    ann_file='test.txt',
         
     | 
| 38 | 
         
            +
                    motion_dir='motions',
         
     | 
| 39 | 
         
            +
                    text_dir='texts',
         
     | 
| 40 | 
         
            +
                    token_dir='tokens',
         
     | 
| 41 | 
         
            +
                    clip_feat_dir='clip_feats',
         
     | 
| 42 | 
         
            +
                    eval_cfg=dict(
         
     | 
| 43 | 
         
            +
                        shuffle_indexes=True,
         
     | 
| 44 | 
         
            +
                        replication_times=20,
         
     | 
| 45 | 
         
            +
                        replication_reduction='statistics',
         
     | 
| 46 | 
         
            +
                        text_encoder_name='kit_ml',
         
     | 
| 47 | 
         
            +
                        text_encoder_path='data/evaluators/kit_ml/finest.tar',
         
     | 
| 48 | 
         
            +
                        motion_encoder_name='kit_ml',
         
     | 
| 49 | 
         
            +
                        motion_encoder_path='data/evaluators/kit_ml/finest.tar',
         
     | 
| 50 | 
         
            +
                        metrics=[
         
     | 
| 51 | 
         
            +
                            dict(type='R Precision', batch_size=32, top_k=3),
         
     | 
| 52 | 
         
            +
                            dict(type='Matching Score', batch_size=32),
         
     | 
| 53 | 
         
            +
                            dict(type='FID'),
         
     | 
| 54 | 
         
            +
                            dict(type='Diversity', num_samples=300),
         
     | 
| 55 | 
         
            +
                            dict(type='MultiModality', num_samples=50, num_repeats=30, num_picks=10)
         
     | 
| 56 | 
         
            +
                        ]
         
     | 
| 57 | 
         
            +
                    ),
         
     | 
| 58 | 
         
            +
                    test_mode=True
         
     | 
| 59 | 
         
            +
                )
         
     | 
| 60 | 
         
            +
            )
         
     | 
    	
        configs/mdm/mdm_t2m_official.py
    ADDED
    
    | 
         @@ -0,0 +1,67 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            _base_ = ['../_base_/datasets/human_ml3d_bs128.py']
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # checkpoint saving
         
     | 
| 4 | 
         
            +
            checkpoint_config = dict(interval=1)
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            dist_params = dict(backend='nccl')
         
     | 
| 7 | 
         
            +
            log_level = 'INFO'
         
     | 
| 8 | 
         
            +
            load_from = None
         
     | 
| 9 | 
         
            +
            resume_from = None
         
     | 
| 10 | 
         
            +
            workflow = [('train', 1)]
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # optimizer
         
     | 
| 13 | 
         
            +
            optimizer = dict(type='Adam', lr=1e-4)
         
     | 
| 14 | 
         
            +
            optimizer_config = dict(grad_clip=None)
         
     | 
| 15 | 
         
            +
            # learning policy
         
     | 
| 16 | 
         
            +
            lr_config = dict(policy='step', step=[])
         
     | 
| 17 | 
         
            +
            runner = dict(type='EpochBasedRunner', max_epochs=50)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            log_config = dict(
         
     | 
| 20 | 
         
            +
                interval=50,
         
     | 
| 21 | 
         
            +
                hooks=[
         
     | 
| 22 | 
         
            +
                    dict(type='TextLoggerHook'),
         
     | 
| 23 | 
         
            +
                    # dict(type='TensorboardLoggerHook')
         
     | 
| 24 | 
         
            +
                ])
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            input_feats = 263
         
     | 
| 27 | 
         
            +
            max_seq_len = 196
         
     | 
| 28 | 
         
            +
            latent_dim = 512
         
     | 
| 29 | 
         
            +
            time_embed_dim = 2048
         
     | 
| 30 | 
         
            +
            text_latent_dim = 256
         
     | 
| 31 | 
         
            +
            ff_size = 1024
         
     | 
| 32 | 
         
            +
            num_layers = 8
         
     | 
| 33 | 
         
            +
            num_heads = 4
         
     | 
| 34 | 
         
            +
            dropout = 0.1
         
     | 
| 35 | 
         
            +
            cond_mask_prob = 0.1
         
     | 
| 36 | 
         
            +
            # model settings
         
     | 
| 37 | 
         
            +
            model = dict(
         
     | 
| 38 | 
         
            +
                type='MotionDiffusion',
         
     | 
| 39 | 
         
            +
                model=dict(
         
     | 
| 40 | 
         
            +
                    type='MDMTransformer',
         
     | 
| 41 | 
         
            +
                    input_feats=input_feats,
         
     | 
| 42 | 
         
            +
                    latent_dim=latent_dim,
         
     | 
| 43 | 
         
            +
                    ff_size=ff_size,
         
     | 
| 44 | 
         
            +
                    num_layers=num_layers,
         
     | 
| 45 | 
         
            +
                    num_heads=num_heads,
         
     | 
| 46 | 
         
            +
                    dropout=dropout,
         
     | 
| 47 | 
         
            +
                    time_embed_dim=time_embed_dim,
         
     | 
| 48 | 
         
            +
                    cond_mask_prob=cond_mask_prob,
         
     | 
| 49 | 
         
            +
                    guide_scale=2.5,
         
     | 
| 50 | 
         
            +
                    clip_version='ViT-B/32',
         
     | 
| 51 | 
         
            +
                    use_official_ckpt=True
         
     | 
| 52 | 
         
            +
                ),
         
     | 
| 53 | 
         
            +
                loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
         
     | 
| 54 | 
         
            +
                diffusion_train=dict(
         
     | 
| 55 | 
         
            +
                    beta_scheduler='cosine',
         
     | 
| 56 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 57 | 
         
            +
                    model_mean_type='start_x',
         
     | 
| 58 | 
         
            +
                    model_var_type='fixed_small',
         
     | 
| 59 | 
         
            +
                ),
         
     | 
| 60 | 
         
            +
                diffusion_test=dict(
         
     | 
| 61 | 
         
            +
                    beta_scheduler='cosine',
         
     | 
| 62 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 63 | 
         
            +
                    model_mean_type='start_x',
         
     | 
| 64 | 
         
            +
                    model_var_type='fixed_small',
         
     | 
| 65 | 
         
            +
                ),
         
     | 
| 66 | 
         
            +
                inference_type='ddpm'
         
     | 
| 67 | 
         
            +
            )
         
     | 
    	
        configs/motiondiffuse/motiondiffuse_kit.py
    ADDED
    
    | 
         @@ -0,0 +1,89 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            _base_ = ['../_base_/datasets/kit_ml_bs128.py']
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # checkpoint saving
         
     | 
| 4 | 
         
            +
            checkpoint_config = dict(interval=1)
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            dist_params = dict(backend='nccl')
         
     | 
| 7 | 
         
            +
            log_level = 'INFO'
         
     | 
| 8 | 
         
            +
            load_from = None
         
     | 
| 9 | 
         
            +
            resume_from = None
         
     | 
| 10 | 
         
            +
            workflow = [('train', 1)]
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # optimizer
         
     | 
| 13 | 
         
            +
            optimizer = dict(type='Adam', lr=2e-4)
         
     | 
| 14 | 
         
            +
            optimizer_config = dict(grad_clip=None)
         
     | 
| 15 | 
         
            +
            # learning policy
         
     | 
| 16 | 
         
            +
            lr_config = dict(policy='step', step=[])
         
     | 
| 17 | 
         
            +
            runner = dict(type='EpochBasedRunner', max_epochs=50)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            log_config = dict(
         
     | 
| 20 | 
         
            +
                interval=50,
         
     | 
| 21 | 
         
            +
                hooks=[
         
     | 
| 22 | 
         
            +
                    dict(type='TextLoggerHook'),
         
     | 
| 23 | 
         
            +
                    # dict(type='TensorboardLoggerHook')
         
     | 
| 24 | 
         
            +
                ])
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            input_feats = 251
         
     | 
| 27 | 
         
            +
            max_seq_len = 196
         
     | 
| 28 | 
         
            +
            latent_dim = 512
         
     | 
| 29 | 
         
            +
            time_embed_dim = 2048
         
     | 
| 30 | 
         
            +
            text_latent_dim = 256
         
     | 
| 31 | 
         
            +
            ff_size = 1024
         
     | 
| 32 | 
         
            +
            num_heads = 8
         
     | 
| 33 | 
         
            +
            dropout = 0
         
     | 
| 34 | 
         
            +
            # model settings
         
     | 
| 35 | 
         
            +
            model = dict(
         
     | 
| 36 | 
         
            +
                type='MotionDiffusion',
         
     | 
| 37 | 
         
            +
                model=dict(
         
     | 
| 38 | 
         
            +
                    type='MotionDiffuseTransformer',
         
     | 
| 39 | 
         
            +
                    input_feats=input_feats,
         
     | 
| 40 | 
         
            +
                    max_seq_len=max_seq_len,
         
     | 
| 41 | 
         
            +
                    latent_dim=latent_dim,
         
     | 
| 42 | 
         
            +
                    time_embed_dim=time_embed_dim,
         
     | 
| 43 | 
         
            +
                    num_layers=8,
         
     | 
| 44 | 
         
            +
                    sa_block_cfg=dict(
         
     | 
| 45 | 
         
            +
                        type='EfficientSelfAttention',
         
     | 
| 46 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 47 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 48 | 
         
            +
                        dropout=dropout,
         
     | 
| 49 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 50 | 
         
            +
                    ),
         
     | 
| 51 | 
         
            +
                    ca_block_cfg=dict(
         
     | 
| 52 | 
         
            +
                        type='EfficientCrossAttention',
         
     | 
| 53 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 54 | 
         
            +
                        text_latent_dim=text_latent_dim,
         
     | 
| 55 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 56 | 
         
            +
                        dropout=dropout,
         
     | 
| 57 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 58 | 
         
            +
                    ),
         
     | 
| 59 | 
         
            +
                    ffn_cfg=dict(
         
     | 
| 60 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 61 | 
         
            +
                        ffn_dim=ff_size,
         
     | 
| 62 | 
         
            +
                        dropout=dropout,
         
     | 
| 63 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 64 | 
         
            +
                    ),
         
     | 
| 65 | 
         
            +
                    text_encoder=dict(
         
     | 
| 66 | 
         
            +
                        pretrained_model='clip',
         
     | 
| 67 | 
         
            +
                        latent_dim=text_latent_dim,
         
     | 
| 68 | 
         
            +
                        num_layers=4,
         
     | 
| 69 | 
         
            +
                        num_heads=4,
         
     | 
| 70 | 
         
            +
                        ff_size=2048,
         
     | 
| 71 | 
         
            +
                        dropout=dropout,
         
     | 
| 72 | 
         
            +
                        use_text_proj=True
         
     | 
| 73 | 
         
            +
                    )
         
     | 
| 74 | 
         
            +
                ),
         
     | 
| 75 | 
         
            +
                loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
         
     | 
| 76 | 
         
            +
                diffusion_train=dict(
         
     | 
| 77 | 
         
            +
                    beta_scheduler='linear',
         
     | 
| 78 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 79 | 
         
            +
                    model_mean_type='epsilon',
         
     | 
| 80 | 
         
            +
                    model_var_type='fixed_small',
         
     | 
| 81 | 
         
            +
                ),
         
     | 
| 82 | 
         
            +
                diffusion_test=dict(
         
     | 
| 83 | 
         
            +
                    beta_scheduler='linear',
         
     | 
| 84 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 85 | 
         
            +
                    model_mean_type='epsilon',
         
     | 
| 86 | 
         
            +
                    model_var_type='fixed_small',
         
     | 
| 87 | 
         
            +
                ),
         
     | 
| 88 | 
         
            +
                inference_type='ddpm'
         
     | 
| 89 | 
         
            +
            )
         
     | 
    	
        configs/motiondiffuse/motiondiffuse_t2m.py
    ADDED
    
    | 
         @@ -0,0 +1,90 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            _base_ = ['../_base_/datasets/human_ml3d_bs128.py']
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # checkpoint saving
         
     | 
| 4 | 
         
            +
            checkpoint_config = dict(interval=1)
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            dist_params = dict(backend='nccl')
         
     | 
| 7 | 
         
            +
            log_level = 'INFO'
         
     | 
| 8 | 
         
            +
            load_from = None
         
     | 
| 9 | 
         
            +
            resume_from = None
         
     | 
| 10 | 
         
            +
            workflow = [('train', 1)]
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # optimizer
         
     | 
| 13 | 
         
            +
            optimizer = dict(type='Adam', lr=2e-4)
         
     | 
| 14 | 
         
            +
            optimizer_config = dict(grad_clip=None)
         
     | 
| 15 | 
         
            +
            # learning policy
         
     | 
| 16 | 
         
            +
            lr_config = dict(policy='step', step=[])
         
     | 
| 17 | 
         
            +
            runner = dict(type='EpochBasedRunner', max_epochs=50)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            log_config = dict(
         
     | 
| 20 | 
         
            +
                interval=50,
         
     | 
| 21 | 
         
            +
                hooks=[
         
     | 
| 22 | 
         
            +
                    dict(type='TextLoggerHook'),
         
     | 
| 23 | 
         
            +
                    # dict(type='TensorboardLoggerHook')
         
     | 
| 24 | 
         
            +
                ])
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            input_feats = 263
         
     | 
| 27 | 
         
            +
            max_seq_len = 196
         
     | 
| 28 | 
         
            +
            latent_dim = 512
         
     | 
| 29 | 
         
            +
            time_embed_dim = 2048
         
     | 
| 30 | 
         
            +
            text_latent_dim = 256
         
     | 
| 31 | 
         
            +
            ff_size = 1024
         
     | 
| 32 | 
         
            +
            num_heads = 8
         
     | 
| 33 | 
         
            +
            dropout = 0
         
     | 
| 34 | 
         
            +
            # model settings
         
     | 
| 35 | 
         
            +
            model = dict(
         
     | 
| 36 | 
         
            +
                type='MotionDiffusion',
         
     | 
| 37 | 
         
            +
                model=dict(
         
     | 
| 38 | 
         
            +
                    type='MotionDiffuseTransformer',
         
     | 
| 39 | 
         
            +
                    input_feats=input_feats,
         
     | 
| 40 | 
         
            +
                    max_seq_len=max_seq_len,
         
     | 
| 41 | 
         
            +
                    latent_dim=latent_dim,
         
     | 
| 42 | 
         
            +
                    time_embed_dim=time_embed_dim,
         
     | 
| 43 | 
         
            +
                    num_layers=8,
         
     | 
| 44 | 
         
            +
                    sa_block_cfg=dict(
         
     | 
| 45 | 
         
            +
                        type='EfficientSelfAttention',
         
     | 
| 46 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 47 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 48 | 
         
            +
                        dropout=dropout,
         
     | 
| 49 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 50 | 
         
            +
                    ),
         
     | 
| 51 | 
         
            +
                    ca_block_cfg=dict(
         
     | 
| 52 | 
         
            +
                        type='EfficientCrossAttention',
         
     | 
| 53 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 54 | 
         
            +
                        text_latent_dim=text_latent_dim,
         
     | 
| 55 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 56 | 
         
            +
                        dropout=dropout,
         
     | 
| 57 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 58 | 
         
            +
                    ),
         
     | 
| 59 | 
         
            +
                    ffn_cfg=dict(
         
     | 
| 60 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 61 | 
         
            +
                        ffn_dim=ff_size,
         
     | 
| 62 | 
         
            +
                        dropout=dropout,
         
     | 
| 63 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 64 | 
         
            +
                    ),
         
     | 
| 65 | 
         
            +
                    text_encoder=dict(
         
     | 
| 66 | 
         
            +
                        pretrained_model='clip',
         
     | 
| 67 | 
         
            +
                        latent_dim=text_latent_dim,
         
     | 
| 68 | 
         
            +
                        num_layers=4,
         
     | 
| 69 | 
         
            +
                        num_heads=4,
         
     | 
| 70 | 
         
            +
                        ff_size=2048,
         
     | 
| 71 | 
         
            +
                        dropout=dropout,
         
     | 
| 72 | 
         
            +
                        use_text_proj=True
         
     | 
| 73 | 
         
            +
                    )
         
     | 
| 74 | 
         
            +
                ),
         
     | 
| 75 | 
         
            +
                loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
         
     | 
| 76 | 
         
            +
                diffusion_train=dict(
         
     | 
| 77 | 
         
            +
                    beta_scheduler='linear',
         
     | 
| 78 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 79 | 
         
            +
                    model_mean_type='epsilon',
         
     | 
| 80 | 
         
            +
                    model_var_type='fixed_small',
         
     | 
| 81 | 
         
            +
                ),
         
     | 
| 82 | 
         
            +
                diffusion_test=dict(
         
     | 
| 83 | 
         
            +
                    beta_scheduler='linear',
         
     | 
| 84 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 85 | 
         
            +
                    model_mean_type='epsilon',
         
     | 
| 86 | 
         
            +
                    model_var_type='fixed_small',
         
     | 
| 87 | 
         
            +
                ),
         
     | 
| 88 | 
         
            +
                inference_type='ddpm'
         
     | 
| 89 | 
         
            +
            )
         
     | 
| 90 | 
         
            +
            data = dict(samples_per_gpu=128)
         
     | 
    	
        configs/remodiffuse/remodiffuse_kit.py
    ADDED
    
    | 
         @@ -0,0 +1,141 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            _base_ = ['../_base_/datasets/kit_ml_bs128.py']
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # checkpoint saving
         
     | 
| 4 | 
         
            +
            checkpoint_config = dict(interval=1)
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            dist_params = dict(backend='nccl')
         
     | 
| 7 | 
         
            +
            log_level = 'INFO'
         
     | 
| 8 | 
         
            +
            load_from = None
         
     | 
| 9 | 
         
            +
            resume_from = None
         
     | 
| 10 | 
         
            +
            workflow = [('train', 1)]
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # optimizer
         
     | 
| 13 | 
         
            +
            optimizer = dict(type='Adam', lr=2e-4)
         
     | 
| 14 | 
         
            +
            optimizer_config = dict(grad_clip=None)
         
     | 
| 15 | 
         
            +
            # learning policy
         
     | 
| 16 | 
         
            +
            lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False)
         
     | 
| 17 | 
         
            +
            runner = dict(type='EpochBasedRunner', max_epochs=20)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            log_config = dict(
         
     | 
| 20 | 
         
            +
                interval=50,
         
     | 
| 21 | 
         
            +
                hooks=[
         
     | 
| 22 | 
         
            +
                    dict(type='TextLoggerHook'),
         
     | 
| 23 | 
         
            +
                    # dict(type='TensorboardLoggerHook')
         
     | 
| 24 | 
         
            +
                ])
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            input_feats = 251
         
     | 
| 27 | 
         
            +
            max_seq_len = 196
         
     | 
| 28 | 
         
            +
            latent_dim = 512
         
     | 
| 29 | 
         
            +
            time_embed_dim = 2048
         
     | 
| 30 | 
         
            +
            text_latent_dim = 256
         
     | 
| 31 | 
         
            +
            ff_size = 1024
         
     | 
| 32 | 
         
            +
            num_heads = 8
         
     | 
| 33 | 
         
            +
            dropout = 0
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def scale_func(timestep):
         
     | 
| 36 | 
         
            +
                import random
         
     | 
| 37 | 
         
            +
                w = (1 - (1000 - timestep) / 1000) * 4.0 + 1
         
     | 
| 38 | 
         
            +
                if timestep > 100:
         
     | 
| 39 | 
         
            +
                    if random.randint(0, 1) == 0:
         
     | 
| 40 | 
         
            +
                        output = {
         
     | 
| 41 | 
         
            +
                            'both_coef': w,
         
     | 
| 42 | 
         
            +
                            'text_coef': 0,
         
     | 
| 43 | 
         
            +
                            'retr_coef': 1 - w,
         
     | 
| 44 | 
         
            +
                            'none_coef': 0
         
     | 
| 45 | 
         
            +
                        }
         
     | 
| 46 | 
         
            +
                    else:
         
     | 
| 47 | 
         
            +
                        output = {
         
     | 
| 48 | 
         
            +
                            'both_coef': 0,
         
     | 
| 49 | 
         
            +
                            'text_coef': w,
         
     | 
| 50 | 
         
            +
                            'retr_coef': 0,
         
     | 
| 51 | 
         
            +
                            'none_coef': 1 - w
         
     | 
| 52 | 
         
            +
                        }
         
     | 
| 53 | 
         
            +
                else:
         
     | 
| 54 | 
         
            +
                    both_coef = 0.78123
         
     | 
| 55 | 
         
            +
                    text_coef = 0.39284
         
     | 
| 56 | 
         
            +
                    retr_coef = -0.12475
         
     | 
| 57 | 
         
            +
                    none_coef = 1 - both_coef - text_coef - retr_coef
         
     | 
| 58 | 
         
            +
                    output = {
         
     | 
| 59 | 
         
            +
                        'both_coef': both_coef,
         
     | 
| 60 | 
         
            +
                        'text_coef': text_coef,
         
     | 
| 61 | 
         
            +
                        'retr_coef': retr_coef,
         
     | 
| 62 | 
         
            +
                        'none_coef': none_coef
         
     | 
| 63 | 
         
            +
                    }
         
     | 
| 64 | 
         
            +
                return output
         
     | 
| 65 | 
         
            +
                
         
     | 
| 66 | 
         
            +
            # model settings
         
     | 
| 67 | 
         
            +
            model = dict(
         
     | 
| 68 | 
         
            +
                type='MotionDiffusion',
         
     | 
| 69 | 
         
            +
                model=dict(
         
     | 
| 70 | 
         
            +
                    type='ReMoDiffuseTransformer',
         
     | 
| 71 | 
         
            +
                    input_feats=input_feats,
         
     | 
| 72 | 
         
            +
                    max_seq_len=max_seq_len,
         
     | 
| 73 | 
         
            +
                    latent_dim=latent_dim,
         
     | 
| 74 | 
         
            +
                    time_embed_dim=time_embed_dim,
         
     | 
| 75 | 
         
            +
                    num_layers=4,
         
     | 
| 76 | 
         
            +
                    ca_block_cfg=dict(
         
     | 
| 77 | 
         
            +
                        type='SemanticsModulatedAttention',
         
     | 
| 78 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 79 | 
         
            +
                        text_latent_dim=text_latent_dim,
         
     | 
| 80 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 81 | 
         
            +
                        dropout=dropout,
         
     | 
| 82 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 83 | 
         
            +
                    ),
         
     | 
| 84 | 
         
            +
                    ffn_cfg=dict(
         
     | 
| 85 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 86 | 
         
            +
                        ffn_dim=ff_size,
         
     | 
| 87 | 
         
            +
                        dropout=dropout,
         
     | 
| 88 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 89 | 
         
            +
                    ),
         
     | 
| 90 | 
         
            +
                    text_encoder=dict(
         
     | 
| 91 | 
         
            +
                        pretrained_model='clip',
         
     | 
| 92 | 
         
            +
                        latent_dim=text_latent_dim,
         
     | 
| 93 | 
         
            +
                        num_layers=2,
         
     | 
| 94 | 
         
            +
                        ff_size=2048,
         
     | 
| 95 | 
         
            +
                        dropout=dropout,
         
     | 
| 96 | 
         
            +
                        use_text_proj=False
         
     | 
| 97 | 
         
            +
                    ),
         
     | 
| 98 | 
         
            +
                    retrieval_cfg=dict(
         
     | 
| 99 | 
         
            +
                        num_retrieval=2,
         
     | 
| 100 | 
         
            +
                        stride=4,
         
     | 
| 101 | 
         
            +
                        num_layers=2,
         
     | 
| 102 | 
         
            +
                        num_motion_layers=2,
         
     | 
| 103 | 
         
            +
                        kinematic_coef=0.1,
         
     | 
| 104 | 
         
            +
                        topk=2,
         
     | 
| 105 | 
         
            +
                        retrieval_file='data/database/kit_text_train.npz',
         
     | 
| 106 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 107 | 
         
            +
                        output_dim=latent_dim,
         
     | 
| 108 | 
         
            +
                        max_seq_len=max_seq_len,
         
     | 
| 109 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 110 | 
         
            +
                        ff_size=ff_size,
         
     | 
| 111 | 
         
            +
                        dropout=dropout,
         
     | 
| 112 | 
         
            +
                        ffn_cfg=dict(
         
     | 
| 113 | 
         
            +
                            latent_dim=latent_dim,
         
     | 
| 114 | 
         
            +
                            ffn_dim=ff_size,
         
     | 
| 115 | 
         
            +
                            dropout=dropout,
         
     | 
| 116 | 
         
            +
                        ),
         
     | 
| 117 | 
         
            +
                        sa_block_cfg=dict(
         
     | 
| 118 | 
         
            +
                            type='EfficientSelfAttention',
         
     | 
| 119 | 
         
            +
                            latent_dim=latent_dim,
         
     | 
| 120 | 
         
            +
                            num_heads=num_heads,
         
     | 
| 121 | 
         
            +
                            dropout=dropout
         
     | 
| 122 | 
         
            +
                        ),
         
     | 
| 123 | 
         
            +
                    ),
         
     | 
| 124 | 
         
            +
                    scale_func=scale_func
         
     | 
| 125 | 
         
            +
                ),
         
     | 
| 126 | 
         
            +
                loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
         
     | 
| 127 | 
         
            +
                diffusion_train=dict(
         
     | 
| 128 | 
         
            +
                    beta_scheduler='linear',
         
     | 
| 129 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 130 | 
         
            +
                    model_mean_type='start_x',
         
     | 
| 131 | 
         
            +
                    model_var_type='fixed_large',
         
     | 
| 132 | 
         
            +
                ),
         
     | 
| 133 | 
         
            +
                diffusion_test=dict(
         
     | 
| 134 | 
         
            +
                    beta_scheduler='linear',
         
     | 
| 135 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 136 | 
         
            +
                    model_mean_type='start_x',
         
     | 
| 137 | 
         
            +
                    model_var_type='fixed_large',
         
     | 
| 138 | 
         
            +
                    respace='15,15,8,6,6',
         
     | 
| 139 | 
         
            +
                ),
         
     | 
| 140 | 
         
            +
                inference_type='ddim'
         
     | 
| 141 | 
         
            +
            )
         
     | 
    	
        configs/remodiffuse/remodiffuse_t2m.py
    ADDED
    
    | 
         @@ -0,0 +1,141 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            _base_ = ['../_base_/datasets/human_ml3d_bs128.py']
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # checkpoint saving
         
     | 
| 4 | 
         
            +
            checkpoint_config = dict(interval=1)
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            dist_params = dict(backend='nccl')
         
     | 
| 7 | 
         
            +
            log_level = 'INFO'
         
     | 
| 8 | 
         
            +
            load_from = None
         
     | 
| 9 | 
         
            +
            resume_from = None
         
     | 
| 10 | 
         
            +
            workflow = [('train', 1)]
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # optimizer
         
     | 
| 13 | 
         
            +
            optimizer = dict(type='Adam', lr=2e-4)
         
     | 
| 14 | 
         
            +
            optimizer_config = dict(grad_clip=None)
         
     | 
| 15 | 
         
            +
            # learning policy
         
     | 
| 16 | 
         
            +
            lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False)
         
     | 
| 17 | 
         
            +
            runner = dict(type='EpochBasedRunner', max_epochs=40)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            log_config = dict(
         
     | 
| 20 | 
         
            +
                interval=50,
         
     | 
| 21 | 
         
            +
                hooks=[
         
     | 
| 22 | 
         
            +
                    dict(type='TextLoggerHook'),
         
     | 
| 23 | 
         
            +
                    # dict(type='TensorboardLoggerHook')
         
     | 
| 24 | 
         
            +
                ])
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            input_feats = 263
         
     | 
| 27 | 
         
            +
            max_seq_len = 196
         
     | 
| 28 | 
         
            +
            latent_dim = 512
         
     | 
| 29 | 
         
            +
            time_embed_dim = 2048
         
     | 
| 30 | 
         
            +
            text_latent_dim = 256
         
     | 
| 31 | 
         
            +
            ff_size = 1024
         
     | 
| 32 | 
         
            +
            num_heads = 8
         
     | 
| 33 | 
         
            +
            dropout = 0
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def scale_func(timestep):
         
     | 
| 36 | 
         
            +
                import random
         
     | 
| 37 | 
         
            +
                w = (1 - (1000 - timestep) / 1000) * 6.5 + 1
         
     | 
| 38 | 
         
            +
                if timestep > 100:
         
     | 
| 39 | 
         
            +
                    if random.randint(0, 1) == 0:
         
     | 
| 40 | 
         
            +
                        output = {
         
     | 
| 41 | 
         
            +
                            'both_coef': w,
         
     | 
| 42 | 
         
            +
                            'text_coef': 0,
         
     | 
| 43 | 
         
            +
                            'retr_coef': 1 - w,
         
     | 
| 44 | 
         
            +
                            'none_coef': 0
         
     | 
| 45 | 
         
            +
                        }
         
     | 
| 46 | 
         
            +
                    else:
         
     | 
| 47 | 
         
            +
                        output = {
         
     | 
| 48 | 
         
            +
                            'both_coef': 0,
         
     | 
| 49 | 
         
            +
                            'text_coef': w,
         
     | 
| 50 | 
         
            +
                            'retr_coef': 0,
         
     | 
| 51 | 
         
            +
                            'none_coef': 1 - w
         
     | 
| 52 | 
         
            +
                        }
         
     | 
| 53 | 
         
            +
                else:
         
     | 
| 54 | 
         
            +
                    both_coef = 0.52351
         
     | 
| 55 | 
         
            +
                    text_coef = -0.28419
         
     | 
| 56 | 
         
            +
                    retr_coef = 2.39872
         
     | 
| 57 | 
         
            +
                    none_coef = 1 - both_coef - text_coef - retr_coef
         
     | 
| 58 | 
         
            +
                    output = {
         
     | 
| 59 | 
         
            +
                        'both_coef': both_coef,
         
     | 
| 60 | 
         
            +
                        'text_coef': text_coef,
         
     | 
| 61 | 
         
            +
                        'retr_coef': retr_coef,
         
     | 
| 62 | 
         
            +
                        'none_coef': none_coef
         
     | 
| 63 | 
         
            +
                    }
         
     | 
| 64 | 
         
            +
                return output
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            # model settings
         
     | 
| 67 | 
         
            +
            model = dict(
         
     | 
| 68 | 
         
            +
                type='MotionDiffusion',
         
     | 
| 69 | 
         
            +
                model=dict(
         
     | 
| 70 | 
         
            +
                    type='ReMoDiffuseTransformer',
         
     | 
| 71 | 
         
            +
                    input_feats=input_feats,
         
     | 
| 72 | 
         
            +
                    max_seq_len=max_seq_len,
         
     | 
| 73 | 
         
            +
                    latent_dim=latent_dim,
         
     | 
| 74 | 
         
            +
                    time_embed_dim=time_embed_dim,
         
     | 
| 75 | 
         
            +
                    num_layers=4,
         
     | 
| 76 | 
         
            +
                    ca_block_cfg=dict(
         
     | 
| 77 | 
         
            +
                        type='SemanticsModulatedAttention',
         
     | 
| 78 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 79 | 
         
            +
                        text_latent_dim=text_latent_dim,
         
     | 
| 80 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 81 | 
         
            +
                        dropout=dropout,
         
     | 
| 82 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 83 | 
         
            +
                    ),
         
     | 
| 84 | 
         
            +
                    ffn_cfg=dict(
         
     | 
| 85 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 86 | 
         
            +
                        ffn_dim=ff_size,
         
     | 
| 87 | 
         
            +
                        dropout=dropout,
         
     | 
| 88 | 
         
            +
                        time_embed_dim=time_embed_dim
         
     | 
| 89 | 
         
            +
                    ),
         
     | 
| 90 | 
         
            +
                    text_encoder=dict(
         
     | 
| 91 | 
         
            +
                        pretrained_model='clip',
         
     | 
| 92 | 
         
            +
                        latent_dim=text_latent_dim,
         
     | 
| 93 | 
         
            +
                        num_layers=2,
         
     | 
| 94 | 
         
            +
                        ff_size=2048,
         
     | 
| 95 | 
         
            +
                        dropout=dropout,
         
     | 
| 96 | 
         
            +
                        use_text_proj=False
         
     | 
| 97 | 
         
            +
                    ),
         
     | 
| 98 | 
         
            +
                    retrieval_cfg=dict(
         
     | 
| 99 | 
         
            +
                        num_retrieval=2,
         
     | 
| 100 | 
         
            +
                        stride=4,
         
     | 
| 101 | 
         
            +
                        num_layers=2,
         
     | 
| 102 | 
         
            +
                        num_motion_layers=2,
         
     | 
| 103 | 
         
            +
                        kinematic_coef=0.1,
         
     | 
| 104 | 
         
            +
                        topk=2,
         
     | 
| 105 | 
         
            +
                        retrieval_file='data/database/t2m_text_train.npz',
         
     | 
| 106 | 
         
            +
                        latent_dim=latent_dim,
         
     | 
| 107 | 
         
            +
                        output_dim=latent_dim,
         
     | 
| 108 | 
         
            +
                        max_seq_len=max_seq_len,
         
     | 
| 109 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 110 | 
         
            +
                        ff_size=ff_size,
         
     | 
| 111 | 
         
            +
                        dropout=dropout,
         
     | 
| 112 | 
         
            +
                        ffn_cfg=dict(
         
     | 
| 113 | 
         
            +
                            latent_dim=latent_dim,
         
     | 
| 114 | 
         
            +
                            ffn_dim=ff_size,
         
     | 
| 115 | 
         
            +
                            dropout=dropout,
         
     | 
| 116 | 
         
            +
                        ),
         
     | 
| 117 | 
         
            +
                        sa_block_cfg=dict(
         
     | 
| 118 | 
         
            +
                            type='EfficientSelfAttention',
         
     | 
| 119 | 
         
            +
                            latent_dim=latent_dim,
         
     | 
| 120 | 
         
            +
                            num_heads=num_heads,
         
     | 
| 121 | 
         
            +
                            dropout=dropout
         
     | 
| 122 | 
         
            +
                        ),
         
     | 
| 123 | 
         
            +
                    ),
         
     | 
| 124 | 
         
            +
                    scale_func=scale_func
         
     | 
| 125 | 
         
            +
                ),
         
     | 
| 126 | 
         
            +
                loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'),
         
     | 
| 127 | 
         
            +
                diffusion_train=dict(
         
     | 
| 128 | 
         
            +
                    beta_scheduler='linear',
         
     | 
| 129 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 130 | 
         
            +
                    model_mean_type='start_x',
         
     | 
| 131 | 
         
            +
                    model_var_type='fixed_large',
         
     | 
| 132 | 
         
            +
                ),
         
     | 
| 133 | 
         
            +
                diffusion_test=dict(
         
     | 
| 134 | 
         
            +
                    beta_scheduler='linear',
         
     | 
| 135 | 
         
            +
                    diffusion_steps=1000,
         
     | 
| 136 | 
         
            +
                    model_mean_type='start_x',
         
     | 
| 137 | 
         
            +
                    model_var_type='fixed_large',
         
     | 
| 138 | 
         
            +
                    respace='15,15,8,6,6',
         
     | 
| 139 | 
         
            +
                ),
         
     | 
| 140 | 
         
            +
                inference_type='ddim'
         
     | 
| 141 | 
         
            +
            )
         
     | 
    	
        data/database/t2m_text_train.npz
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:ae3575b686e29623f9e1715345b052726650f53c5bfcc770d9fb87a827a60249
         
     | 
| 3 | 
         
            +
            size 1462801786
         
     | 
    	
        data/datasets/human_ml3d/mean.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:2d73483a5b53e017b4044fe363164d7c185082a02ae7f69525ea70c5ccfd4a85
         
     | 
| 3 | 
         
            +
            size 1180
         
     | 
    	
        data/datasets/human_ml3d/std.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:6a6d720e004b6da18e8033d739de6078cbc7c1c8fad0ff62eee86f173e4430a2
         
     | 
| 3 | 
         
            +
            size 1180
         
     | 
    	
        data/datasets/kit_ml/mean.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:9e23fac51db2215ab5666324226be48f27efd6a6e7b22ebd17c28e0f056a7c22
         
     | 
| 3 | 
         
            +
            size 2136
         
     | 
    	
        data/datasets/kit_ml/std.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:296a60656cea07e65ee64512d73d47c0412df0698b35194116330661be32fa90
         
     | 
| 3 | 
         
            +
            size 2136
         
     | 
    	
        logs/mdm/mdm_t2m/latest.pth
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:8810255fb8df9eed6211537de9826f07ff73862f367cbf91532d84fd4c9a497e
         
     | 
| 3 | 
         
            +
            size 81791550
         
     | 
    	
        logs/motiondiffuse/motiondiffuse_t2m/latest.pth
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:521baa6ba60865710bc75b99f393b133e45dc18083229a2258a16e5dc65f904a
         
     | 
| 3 | 
         
            +
            size 348728194
         
     | 
    	
        logs/remodiffuse/remodiffuse_t2m/latest.pth
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:aaa34b3328942769478e96283678424c95c4b817ca6f7162c4cf1fc512d4951b
         
     | 
| 3 | 
         
            +
            size 187939375
         
     | 
    	
        mogen/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import warnings
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import mmcv
         
     | 
| 4 | 
         
            +
            from packaging.version import parse
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from .version import __version__
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def digit_version(version_str: str, length: int = 4):
         
     | 
| 10 | 
         
            +
                """Convert a version string into a tuple of integers.
         
     | 
| 11 | 
         
            +
                This method is usually used for comparing two versions. For pre-release
         
     | 
| 12 | 
         
            +
                versions: alpha < beta < rc.
         
     | 
| 13 | 
         
            +
                Args:
         
     | 
| 14 | 
         
            +
                    version_str (str): The version string.
         
     | 
| 15 | 
         
            +
                    length (int): The maximum number of version levels. Default: 4.
         
     | 
| 16 | 
         
            +
                Returns:
         
     | 
| 17 | 
         
            +
                    tuple[int]: The version info in digits (integers).
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
                version = parse(version_str)
         
     | 
| 20 | 
         
            +
                assert version.release, f'failed to parse version {version_str}'
         
     | 
| 21 | 
         
            +
                release = list(version.release)
         
     | 
| 22 | 
         
            +
                release = release[:length]
         
     | 
| 23 | 
         
            +
                if len(release) < length:
         
     | 
| 24 | 
         
            +
                    release = release + [0] * (length - len(release))
         
     | 
| 25 | 
         
            +
                if version.is_prerelease:
         
     | 
| 26 | 
         
            +
                    mapping = {'a': -3, 'b': -2, 'rc': -1}
         
     | 
| 27 | 
         
            +
                    val = -4
         
     | 
| 28 | 
         
            +
                    # version.pre can be None
         
     | 
| 29 | 
         
            +
                    if version.pre:
         
     | 
| 30 | 
         
            +
                        if version.pre[0] not in mapping:
         
     | 
| 31 | 
         
            +
                            warnings.warn(f'unknown prerelease version {version.pre[0]}, '
         
     | 
| 32 | 
         
            +
                                          'version checking may go wrong')
         
     | 
| 33 | 
         
            +
                        else:
         
     | 
| 34 | 
         
            +
                            val = mapping[version.pre[0]]
         
     | 
| 35 | 
         
            +
                        release.extend([val, version.pre[-1]])
         
     | 
| 36 | 
         
            +
                    else:
         
     | 
| 37 | 
         
            +
                        release.extend([val, 0])
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                elif version.is_postrelease:
         
     | 
| 40 | 
         
            +
                    release.extend([1, version.post])
         
     | 
| 41 | 
         
            +
                else:
         
     | 
| 42 | 
         
            +
                    release.extend([0, 0])
         
     | 
| 43 | 
         
            +
                return tuple(release)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            mmcv_minimum_version = '1.4.2'
         
     | 
| 47 | 
         
            +
            mmcv_maximum_version = '1.9.0'
         
     | 
| 48 | 
         
            +
            mmcv_version = digit_version(mmcv.__version__)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            assert (mmcv_version >= digit_version(mmcv_minimum_version)
         
     | 
| 52 | 
         
            +
                    and mmcv_version <= digit_version(mmcv_maximum_version)), \
         
     | 
| 53 | 
         
            +
                f'MMCV=={mmcv.__version__} is used but incompatible. ' \
         
     | 
| 54 | 
         
            +
                f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            __all__ = ['__version__', 'digit_version']
         
     | 
    	
        mogen/apis/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from mogen.apis import test, train
         
     | 
| 2 | 
         
            +
            from mogen.apis.test import (
         
     | 
| 3 | 
         
            +
                collect_results_cpu,
         
     | 
| 4 | 
         
            +
                collect_results_gpu,
         
     | 
| 5 | 
         
            +
                multi_gpu_test,
         
     | 
| 6 | 
         
            +
                single_gpu_test,
         
     | 
| 7 | 
         
            +
            )
         
     | 
| 8 | 
         
            +
            from mogen.apis.train import set_random_seed, train_model
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            __all__ = [
         
     | 
| 11 | 
         
            +
                'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
         
     | 
| 12 | 
         
            +
                'single_gpu_test', 'set_random_seed', 'train_model'
         
     | 
| 13 | 
         
            +
            ]
         
     | 
    	
        mogen/apis/test.py
    ADDED
    
    | 
         @@ -0,0 +1,160 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os.path as osp
         
     | 
| 2 | 
         
            +
            import pickle
         
     | 
| 3 | 
         
            +
            import shutil
         
     | 
| 4 | 
         
            +
            import tempfile
         
     | 
| 5 | 
         
            +
            import time
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import mmcv
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import torch.distributed as dist
         
     | 
| 10 | 
         
            +
            from mmcv.runner import get_dist_info
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def single_gpu_test(model, data_loader):
         
     | 
| 14 | 
         
            +
                """Test with single gpu."""
         
     | 
| 15 | 
         
            +
                model.eval()
         
     | 
| 16 | 
         
            +
                results = []
         
     | 
| 17 | 
         
            +
                dataset = data_loader.dataset
         
     | 
| 18 | 
         
            +
                prog_bar = mmcv.ProgressBar(len(dataset))
         
     | 
| 19 | 
         
            +
                for i, data in enumerate(data_loader):
         
     | 
| 20 | 
         
            +
                    with torch.no_grad():
         
     | 
| 21 | 
         
            +
                        result = model(return_loss=False, **data)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                    batch_size = len(result)
         
     | 
| 24 | 
         
            +
                    if isinstance(result, list):
         
     | 
| 25 | 
         
            +
                        results.extend(result)
         
     | 
| 26 | 
         
            +
                    else:
         
     | 
| 27 | 
         
            +
                        results.append(result)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    batch_size = data['motion'].size(0)
         
     | 
| 30 | 
         
            +
                    for _ in range(batch_size):
         
     | 
| 31 | 
         
            +
                        prog_bar.update()
         
     | 
| 32 | 
         
            +
                return results
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
         
     | 
| 36 | 
         
            +
                """Test model with multiple gpus.
         
     | 
| 37 | 
         
            +
                This method tests model with multiple gpus and collects the results
         
     | 
| 38 | 
         
            +
                under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
         
     | 
| 39 | 
         
            +
                it encodes results to gpu tensors and use gpu communication for results
         
     | 
| 40 | 
         
            +
                collection. On cpu mode it saves the results on different gpus to 'tmpdir'
         
     | 
| 41 | 
         
            +
                and collects them by the rank 0 worker.
         
     | 
| 42 | 
         
            +
                Args:
         
     | 
| 43 | 
         
            +
                    model (nn.Module): Model to be tested.
         
     | 
| 44 | 
         
            +
                    data_loader (nn.Dataloader): Pytorch data loader.
         
     | 
| 45 | 
         
            +
                    tmpdir (str): Path of directory to save the temporary results from
         
     | 
| 46 | 
         
            +
                        different gpus under cpu mode.
         
     | 
| 47 | 
         
            +
                    gpu_collect (bool): Option to use either gpu or cpu to collect results.
         
     | 
| 48 | 
         
            +
                Returns:
         
     | 
| 49 | 
         
            +
                    list: The prediction results.
         
     | 
| 50 | 
         
            +
                """
         
     | 
| 51 | 
         
            +
                model.eval()
         
     | 
| 52 | 
         
            +
                results = []
         
     | 
| 53 | 
         
            +
                dataset = data_loader.dataset
         
     | 
| 54 | 
         
            +
                rank, world_size = get_dist_info()
         
     | 
| 55 | 
         
            +
                if rank == 0:
         
     | 
| 56 | 
         
            +
                    # Check if tmpdir is valid for cpu_collect
         
     | 
| 57 | 
         
            +
                    if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)):
         
     | 
| 58 | 
         
            +
                        raise OSError((f'The tmpdir {tmpdir} already exists.',
         
     | 
| 59 | 
         
            +
                                       ' Since tmpdir will be deleted after testing,',
         
     | 
| 60 | 
         
            +
                                       ' please make sure you specify an empty one.'))
         
     | 
| 61 | 
         
            +
                    prog_bar = mmcv.ProgressBar(len(dataset))
         
     | 
| 62 | 
         
            +
                time.sleep(2)  # This line can prevent deadlock problem in some cases.
         
     | 
| 63 | 
         
            +
                for i, data in enumerate(data_loader):
         
     | 
| 64 | 
         
            +
                    with torch.no_grad():
         
     | 
| 65 | 
         
            +
                        result = model(return_loss=False, **data)
         
     | 
| 66 | 
         
            +
                    if isinstance(result, list):
         
     | 
| 67 | 
         
            +
                        results.extend(result)
         
     | 
| 68 | 
         
            +
                    else:
         
     | 
| 69 | 
         
            +
                        results.append(result)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    if rank == 0:
         
     | 
| 72 | 
         
            +
                        batch_size = data['motion'].size(0)
         
     | 
| 73 | 
         
            +
                        for _ in range(batch_size * world_size):
         
     | 
| 74 | 
         
            +
                            prog_bar.update()
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                # collect results from all ranks
         
     | 
| 77 | 
         
            +
                if gpu_collect:
         
     | 
| 78 | 
         
            +
                    results = collect_results_gpu(results, len(dataset))
         
     | 
| 79 | 
         
            +
                else:
         
     | 
| 80 | 
         
            +
                    results = collect_results_cpu(results, len(dataset), tmpdir)
         
     | 
| 81 | 
         
            +
                return results
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
            def collect_results_cpu(result_part, size, tmpdir=None):
         
     | 
| 85 | 
         
            +
                """Collect results in cpu."""
         
     | 
| 86 | 
         
            +
                rank, world_size = get_dist_info()
         
     | 
| 87 | 
         
            +
                # create a tmp dir if it is not specified
         
     | 
| 88 | 
         
            +
                if tmpdir is None:
         
     | 
| 89 | 
         
            +
                    MAX_LEN = 512
         
     | 
| 90 | 
         
            +
                    # 32 is whitespace
         
     | 
| 91 | 
         
            +
                    dir_tensor = torch.full((MAX_LEN, ),
         
     | 
| 92 | 
         
            +
                                            32,
         
     | 
| 93 | 
         
            +
                                            dtype=torch.uint8,
         
     | 
| 94 | 
         
            +
                                            device='cuda')
         
     | 
| 95 | 
         
            +
                    if rank == 0:
         
     | 
| 96 | 
         
            +
                        mmcv.mkdir_or_exist('.dist_test')
         
     | 
| 97 | 
         
            +
                        tmpdir = tempfile.mkdtemp(dir='.dist_test')
         
     | 
| 98 | 
         
            +
                        tmpdir = torch.tensor(
         
     | 
| 99 | 
         
            +
                            bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
         
     | 
| 100 | 
         
            +
                        dir_tensor[:len(tmpdir)] = tmpdir
         
     | 
| 101 | 
         
            +
                    dist.broadcast(dir_tensor, 0)
         
     | 
| 102 | 
         
            +
                    tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
         
     | 
| 103 | 
         
            +
                else:
         
     | 
| 104 | 
         
            +
                    mmcv.mkdir_or_exist(tmpdir)
         
     | 
| 105 | 
         
            +
                # dump the part result to the dir
         
     | 
| 106 | 
         
            +
                mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
         
     | 
| 107 | 
         
            +
                dist.barrier()
         
     | 
| 108 | 
         
            +
                # collect all parts
         
     | 
| 109 | 
         
            +
                if rank != 0:
         
     | 
| 110 | 
         
            +
                    return None
         
     | 
| 111 | 
         
            +
                else:
         
     | 
| 112 | 
         
            +
                    # load results of all parts from tmp dir
         
     | 
| 113 | 
         
            +
                    part_list = []
         
     | 
| 114 | 
         
            +
                    for i in range(world_size):
         
     | 
| 115 | 
         
            +
                        part_file = osp.join(tmpdir, f'part_{i}.pkl')
         
     | 
| 116 | 
         
            +
                        part_result = mmcv.load(part_file)
         
     | 
| 117 | 
         
            +
                        part_list.append(part_result)
         
     | 
| 118 | 
         
            +
                    # sort the results
         
     | 
| 119 | 
         
            +
                    ordered_results = []
         
     | 
| 120 | 
         
            +
                    for res in zip(*part_list):
         
     | 
| 121 | 
         
            +
                        ordered_results.extend(list(res))
         
     | 
| 122 | 
         
            +
                    # the dataloader may pad some samples
         
     | 
| 123 | 
         
            +
                    ordered_results = ordered_results[:size]
         
     | 
| 124 | 
         
            +
                    # remove tmp dir
         
     | 
| 125 | 
         
            +
                    shutil.rmtree(tmpdir)
         
     | 
| 126 | 
         
            +
                    return ordered_results
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
            def collect_results_gpu(result_part, size):
         
     | 
| 130 | 
         
            +
                """Collect results in gpu."""
         
     | 
| 131 | 
         
            +
                rank, world_size = get_dist_info()
         
     | 
| 132 | 
         
            +
                # dump result part to tensor with pickle
         
     | 
| 133 | 
         
            +
                part_tensor = torch.tensor(
         
     | 
| 134 | 
         
            +
                    bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
         
     | 
| 135 | 
         
            +
                # gather all result part tensor shape
         
     | 
| 136 | 
         
            +
                shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
         
     | 
| 137 | 
         
            +
                shape_list = [shape_tensor.clone() for _ in range(world_size)]
         
     | 
| 138 | 
         
            +
                dist.all_gather(shape_list, shape_tensor)
         
     | 
| 139 | 
         
            +
                # padding result part tensor to max length
         
     | 
| 140 | 
         
            +
                shape_max = torch.tensor(shape_list).max()
         
     | 
| 141 | 
         
            +
                part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
         
     | 
| 142 | 
         
            +
                part_send[:shape_tensor[0]] = part_tensor
         
     | 
| 143 | 
         
            +
                part_recv_list = [
         
     | 
| 144 | 
         
            +
                    part_tensor.new_zeros(shape_max) for _ in range(world_size)
         
     | 
| 145 | 
         
            +
                ]
         
     | 
| 146 | 
         
            +
                # gather all result part
         
     | 
| 147 | 
         
            +
                dist.all_gather(part_recv_list, part_send)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                if rank == 0:
         
     | 
| 150 | 
         
            +
                    part_list = []
         
     | 
| 151 | 
         
            +
                    for recv, shape in zip(part_recv_list, shape_list):
         
     | 
| 152 | 
         
            +
                        part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
         
     | 
| 153 | 
         
            +
                        part_list.append(part_result)
         
     | 
| 154 | 
         
            +
                    # sort the results
         
     | 
| 155 | 
         
            +
                    ordered_results = []
         
     | 
| 156 | 
         
            +
                    for res in zip(*part_list):
         
     | 
| 157 | 
         
            +
                        ordered_results.extend(list(res))
         
     | 
| 158 | 
         
            +
                    # the dataloader may pad some samples
         
     | 
| 159 | 
         
            +
                    ordered_results = ordered_results[:size]
         
     | 
| 160 | 
         
            +
                    return ordered_results
         
     | 
    	
        mogen/apis/train.py
    ADDED
    
    | 
         @@ -0,0 +1,165 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import random
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
         
     | 
| 7 | 
         
            +
            from mmcv.runner import (
         
     | 
| 8 | 
         
            +
                DistSamplerSeedHook,
         
     | 
| 9 | 
         
            +
                Fp16OptimizerHook,
         
     | 
| 10 | 
         
            +
                OptimizerHook,
         
     | 
| 11 | 
         
            +
                build_runner,
         
     | 
| 12 | 
         
            +
            )
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            from mogen.core.distributed_wrapper import DistributedDataParallelWrapper
         
     | 
| 15 | 
         
            +
            from mogen.core.evaluation import DistEvalHook, EvalHook
         
     | 
| 16 | 
         
            +
            from mogen.core.optimizer import build_optimizers
         
     | 
| 17 | 
         
            +
            from mogen.datasets import build_dataloader, build_dataset
         
     | 
| 18 | 
         
            +
            from mogen.utils import get_root_logger
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def set_random_seed(seed, deterministic=False):
         
     | 
| 22 | 
         
            +
                """Set random seed.
         
     | 
| 23 | 
         
            +
                Args:
         
     | 
| 24 | 
         
            +
                    seed (int): Seed to be used.
         
     | 
| 25 | 
         
            +
                    deterministic (bool): Whether to set the deterministic option for
         
     | 
| 26 | 
         
            +
                        CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
         
     | 
| 27 | 
         
            +
                        to True and `torch.backends.cudnn.benchmark` to False.
         
     | 
| 28 | 
         
            +
                        Default: False.
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
                random.seed(seed)
         
     | 
| 31 | 
         
            +
                np.random.seed(seed)
         
     | 
| 32 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 33 | 
         
            +
                torch.cuda.manual_seed_all(seed)
         
     | 
| 34 | 
         
            +
                if deterministic:
         
     | 
| 35 | 
         
            +
                    torch.backends.cudnn.deterministic = True
         
     | 
| 36 | 
         
            +
                    torch.backends.cudnn.benchmark = False
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            def train_model(model,
         
     | 
| 40 | 
         
            +
                            dataset,
         
     | 
| 41 | 
         
            +
                            cfg,
         
     | 
| 42 | 
         
            +
                            distributed=False,
         
     | 
| 43 | 
         
            +
                            validate=False,
         
     | 
| 44 | 
         
            +
                            timestamp=None,
         
     | 
| 45 | 
         
            +
                            device='cuda',
         
     | 
| 46 | 
         
            +
                            meta=None):
         
     | 
| 47 | 
         
            +
                """Main api for training model."""
         
     | 
| 48 | 
         
            +
                logger = get_root_logger(cfg.log_level)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                # prepare data loaders
         
     | 
| 51 | 
         
            +
                dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                data_loaders = [
         
     | 
| 54 | 
         
            +
                    build_dataloader(
         
     | 
| 55 | 
         
            +
                        ds,
         
     | 
| 56 | 
         
            +
                        cfg.data.samples_per_gpu,
         
     | 
| 57 | 
         
            +
                        cfg.data.workers_per_gpu,
         
     | 
| 58 | 
         
            +
                        # cfg.gpus will be ignored if distributed
         
     | 
| 59 | 
         
            +
                        num_gpus=len(cfg.gpu_ids),
         
     | 
| 60 | 
         
            +
                        dist=distributed,
         
     | 
| 61 | 
         
            +
                        round_up=True,
         
     | 
| 62 | 
         
            +
                        seed=cfg.seed) for ds in dataset
         
     | 
| 63 | 
         
            +
                ]
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                # determine whether use adversarial training precess or not
         
     | 
| 66 | 
         
            +
                use_adverserial_train = cfg.get('use_adversarial_train', False)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                # put model on gpus
         
     | 
| 69 | 
         
            +
                if distributed:
         
     | 
| 70 | 
         
            +
                    find_unused_parameters = cfg.get('find_unused_parameters', True)
         
     | 
| 71 | 
         
            +
                    # Sets the `find_unused_parameters` parameter in
         
     | 
| 72 | 
         
            +
                    # torch.nn.parallel.DistributedDataParallel
         
     | 
| 73 | 
         
            +
                    if use_adverserial_train:
         
     | 
| 74 | 
         
            +
                        # Use DistributedDataParallelWrapper for adversarial training
         
     | 
| 75 | 
         
            +
                        model = DistributedDataParallelWrapper(
         
     | 
| 76 | 
         
            +
                            model,
         
     | 
| 77 | 
         
            +
                            device_ids=[torch.cuda.current_device()],
         
     | 
| 78 | 
         
            +
                            broadcast_buffers=False,
         
     | 
| 79 | 
         
            +
                            find_unused_parameters=find_unused_parameters)
         
     | 
| 80 | 
         
            +
                    else:
         
     | 
| 81 | 
         
            +
                        model = MMDistributedDataParallel(
         
     | 
| 82 | 
         
            +
                            model.cuda(),
         
     | 
| 83 | 
         
            +
                            device_ids=[torch.cuda.current_device()],
         
     | 
| 84 | 
         
            +
                            broadcast_buffers=False,
         
     | 
| 85 | 
         
            +
                            find_unused_parameters=find_unused_parameters)
         
     | 
| 86 | 
         
            +
                else:
         
     | 
| 87 | 
         
            +
                    if device == 'cuda':
         
     | 
| 88 | 
         
            +
                        model = MMDataParallel(
         
     | 
| 89 | 
         
            +
                            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
         
     | 
| 90 | 
         
            +
                    elif device == 'cpu':
         
     | 
| 91 | 
         
            +
                        model = model.cpu()
         
     | 
| 92 | 
         
            +
                    else:
         
     | 
| 93 | 
         
            +
                        raise ValueError(F'unsupported device name {device}.')
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                # build runner
         
     | 
| 96 | 
         
            +
                optimizer = build_optimizers(model, cfg.optimizer)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                if cfg.get('runner') is None:
         
     | 
| 99 | 
         
            +
                    cfg.runner = {
         
     | 
| 100 | 
         
            +
                        'type': 'EpochBasedRunner',
         
     | 
| 101 | 
         
            +
                        'max_epochs': cfg.total_epochs
         
     | 
| 102 | 
         
            +
                    }
         
     | 
| 103 | 
         
            +
                    warnings.warn(
         
     | 
| 104 | 
         
            +
                        'config is now expected to have a `runner` section, '
         
     | 
| 105 | 
         
            +
                        'please set `runner` in your config.', UserWarning)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                runner = build_runner(
         
     | 
| 108 | 
         
            +
                    cfg.runner,
         
     | 
| 109 | 
         
            +
                    default_args=dict(
         
     | 
| 110 | 
         
            +
                        model=model,
         
     | 
| 111 | 
         
            +
                        batch_processor=None,
         
     | 
| 112 | 
         
            +
                        optimizer=optimizer,
         
     | 
| 113 | 
         
            +
                        work_dir=cfg.work_dir,
         
     | 
| 114 | 
         
            +
                        logger=logger,
         
     | 
| 115 | 
         
            +
                        meta=meta))
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                # an ugly walkaround to make the .log and .log.json filenames the same
         
     | 
| 118 | 
         
            +
                runner.timestamp = timestamp
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                if use_adverserial_train:
         
     | 
| 121 | 
         
            +
                    # The optimizer step process is included in the train_step function
         
     | 
| 122 | 
         
            +
                    # of the model, so the runner should NOT include optimizer hook.
         
     | 
| 123 | 
         
            +
                    optimizer_config = None
         
     | 
| 124 | 
         
            +
                else:
         
     | 
| 125 | 
         
            +
                    # fp16 setting
         
     | 
| 126 | 
         
            +
                    fp16_cfg = cfg.get('fp16', None)
         
     | 
| 127 | 
         
            +
                    if fp16_cfg is not None:
         
     | 
| 128 | 
         
            +
                        optimizer_config = Fp16OptimizerHook(
         
     | 
| 129 | 
         
            +
                            **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
         
     | 
| 130 | 
         
            +
                    elif distributed and 'type' not in cfg.optimizer_config:
         
     | 
| 131 | 
         
            +
                        optimizer_config = OptimizerHook(**cfg.optimizer_config)
         
     | 
| 132 | 
         
            +
                    else:
         
     | 
| 133 | 
         
            +
                        optimizer_config = cfg.optimizer_config
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                # register hooks
         
     | 
| 136 | 
         
            +
                runner.register_training_hooks(
         
     | 
| 137 | 
         
            +
                    cfg.lr_config,
         
     | 
| 138 | 
         
            +
                    optimizer_config,
         
     | 
| 139 | 
         
            +
                    cfg.checkpoint_config,
         
     | 
| 140 | 
         
            +
                    cfg.log_config,
         
     | 
| 141 | 
         
            +
                    cfg.get('momentum_config', None),
         
     | 
| 142 | 
         
            +
                    custom_hooks_config=cfg.get('custom_hooks', None))
         
     | 
| 143 | 
         
            +
                if distributed:
         
     | 
| 144 | 
         
            +
                    runner.register_hook(DistSamplerSeedHook())
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                # register eval hooks
         
     | 
| 147 | 
         
            +
                if validate:
         
     | 
| 148 | 
         
            +
                    val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
         
     | 
| 149 | 
         
            +
                    val_dataloader = build_dataloader(
         
     | 
| 150 | 
         
            +
                        val_dataset,
         
     | 
| 151 | 
         
            +
                        samples_per_gpu=cfg.data.samples_per_gpu,
         
     | 
| 152 | 
         
            +
                        workers_per_gpu=cfg.data.workers_per_gpu,
         
     | 
| 153 | 
         
            +
                        dist=distributed,
         
     | 
| 154 | 
         
            +
                        shuffle=False,
         
     | 
| 155 | 
         
            +
                        round_up=True)
         
     | 
| 156 | 
         
            +
                    eval_cfg = cfg.get('evaluation', {})
         
     | 
| 157 | 
         
            +
                    eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
         
     | 
| 158 | 
         
            +
                    eval_hook = DistEvalHook if distributed else EvalHook
         
     | 
| 159 | 
         
            +
                    runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                if cfg.resume_from:
         
     | 
| 162 | 
         
            +
                    runner.resume(cfg.resume_from)
         
     | 
| 163 | 
         
            +
                elif cfg.load_from:
         
     | 
| 164 | 
         
            +
                    runner.load_checkpoint(cfg.load_from)
         
     | 
| 165 | 
         
            +
                runner.run(data_loaders, cfg.workflow)
         
     | 
    	
        mogen/core/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        mogen/core/distributed_wrapper.py
    ADDED
    
    | 
         @@ -0,0 +1,136 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel
         
     | 
| 5 | 
         
            +
            from mmcv.parallel.scatter_gather import scatter_kwargs
         
     | 
| 6 | 
         
            +
            from torch.cuda._utils import _get_device_index
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @MODULE_WRAPPERS.register_module()
         
     | 
| 10 | 
         
            +
            class DistributedDataParallelWrapper(nn.Module):
         
     | 
| 11 | 
         
            +
                """A DistributedDataParallel wrapper for models in 3D mesh estimation task.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                In  3D mesh estimation task, there is a need to wrap different modules in
         
     | 
| 14 | 
         
            +
                the models with separate DistributedDataParallel. Otherwise, it will cause
         
     | 
| 15 | 
         
            +
                errors for GAN training.
         
     | 
| 16 | 
         
            +
                More specific, the GAN model, usually has two sub-modules:
         
     | 
| 17 | 
         
            +
                generator and discriminator. If we wrap both of them in one
         
     | 
| 18 | 
         
            +
                standard DistributedDataParallel, it will cause errors during training,
         
     | 
| 19 | 
         
            +
                because when we update the parameters of the generator (or discriminator),
         
     | 
| 20 | 
         
            +
                the parameters of the discriminator (or generator) is not updated, which is
         
     | 
| 21 | 
         
            +
                not allowed for DistributedDataParallel.
         
     | 
| 22 | 
         
            +
                So we design this wrapper to separately wrap DistributedDataParallel
         
     | 
| 23 | 
         
            +
                for generator and discriminator.
         
     | 
| 24 | 
         
            +
                In this wrapper, we perform two operations:
         
     | 
| 25 | 
         
            +
                1. Wrap the modules in the models with separate MMDistributedDataParallel.
         
     | 
| 26 | 
         
            +
                    Note that only modules with parameters will be wrapped.
         
     | 
| 27 | 
         
            +
                2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
         
     | 
| 28 | 
         
            +
                Note that the arguments of this wrapper is the same as those in
         
     | 
| 29 | 
         
            +
                `torch.nn.parallel.distributed.DistributedDataParallel`.
         
     | 
| 30 | 
         
            +
                Args:
         
     | 
| 31 | 
         
            +
                    module (nn.Module): Module that needs to be wrapped.
         
     | 
| 32 | 
         
            +
                    device_ids (list[int | `torch.device`]): Same as that in
         
     | 
| 33 | 
         
            +
                        `torch.nn.parallel.distributed.DistributedDataParallel`.
         
     | 
| 34 | 
         
            +
                    dim (int, optional): Same as that in the official scatter function in
         
     | 
| 35 | 
         
            +
                        pytorch. Defaults to 0.
         
     | 
| 36 | 
         
            +
                    broadcast_buffers (bool): Same as that in
         
     | 
| 37 | 
         
            +
                        `torch.nn.parallel.distributed.DistributedDataParallel`.
         
     | 
| 38 | 
         
            +
                        Defaults to False.
         
     | 
| 39 | 
         
            +
                    find_unused_parameters (bool, optional): Same as that in
         
     | 
| 40 | 
         
            +
                        `torch.nn.parallel.distributed.DistributedDataParallel`.
         
     | 
| 41 | 
         
            +
                        Traverse the autograd graph of all tensors contained in returned
         
     | 
| 42 | 
         
            +
                        value of the wrapped module’s forward function. Defaults to False.
         
     | 
| 43 | 
         
            +
                    kwargs (dict): Other arguments used in
         
     | 
| 44 | 
         
            +
                        `torch.nn.parallel.distributed.DistributedDataParallel`.
         
     | 
| 45 | 
         
            +
                """
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def __init__(self,
         
     | 
| 48 | 
         
            +
                             module,
         
     | 
| 49 | 
         
            +
                             device_ids,
         
     | 
| 50 | 
         
            +
                             dim=0,
         
     | 
| 51 | 
         
            +
                             broadcast_buffers=False,
         
     | 
| 52 | 
         
            +
                             find_unused_parameters=False,
         
     | 
| 53 | 
         
            +
                             **kwargs):
         
     | 
| 54 | 
         
            +
                    super().__init__()
         
     | 
| 55 | 
         
            +
                    assert len(device_ids) == 1, (
         
     | 
| 56 | 
         
            +
                        'Currently, DistributedDataParallelWrapper only supports one'
         
     | 
| 57 | 
         
            +
                        'single CUDA device for each process.'
         
     | 
| 58 | 
         
            +
                        f'The length of device_ids must be 1, but got {len(device_ids)}.')
         
     | 
| 59 | 
         
            +
                    self.module = module
         
     | 
| 60 | 
         
            +
                    self.dim = dim
         
     | 
| 61 | 
         
            +
                    self.to_ddp(
         
     | 
| 62 | 
         
            +
                        device_ids=device_ids,
         
     | 
| 63 | 
         
            +
                        dim=dim,
         
     | 
| 64 | 
         
            +
                        broadcast_buffers=broadcast_buffers,
         
     | 
| 65 | 
         
            +
                        find_unused_parameters=find_unused_parameters,
         
     | 
| 66 | 
         
            +
                        **kwargs)
         
     | 
| 67 | 
         
            +
                    self.output_device = _get_device_index(device_ids[0], True)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                def to_ddp(self, device_ids, dim, broadcast_buffers,
         
     | 
| 70 | 
         
            +
                           find_unused_parameters, **kwargs):
         
     | 
| 71 | 
         
            +
                    """Wrap models with separate MMDistributedDataParallel.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    It only wraps the modules with parameters.
         
     | 
| 74 | 
         
            +
                    """
         
     | 
| 75 | 
         
            +
                    for name, module in self.module._modules.items():
         
     | 
| 76 | 
         
            +
                        if next(module.parameters(), None) is None:
         
     | 
| 77 | 
         
            +
                            module = module.cuda()
         
     | 
| 78 | 
         
            +
                        elif all(not p.requires_grad for p in module.parameters()):
         
     | 
| 79 | 
         
            +
                            module = module.cuda()
         
     | 
| 80 | 
         
            +
                        else:
         
     | 
| 81 | 
         
            +
                            module = MMDistributedDataParallel(
         
     | 
| 82 | 
         
            +
                                module.cuda(),
         
     | 
| 83 | 
         
            +
                                device_ids=device_ids,
         
     | 
| 84 | 
         
            +
                                dim=dim,
         
     | 
| 85 | 
         
            +
                                broadcast_buffers=broadcast_buffers,
         
     | 
| 86 | 
         
            +
                                find_unused_parameters=find_unused_parameters,
         
     | 
| 87 | 
         
            +
                                **kwargs)
         
     | 
| 88 | 
         
            +
                        self.module._modules[name] = module
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def scatter(self, inputs, kwargs, device_ids):
         
     | 
| 91 | 
         
            +
                    """Scatter function.
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    Args:
         
     | 
| 94 | 
         
            +
                        inputs (Tensor): Input Tensor.
         
     | 
| 95 | 
         
            +
                        kwargs (dict): Args for
         
     | 
| 96 | 
         
            +
                            ``mmcv.parallel.scatter_gather.scatter_kwargs``.
         
     | 
| 97 | 
         
            +
                        device_ids (int): Device id.
         
     | 
| 98 | 
         
            +
                    """
         
     | 
| 99 | 
         
            +
                    return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                def forward(self, *inputs, **kwargs):
         
     | 
| 102 | 
         
            +
                    """Forward function.
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    Args:
         
     | 
| 105 | 
         
            +
                        inputs (tuple): Input data.
         
     | 
| 106 | 
         
            +
                        kwargs (dict): Args for
         
     | 
| 107 | 
         
            +
                            ``mmcv.parallel.scatter_gather.scatter_kwargs``.
         
     | 
| 108 | 
         
            +
                    """
         
     | 
| 109 | 
         
            +
                    inputs, kwargs = self.scatter(inputs, kwargs,
         
     | 
| 110 | 
         
            +
                                                  [torch.cuda.current_device()])
         
     | 
| 111 | 
         
            +
                    return self.module(*inputs[0], **kwargs[0])
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                def train_step(self, *inputs, **kwargs):
         
     | 
| 114 | 
         
            +
                    """Train step function.
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    Args:
         
     | 
| 117 | 
         
            +
                        inputs (Tensor): Input Tensor.
         
     | 
| 118 | 
         
            +
                        kwargs (dict): Args for
         
     | 
| 119 | 
         
            +
                            ``mmcv.parallel.scatter_gather.scatter_kwargs``.
         
     | 
| 120 | 
         
            +
                    """
         
     | 
| 121 | 
         
            +
                    inputs, kwargs = self.scatter(inputs, kwargs,
         
     | 
| 122 | 
         
            +
                                                  [torch.cuda.current_device()])
         
     | 
| 123 | 
         
            +
                    output = self.module.train_step(*inputs[0], **kwargs[0])
         
     | 
| 124 | 
         
            +
                    return output
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                def val_step(self, *inputs, **kwargs):
         
     | 
| 127 | 
         
            +
                    """Validation step function.
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    Args:
         
     | 
| 130 | 
         
            +
                        inputs (tuple): Input data.
         
     | 
| 131 | 
         
            +
                        kwargs (dict): Args for ``scatter_kwargs``.
         
     | 
| 132 | 
         
            +
                    """
         
     | 
| 133 | 
         
            +
                    inputs, kwargs = self.scatter(inputs, kwargs,
         
     | 
| 134 | 
         
            +
                                                  [torch.cuda.current_device()])
         
     | 
| 135 | 
         
            +
                    output = self.module.val_step(*inputs[0], **kwargs[0])
         
     | 
| 136 | 
         
            +
                    return output
         
     | 
    	
        mogen/core/evaluation/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from mogen.core.evaluation.eval_hooks import DistEvalHook, EvalHook
         
     | 
| 2 | 
         
            +
            from mogen.core.evaluation.builder import build_evaluator
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            __all__ = ["DistEvalHook", "EvalHook", "build_evaluator"]
         
     | 
    	
        mogen/core/evaluation/builder.py
    ADDED
    
    | 
         @@ -0,0 +1,29 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import copy
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            from mmcv.utils import Registry
         
     | 
| 4 | 
         
            +
            from .evaluators.precision_evaluator import PrecisionEvaluator
         
     | 
| 5 | 
         
            +
            from .evaluators.matching_score_evaluator import MatchingScoreEvaluator
         
     | 
| 6 | 
         
            +
            from .evaluators.fid_evaluator import FIDEvaluator
         
     | 
| 7 | 
         
            +
            from .evaluators.diversity_evaluator import DiversityEvaluator
         
     | 
| 8 | 
         
            +
            from .evaluators.multimodality_evaluator import MultiModalityEvaluator
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            EVALUATORS = Registry('evaluators')
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            EVALUATORS.register_module(name='R Precision', module=PrecisionEvaluator)
         
     | 
| 13 | 
         
            +
            EVALUATORS.register_module(name='Matching Score', module=MatchingScoreEvaluator)
         
     | 
| 14 | 
         
            +
            EVALUATORS.register_module(name='FID', module=FIDEvaluator)
         
     | 
| 15 | 
         
            +
            EVALUATORS.register_module(name='Diversity', module=DiversityEvaluator)
         
     | 
| 16 | 
         
            +
            EVALUATORS.register_module(name='MultiModality', module=MultiModalityEvaluator)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def build_evaluator(metric, eval_cfg, data_len, eval_indexes):
         
     | 
| 20 | 
         
            +
                cfg = copy.deepcopy(eval_cfg)
         
     | 
| 21 | 
         
            +
                cfg.update(metric)
         
     | 
| 22 | 
         
            +
                cfg.pop('metrics')
         
     | 
| 23 | 
         
            +
                cfg['data_len'] = data_len
         
     | 
| 24 | 
         
            +
                cfg['eval_indexes'] = eval_indexes
         
     | 
| 25 | 
         
            +
                evaluator = EVALUATORS.build(cfg)
         
     | 
| 26 | 
         
            +
                if evaluator.append_indexes is not None:
         
     | 
| 27 | 
         
            +
                    for i in range(eval_cfg['replication_times']):
         
     | 
| 28 | 
         
            +
                        eval_indexes[i] = np.concatenate((eval_indexes[i], evaluator.append_indexes[i]), axis=0)
         
     | 
| 29 | 
         
            +
                return evaluator, eval_indexes
         
     | 
    	
        mogen/core/evaluation/eval_hooks.py
    ADDED
    
    | 
         @@ -0,0 +1,138 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            import tempfile
         
     | 
| 3 | 
         
            +
            import warnings
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from mmcv.runner import DistEvalHook as BaseDistEvalHook
         
     | 
| 6 | 
         
            +
            from mmcv.runner import EvalHook as BaseEvalHook
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            mogen_GREATER_KEYS = []
         
     | 
| 9 | 
         
            +
            mogen_LESS_KEYS = []
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class EvalHook(BaseEvalHook):
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def __init__(self,
         
     | 
| 15 | 
         
            +
                             dataloader,
         
     | 
| 16 | 
         
            +
                             start=None,
         
     | 
| 17 | 
         
            +
                             interval=1,
         
     | 
| 18 | 
         
            +
                             by_epoch=True,
         
     | 
| 19 | 
         
            +
                             save_best=None,
         
     | 
| 20 | 
         
            +
                             rule=None,
         
     | 
| 21 | 
         
            +
                             test_fn=None,
         
     | 
| 22 | 
         
            +
                             greater_keys=mogen_GREATER_KEYS,
         
     | 
| 23 | 
         
            +
                             less_keys=mogen_LESS_KEYS,
         
     | 
| 24 | 
         
            +
                             **eval_kwargs):
         
     | 
| 25 | 
         
            +
                    if test_fn is None:
         
     | 
| 26 | 
         
            +
                        from mogen.apis import single_gpu_test
         
     | 
| 27 | 
         
            +
                        test_fn = single_gpu_test
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    # remove "gpu_collect" from eval_kwargs
         
     | 
| 30 | 
         
            +
                    if 'gpu_collect' in eval_kwargs:
         
     | 
| 31 | 
         
            +
                        warnings.warn(
         
     | 
| 32 | 
         
            +
                            '"gpu_collect" will be deprecated in EvalHook.'
         
     | 
| 33 | 
         
            +
                            'Please remove it from the config.', DeprecationWarning)
         
     | 
| 34 | 
         
            +
                        _ = eval_kwargs.pop('gpu_collect')
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    # update "save_best" according to "key_indicator" and remove the
         
     | 
| 37 | 
         
            +
                    # latter from eval_kwargs
         
     | 
| 38 | 
         
            +
                    if 'key_indicator' in eval_kwargs or isinstance(save_best, bool):
         
     | 
| 39 | 
         
            +
                        warnings.warn(
         
     | 
| 40 | 
         
            +
                            '"key_indicator" will be deprecated in EvalHook.'
         
     | 
| 41 | 
         
            +
                            'Please use "save_best" to specify the metric key,'
         
     | 
| 42 | 
         
            +
                            'e.g., save_best="pa-mpjpe".', DeprecationWarning)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                        key_indicator = eval_kwargs.pop('key_indicator', None)
         
     | 
| 45 | 
         
            +
                        if save_best is True and key_indicator is None:
         
     | 
| 46 | 
         
            +
                            raise ValueError('key_indicator should not be None, when '
         
     | 
| 47 | 
         
            +
                                             'save_best is set to True.')
         
     | 
| 48 | 
         
            +
                        save_best = key_indicator
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    super().__init__(dataloader, start, interval, by_epoch, save_best,
         
     | 
| 51 | 
         
            +
                                     rule, test_fn, greater_keys, less_keys, **eval_kwargs)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def evaluate(self, runner, results):
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    with tempfile.TemporaryDirectory() as tmp_dir:
         
     | 
| 56 | 
         
            +
                        eval_res = self.dataloader.dataset.evaluate(
         
     | 
| 57 | 
         
            +
                            results,
         
     | 
| 58 | 
         
            +
                            work_dir=tmp_dir,
         
     | 
| 59 | 
         
            +
                            logger=runner.logger,
         
     | 
| 60 | 
         
            +
                            **self.eval_kwargs)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    for name, val in eval_res.items():
         
     | 
| 63 | 
         
            +
                        runner.log_buffer.output[name] = val
         
     | 
| 64 | 
         
            +
                    runner.log_buffer.ready = True
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    if self.save_best is not None:
         
     | 
| 67 | 
         
            +
                        if self.key_indicator == 'auto':
         
     | 
| 68 | 
         
            +
                            self._init_rule(self.rule, list(eval_res.keys())[0])
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                        return eval_res[self.key_indicator]
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    return None
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            class DistEvalHook(BaseDistEvalHook):
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                def __init__(self,
         
     | 
| 78 | 
         
            +
                             dataloader,
         
     | 
| 79 | 
         
            +
                             start=None,
         
     | 
| 80 | 
         
            +
                             interval=1,
         
     | 
| 81 | 
         
            +
                             by_epoch=True,
         
     | 
| 82 | 
         
            +
                             save_best=None,
         
     | 
| 83 | 
         
            +
                             rule=None,
         
     | 
| 84 | 
         
            +
                             test_fn=None,
         
     | 
| 85 | 
         
            +
                             greater_keys=mogen_GREATER_KEYS,
         
     | 
| 86 | 
         
            +
                             less_keys=mogen_LESS_KEYS,
         
     | 
| 87 | 
         
            +
                             broadcast_bn_buffer=True,
         
     | 
| 88 | 
         
            +
                             tmpdir=None,
         
     | 
| 89 | 
         
            +
                             gpu_collect=False,
         
     | 
| 90 | 
         
            +
                             **eval_kwargs):
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    if test_fn is None:
         
     | 
| 93 | 
         
            +
                        from mogen.apis import multi_gpu_test
         
     | 
| 94 | 
         
            +
                        test_fn = multi_gpu_test
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    # update "save_best" according to "key_indicator" and remove the
         
     | 
| 97 | 
         
            +
                    # latter from eval_kwargs
         
     | 
| 98 | 
         
            +
                    if 'key_indicator' in eval_kwargs or isinstance(save_best, bool):
         
     | 
| 99 | 
         
            +
                        warnings.warn(
         
     | 
| 100 | 
         
            +
                            '"key_indicator" will be deprecated in EvalHook.'
         
     | 
| 101 | 
         
            +
                            'Please use "save_best" to specify the metric key,'
         
     | 
| 102 | 
         
            +
                            'e.g., save_best="pa-mpjpe".', DeprecationWarning)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                        key_indicator = eval_kwargs.pop('key_indicator', None)
         
     | 
| 105 | 
         
            +
                        if save_best is True and key_indicator is None:
         
     | 
| 106 | 
         
            +
                            raise ValueError('key_indicator should not be None, when '
         
     | 
| 107 | 
         
            +
                                             'save_best is set to True.')
         
     | 
| 108 | 
         
            +
                        save_best = key_indicator
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    super().__init__(dataloader, start, interval, by_epoch, save_best,
         
     | 
| 111 | 
         
            +
                                     rule, test_fn, greater_keys, less_keys,
         
     | 
| 112 | 
         
            +
                                     broadcast_bn_buffer, tmpdir, gpu_collect,
         
     | 
| 113 | 
         
            +
                                     **eval_kwargs)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                def evaluate(self, runner, results):
         
     | 
| 116 | 
         
            +
                    """Evaluate the results.
         
     | 
| 117 | 
         
            +
                    Args:
         
     | 
| 118 | 
         
            +
                        runner (:obj:`mmcv.Runner`): The underlined training runner.
         
     | 
| 119 | 
         
            +
                        results (list): Output results.
         
     | 
| 120 | 
         
            +
                    """
         
     | 
| 121 | 
         
            +
                    with tempfile.TemporaryDirectory() as tmp_dir:
         
     | 
| 122 | 
         
            +
                        eval_res = self.dataloader.dataset.evaluate(
         
     | 
| 123 | 
         
            +
                            results,
         
     | 
| 124 | 
         
            +
                            work_dir=tmp_dir,
         
     | 
| 125 | 
         
            +
                            logger=runner.logger,
         
     | 
| 126 | 
         
            +
                            **self.eval_kwargs)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    for name, val in eval_res.items():
         
     | 
| 129 | 
         
            +
                        runner.log_buffer.output[name] = val
         
     | 
| 130 | 
         
            +
                    runner.log_buffer.ready = True
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    if self.save_best is not None:
         
     | 
| 133 | 
         
            +
                        if self.key_indicator == 'auto':
         
     | 
| 134 | 
         
            +
                            # infer from eval_results
         
     | 
| 135 | 
         
            +
                            self._init_rule(self.rule, list(eval_res.keys())[0])
         
     | 
| 136 | 
         
            +
                        return eval_res[self.key_indicator]
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    return None
         
     | 
    	
        mogen/core/evaluation/evaluators/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        mogen/core/evaluation/evaluators/base_evaluator.py
    ADDED
    
    | 
         @@ -0,0 +1,144 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            from ..utils import get_metric_statistics
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class BaseEvaluator(object):
         
     | 
| 7 | 
         
            +
                
         
     | 
| 8 | 
         
            +
                def __init__(self,
         
     | 
| 9 | 
         
            +
                             batch_size=None,
         
     | 
| 10 | 
         
            +
                             drop_last=False,
         
     | 
| 11 | 
         
            +
                             replication_times=1,
         
     | 
| 12 | 
         
            +
                             replication_reduction='statistics',
         
     | 
| 13 | 
         
            +
                             eval_begin_idx=None,
         
     | 
| 14 | 
         
            +
                             eval_end_idx=None):
         
     | 
| 15 | 
         
            +
                    self.batch_size = batch_size
         
     | 
| 16 | 
         
            +
                    self.drop_last = drop_last
         
     | 
| 17 | 
         
            +
                    self.replication_times = replication_times
         
     | 
| 18 | 
         
            +
                    self.replication_reduction = replication_reduction
         
     | 
| 19 | 
         
            +
                    assert replication_reduction in ['statistics', 'mean', 'concat']
         
     | 
| 20 | 
         
            +
                    self.eval_begin_idx = eval_begin_idx
         
     | 
| 21 | 
         
            +
                    self.eval_end_idx = eval_end_idx
         
     | 
| 22 | 
         
            +
                
         
     | 
| 23 | 
         
            +
                def evaluate(self, results):
         
     | 
| 24 | 
         
            +
                    total_len = len(results)
         
     | 
| 25 | 
         
            +
                    partial_len = total_len // self.replication_times
         
     | 
| 26 | 
         
            +
                    all_metrics = []
         
     | 
| 27 | 
         
            +
                    for replication_idx in range(self.replication_times):
         
     | 
| 28 | 
         
            +
                        partial_results = results[
         
     | 
| 29 | 
         
            +
                            replication_idx * partial_len: (replication_idx + 1) * partial_len]
         
     | 
| 30 | 
         
            +
                        if self.batch_size is not None:
         
     | 
| 31 | 
         
            +
                            batch_metrics = []
         
     | 
| 32 | 
         
            +
                            for batch_start in range(self.eval_begin_idx, self.eval_end_idx, self.batch_size):
         
     | 
| 33 | 
         
            +
                                batch_results = partial_results[batch_start: batch_start + self.batch_size]
         
     | 
| 34 | 
         
            +
                                if len(batch_results) < self.batch_size and self.drop_last:
         
     | 
| 35 | 
         
            +
                                    continue
         
     | 
| 36 | 
         
            +
                                batch_metrics.append(self.single_evaluate(batch_results))
         
     | 
| 37 | 
         
            +
                            all_metrics.append(self.concat_batch_metrics(batch_metrics))
         
     | 
| 38 | 
         
            +
                        else:
         
     | 
| 39 | 
         
            +
                            batch_results = partial_results[self.eval_begin_idx: self.eval_end_idx]
         
     | 
| 40 | 
         
            +
                            all_metrics.append(self.single_evaluate(batch_results))
         
     | 
| 41 | 
         
            +
                    all_metrics = np.stack(all_metrics, axis=0)
         
     | 
| 42 | 
         
            +
                    if self.replication_reduction == 'statistics':
         
     | 
| 43 | 
         
            +
                        values = get_metric_statistics(all_metrics, self.replication_times)
         
     | 
| 44 | 
         
            +
                    elif self.replication_reduction == 'mean':
         
     | 
| 45 | 
         
            +
                        values = np.mean(all_metrics, axis=0)
         
     | 
| 46 | 
         
            +
                    elif self.replication_reduction == 'concat':
         
     | 
| 47 | 
         
            +
                        values = all_metrics
         
     | 
| 48 | 
         
            +
                    return self.parse_values(values)
         
     | 
| 49 | 
         
            +
                
         
     | 
| 50 | 
         
            +
                def prepare_results(self, results):
         
     | 
| 51 | 
         
            +
                    text = []
         
     | 
| 52 | 
         
            +
                    pred_motion = []
         
     | 
| 53 | 
         
            +
                    pred_motion_length = []
         
     | 
| 54 | 
         
            +
                    pred_motion_mask = []
         
     | 
| 55 | 
         
            +
                    motion = []
         
     | 
| 56 | 
         
            +
                    motion_length = []
         
     | 
| 57 | 
         
            +
                    motion_mask = []
         
     | 
| 58 | 
         
            +
                    token = []
         
     | 
| 59 | 
         
            +
                    # count the maximum motion length
         
     | 
| 60 | 
         
            +
                    T = max([result['motion'].shape[0] for result in results])
         
     | 
| 61 | 
         
            +
                    for result in results:
         
     | 
| 62 | 
         
            +
                        cur_motion = result['motion']
         
     | 
| 63 | 
         
            +
                        if cur_motion.shape[0] < T:
         
     | 
| 64 | 
         
            +
                            padding_values = torch.zeros((T - cur_motion.shape[0], cur_motion.shape[1]))
         
     | 
| 65 | 
         
            +
                            padding_values = padding_values.type_as(pred_motion)
         
     | 
| 66 | 
         
            +
                            cur_motion = torch.cat([cur_motion, padding_values], dim=0)
         
     | 
| 67 | 
         
            +
                        motion.append(cur_motion)
         
     | 
| 68 | 
         
            +
                        cur_pred_motion = result['pred_motion']
         
     | 
| 69 | 
         
            +
                        if cur_pred_motion.shape[0] < T:
         
     | 
| 70 | 
         
            +
                            padding_values = torch.zeros((T - cur_pred_motion.shape[0], cur_pred_motion.shape[1]))
         
     | 
| 71 | 
         
            +
                            padding_values = padding_values.type_as(cur_pred_motion)
         
     | 
| 72 | 
         
            +
                            cur_pred_motion = torch.cat([cur_pred_motion, padding_values], dim=0)
         
     | 
| 73 | 
         
            +
                        pred_motion.append(cur_pred_motion)
         
     | 
| 74 | 
         
            +
                        cur_motion_mask = result['motion_mask']
         
     | 
| 75 | 
         
            +
                        if cur_motion_mask.shape[0] < T:
         
     | 
| 76 | 
         
            +
                            padding_values = torch.zeros((T - cur_motion_mask.shape[0]))
         
     | 
| 77 | 
         
            +
                            padding_values = padding_values.type_as(cur_motion_mask)
         
     | 
| 78 | 
         
            +
                            cur_motion_mask= torch.cat([cur_motion_mask, padding_values], dim=0)
         
     | 
| 79 | 
         
            +
                        motion_mask.append(cur_motion_mask)
         
     | 
| 80 | 
         
            +
                        cur_pred_motion_mask = result['pred_motion_mask']
         
     | 
| 81 | 
         
            +
                        if cur_pred_motion_mask.shape[0] < T:
         
     | 
| 82 | 
         
            +
                            padding_values = torch.zeros((T - cur_pred_motion_mask.shape[0]))
         
     | 
| 83 | 
         
            +
                            padding_values = padding_values.type_as(cur_pred_motion_mask)
         
     | 
| 84 | 
         
            +
                            cur_pred_motion_mask= torch.cat([cur_pred_motion_mask, padding_values], dim=0)
         
     | 
| 85 | 
         
            +
                        pred_motion_mask.append(cur_pred_motion_mask)
         
     | 
| 86 | 
         
            +
                        motion_length.append(result['motion_length'].item())
         
     | 
| 87 | 
         
            +
                        pred_motion_length.append(result['pred_motion_length'].item())
         
     | 
| 88 | 
         
            +
                        if 'text' in result.keys():  
         
     | 
| 89 | 
         
            +
                            text.append(result['text'])
         
     | 
| 90 | 
         
            +
                        if 'token' in result.keys():
         
     | 
| 91 | 
         
            +
                            token.append(result['token'])
         
     | 
| 92 | 
         
            +
                        
         
     | 
| 93 | 
         
            +
                    motion = torch.stack(motion, dim=0)
         
     | 
| 94 | 
         
            +
                    pred_motion = torch.stack(pred_motion, dim=0)
         
     | 
| 95 | 
         
            +
                    motion_mask = torch.stack(motion_mask, dim=0)
         
     | 
| 96 | 
         
            +
                    pred_motion_mask = torch.stack(pred_motion_mask, dim=0)
         
     | 
| 97 | 
         
            +
                    motion_length = torch.Tensor(motion_length).to(motion.device).long()
         
     | 
| 98 | 
         
            +
                    pred_motion_length = torch.Tensor(pred_motion_length).to(motion.device).long()
         
     | 
| 99 | 
         
            +
                    output = {
         
     | 
| 100 | 
         
            +
                        'pred_motion': pred_motion,
         
     | 
| 101 | 
         
            +
                        'pred_motion_mask': pred_motion_mask,
         
     | 
| 102 | 
         
            +
                        'pred_motion_length': pred_motion_length,
         
     | 
| 103 | 
         
            +
                        'motion': motion,
         
     | 
| 104 | 
         
            +
                        'motion_mask': motion_mask,
         
     | 
| 105 | 
         
            +
                        'motion_length': motion_length,
         
     | 
| 106 | 
         
            +
                        'text': text,
         
     | 
| 107 | 
         
            +
                        'token': token
         
     | 
| 108 | 
         
            +
                    }
         
     | 
| 109 | 
         
            +
                    return output
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                def to_device(self, device):
         
     | 
| 112 | 
         
            +
                    for model in self.model_list:
         
     | 
| 113 | 
         
            +
                        model.to(device)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                def motion_encode(self, motion, motion_length, motion_mask, device):
         
     | 
| 116 | 
         
            +
                    N = motion.shape[0]
         
     | 
| 117 | 
         
            +
                    motion_emb = []
         
     | 
| 118 | 
         
            +
                    batch_size = 32
         
     | 
| 119 | 
         
            +
                    cur_idx = 0
         
     | 
| 120 | 
         
            +
                    with torch.no_grad():
         
     | 
| 121 | 
         
            +
                        while cur_idx < N:
         
     | 
| 122 | 
         
            +
                            cur_motion = motion[cur_idx: cur_idx + batch_size].to(device)
         
     | 
| 123 | 
         
            +
                            cur_motion_length = motion_length[cur_idx: cur_idx + batch_size].to(device)
         
     | 
| 124 | 
         
            +
                            cur_motion_mask = motion_mask[cur_idx: cur_idx + batch_size].to(device)
         
     | 
| 125 | 
         
            +
                            cur_motion_emb = self.motion_encoder(cur_motion, cur_motion_length, cur_motion_mask)
         
     | 
| 126 | 
         
            +
                            motion_emb.append(cur_motion_emb)
         
     | 
| 127 | 
         
            +
                            cur_idx += batch_size
         
     | 
| 128 | 
         
            +
                    motion_emb = torch.cat(motion_emb, dim=0)
         
     | 
| 129 | 
         
            +
                    return motion_emb
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                def text_encode(self, text, token, device):
         
     | 
| 132 | 
         
            +
                    N = len(text)
         
     | 
| 133 | 
         
            +
                    text_emb = []
         
     | 
| 134 | 
         
            +
                    batch_size = 32
         
     | 
| 135 | 
         
            +
                    cur_idx = 0
         
     | 
| 136 | 
         
            +
                    with torch.no_grad():
         
     | 
| 137 | 
         
            +
                        while cur_idx < N:
         
     | 
| 138 | 
         
            +
                            cur_text = text[cur_idx: cur_idx + batch_size]
         
     | 
| 139 | 
         
            +
                            cur_token = token[cur_idx: cur_idx + batch_size]
         
     | 
| 140 | 
         
            +
                            cur_text_emb = self.text_encoder(cur_text, cur_token, device)
         
     | 
| 141 | 
         
            +
                            text_emb.append(cur_text_emb)
         
     | 
| 142 | 
         
            +
                            cur_idx += batch_size
         
     | 
| 143 | 
         
            +
                    text_emb = torch.cat(text_emb, dim=0)
         
     | 
| 144 | 
         
            +
                    return text_emb
         
     | 
    	
        mogen/core/evaluation/evaluators/diversity_evaluator.py
    ADDED
    
    | 
         @@ -0,0 +1,52 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..get_model import get_motion_model
         
     | 
| 5 | 
         
            +
            from .base_evaluator import BaseEvaluator
         
     | 
| 6 | 
         
            +
            from ..utils import calculate_diversity
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class DiversityEvaluator(BaseEvaluator):
         
     | 
| 10 | 
         
            +
                
         
     | 
| 11 | 
         
            +
                def __init__(self,
         
     | 
| 12 | 
         
            +
                             data_len=0,
         
     | 
| 13 | 
         
            +
                             motion_encoder_name=None,
         
     | 
| 14 | 
         
            +
                             motion_encoder_path=None,
         
     | 
| 15 | 
         
            +
                             num_samples=300,
         
     | 
| 16 | 
         
            +
                             batch_size=None,
         
     | 
| 17 | 
         
            +
                             drop_last=False,
         
     | 
| 18 | 
         
            +
                             replication_times=1,
         
     | 
| 19 | 
         
            +
                             replication_reduction='statistics',
         
     | 
| 20 | 
         
            +
                             **kwargs):
         
     | 
| 21 | 
         
            +
                    super().__init__(
         
     | 
| 22 | 
         
            +
                        replication_times=replication_times,
         
     | 
| 23 | 
         
            +
                        replication_reduction=replication_reduction,
         
     | 
| 24 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 25 | 
         
            +
                        drop_last=drop_last,
         
     | 
| 26 | 
         
            +
                        eval_begin_idx=0,
         
     | 
| 27 | 
         
            +
                        eval_end_idx=data_len
         
     | 
| 28 | 
         
            +
                    )
         
     | 
| 29 | 
         
            +
                    self.num_samples = num_samples
         
     | 
| 30 | 
         
            +
                    self.append_indexes = None
         
     | 
| 31 | 
         
            +
                    self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
         
     | 
| 32 | 
         
            +
                    self.model_list = [self.motion_encoder]
         
     | 
| 33 | 
         
            +
                    
         
     | 
| 34 | 
         
            +
                def single_evaluate(self, results):
         
     | 
| 35 | 
         
            +
                    results = self.prepare_results(results)
         
     | 
| 36 | 
         
            +
                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 37 | 
         
            +
                    motion = results['motion']
         
     | 
| 38 | 
         
            +
                    pred_motion = results['pred_motion']
         
     | 
| 39 | 
         
            +
                    pred_motion_length = results['pred_motion_length']
         
     | 
| 40 | 
         
            +
                    pred_motion_mask = results['pred_motion_mask']
         
     | 
| 41 | 
         
            +
                    self.motion_encoder.to(device)
         
     | 
| 42 | 
         
            +
                    self.motion_encoder.eval()
         
     | 
| 43 | 
         
            +
                    with torch.no_grad():
         
     | 
| 44 | 
         
            +
                        pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
         
     | 
| 45 | 
         
            +
                        diversity = calculate_diversity(pred_motion_emb, self.num_samples)
         
     | 
| 46 | 
         
            +
                    return diversity
         
     | 
| 47 | 
         
            +
                    
         
     | 
| 48 | 
         
            +
                def parse_values(self, values):
         
     | 
| 49 | 
         
            +
                    metrics = {}
         
     | 
| 50 | 
         
            +
                    metrics['Diversity (mean)'] = values[0]
         
     | 
| 51 | 
         
            +
                    metrics['Diversity (conf)'] = values[1]
         
     | 
| 52 | 
         
            +
                    return metrics
         
     | 
    	
        mogen/core/evaluation/evaluators/fid_evaluator.py
    ADDED
    
    | 
         @@ -0,0 +1,58 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..get_model import get_motion_model
         
     | 
| 5 | 
         
            +
            from .base_evaluator import BaseEvaluator
         
     | 
| 6 | 
         
            +
            from ..utils import (
         
     | 
| 7 | 
         
            +
                calculate_activation_statistics,
         
     | 
| 8 | 
         
            +
                calculate_frechet_distance)
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class FIDEvaluator(BaseEvaluator):
         
     | 
| 12 | 
         
            +
                
         
     | 
| 13 | 
         
            +
                def __init__(self,
         
     | 
| 14 | 
         
            +
                             data_len=0,
         
     | 
| 15 | 
         
            +
                             motion_encoder_name=None,
         
     | 
| 16 | 
         
            +
                             motion_encoder_path=None,
         
     | 
| 17 | 
         
            +
                             batch_size=None,
         
     | 
| 18 | 
         
            +
                             drop_last=False,
         
     | 
| 19 | 
         
            +
                             replication_times=1,
         
     | 
| 20 | 
         
            +
                             replication_reduction='statistics',
         
     | 
| 21 | 
         
            +
                             **kwargs):
         
     | 
| 22 | 
         
            +
                    super().__init__(
         
     | 
| 23 | 
         
            +
                        replication_times=replication_times,
         
     | 
| 24 | 
         
            +
                        replication_reduction=replication_reduction,
         
     | 
| 25 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 26 | 
         
            +
                        drop_last=drop_last,
         
     | 
| 27 | 
         
            +
                        eval_begin_idx=0,
         
     | 
| 28 | 
         
            +
                        eval_end_idx=data_len
         
     | 
| 29 | 
         
            +
                    )
         
     | 
| 30 | 
         
            +
                    self.append_indexes = None
         
     | 
| 31 | 
         
            +
                    self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
         
     | 
| 32 | 
         
            +
                    self.model_list = [self.motion_encoder]
         
     | 
| 33 | 
         
            +
                    
         
     | 
| 34 | 
         
            +
                def single_evaluate(self, results):
         
     | 
| 35 | 
         
            +
                    results = self.prepare_results(results)
         
     | 
| 36 | 
         
            +
                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 37 | 
         
            +
                    pred_motion = results['pred_motion']
         
     | 
| 38 | 
         
            +
                    
         
     | 
| 39 | 
         
            +
                    pred_motion_length = results['pred_motion_length']
         
     | 
| 40 | 
         
            +
                    pred_motion_mask = results['pred_motion_mask']
         
     | 
| 41 | 
         
            +
                    motion = results['motion']
         
     | 
| 42 | 
         
            +
                    motion_length = results['motion_length']
         
     | 
| 43 | 
         
            +
                    motion_mask = results['motion_mask']
         
     | 
| 44 | 
         
            +
                    self.motion_encoder.to(device)
         
     | 
| 45 | 
         
            +
                    self.motion_encoder.eval()
         
     | 
| 46 | 
         
            +
                    with torch.no_grad():
         
     | 
| 47 | 
         
            +
                        pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
         
     | 
| 48 | 
         
            +
                        gt_motion_emb = self.motion_encode(motion, motion_length, motion_mask, device).cpu().detach().numpy()
         
     | 
| 49 | 
         
            +
                    gt_mu, gt_cov = calculate_activation_statistics(gt_motion_emb)
         
     | 
| 50 | 
         
            +
                    pred_mu, pred_cov = calculate_activation_statistics(pred_motion_emb)
         
     | 
| 51 | 
         
            +
                    fid = calculate_frechet_distance(gt_mu, gt_cov, pred_mu, pred_cov)
         
     | 
| 52 | 
         
            +
                    return fid
         
     | 
| 53 | 
         
            +
                    
         
     | 
| 54 | 
         
            +
                def parse_values(self, values):
         
     | 
| 55 | 
         
            +
                    metrics = {}
         
     | 
| 56 | 
         
            +
                    metrics['FID (mean)'] = values[0]
         
     | 
| 57 | 
         
            +
                    metrics['FID (conf)'] = values[1]
         
     | 
| 58 | 
         
            +
                    return metrics
         
     | 
    	
        mogen/core/evaluation/evaluators/matching_score_evaluator.py
    ADDED
    
    | 
         @@ -0,0 +1,71 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..get_model import get_motion_model, get_text_model
         
     | 
| 5 | 
         
            +
            from .base_evaluator import BaseEvaluator
         
     | 
| 6 | 
         
            +
            from ..utils import calculate_top_k, euclidean_distance_matrix
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class MatchingScoreEvaluator(BaseEvaluator):
         
     | 
| 10 | 
         
            +
                
         
     | 
| 11 | 
         
            +
                def __init__(self,
         
     | 
| 12 | 
         
            +
                             data_len=0,
         
     | 
| 13 | 
         
            +
                             text_encoder_name=None,
         
     | 
| 14 | 
         
            +
                             text_encoder_path=None,
         
     | 
| 15 | 
         
            +
                             motion_encoder_name=None,
         
     | 
| 16 | 
         
            +
                             motion_encoder_path=None,
         
     | 
| 17 | 
         
            +
                             top_k=3,
         
     | 
| 18 | 
         
            +
                             batch_size=32,
         
     | 
| 19 | 
         
            +
                             drop_last=False,
         
     | 
| 20 | 
         
            +
                             replication_times=1,
         
     | 
| 21 | 
         
            +
                             replication_reduction='statistics',
         
     | 
| 22 | 
         
            +
                             **kwargs):
         
     | 
| 23 | 
         
            +
                    super().__init__(
         
     | 
| 24 | 
         
            +
                        replication_times=replication_times,
         
     | 
| 25 | 
         
            +
                        replication_reduction=replication_reduction,
         
     | 
| 26 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 27 | 
         
            +
                        drop_last=drop_last,
         
     | 
| 28 | 
         
            +
                        eval_begin_idx=0,
         
     | 
| 29 | 
         
            +
                        eval_end_idx=data_len
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
                    self.append_indexes = None
         
     | 
| 32 | 
         
            +
                    self.text_encoder = get_text_model(text_encoder_name, text_encoder_path)
         
     | 
| 33 | 
         
            +
                    self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
         
     | 
| 34 | 
         
            +
                    self.top_k = top_k
         
     | 
| 35 | 
         
            +
                    self.model_list = [self.text_encoder, self.motion_encoder]
         
     | 
| 36 | 
         
            +
                    
         
     | 
| 37 | 
         
            +
                def single_evaluate(self, results):
         
     | 
| 38 | 
         
            +
                    results = self.prepare_results(results)
         
     | 
| 39 | 
         
            +
                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 40 | 
         
            +
                    motion = results['motion']
         
     | 
| 41 | 
         
            +
                    pred_motion = results['pred_motion']
         
     | 
| 42 | 
         
            +
                    pred_motion_length = results['pred_motion_length']
         
     | 
| 43 | 
         
            +
                    pred_motion_mask = results['pred_motion_mask']
         
     | 
| 44 | 
         
            +
                    text = results['text']
         
     | 
| 45 | 
         
            +
                    token = results['token']
         
     | 
| 46 | 
         
            +
                    self.text_encoder.to(device)
         
     | 
| 47 | 
         
            +
                    self.motion_encoder.to(device)
         
     | 
| 48 | 
         
            +
                    self.text_encoder.eval()
         
     | 
| 49 | 
         
            +
                    self.motion_encoder.eval()
         
     | 
| 50 | 
         
            +
                    with torch.no_grad():
         
     | 
| 51 | 
         
            +
                        word_emb = self.text_encode(text, token, device=device).cpu().detach().numpy()
         
     | 
| 52 | 
         
            +
                        motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
         
     | 
| 53 | 
         
            +
                        dist_mat = euclidean_distance_matrix(word_emb, motion_emb)
         
     | 
| 54 | 
         
            +
                        matching_score = dist_mat.trace()
         
     | 
| 55 | 
         
            +
                        all_size = word_emb.shape[0]
         
     | 
| 56 | 
         
            +
                    return matching_score, all_size
         
     | 
| 57 | 
         
            +
                
         
     | 
| 58 | 
         
            +
                def concat_batch_metrics(self, batch_metrics):
         
     | 
| 59 | 
         
            +
                    matching_score_sum = 0
         
     | 
| 60 | 
         
            +
                    all_size = 0
         
     | 
| 61 | 
         
            +
                    for batch_matching_score, batch_all_size in batch_metrics:
         
     | 
| 62 | 
         
            +
                        matching_score_sum += batch_matching_score
         
     | 
| 63 | 
         
            +
                        all_size += batch_all_size
         
     | 
| 64 | 
         
            +
                    matching_score = matching_score_sum / all_size
         
     | 
| 65 | 
         
            +
                    return matching_score
         
     | 
| 66 | 
         
            +
                    
         
     | 
| 67 | 
         
            +
                def parse_values(self, values):
         
     | 
| 68 | 
         
            +
                    metrics = {}
         
     | 
| 69 | 
         
            +
                    metrics['Matching Score (mean)'] = values[0]
         
     | 
| 70 | 
         
            +
                    metrics['Matching Score (conf)'] = values[1]
         
     | 
| 71 | 
         
            +
                    return metrics
         
     | 
    	
        mogen/core/evaluation/evaluators/multimodality_evaluator.py
    ADDED
    
    | 
         @@ -0,0 +1,63 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..get_model import get_motion_model
         
     | 
| 5 | 
         
            +
            from .base_evaluator import BaseEvaluator
         
     | 
| 6 | 
         
            +
            from ..utils import calculate_multimodality
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class MultiModalityEvaluator(BaseEvaluator):
         
     | 
| 10 | 
         
            +
                
         
     | 
| 11 | 
         
            +
                def __init__(self,
         
     | 
| 12 | 
         
            +
                             data_len=0,
         
     | 
| 13 | 
         
            +
                             motion_encoder_name=None,
         
     | 
| 14 | 
         
            +
                             motion_encoder_path=None,
         
     | 
| 15 | 
         
            +
                             num_samples=100,
         
     | 
| 16 | 
         
            +
                             num_repeats=30,
         
     | 
| 17 | 
         
            +
                             num_picks=10,
         
     | 
| 18 | 
         
            +
                             batch_size=None,
         
     | 
| 19 | 
         
            +
                             drop_last=False,
         
     | 
| 20 | 
         
            +
                             replication_times=1,
         
     | 
| 21 | 
         
            +
                             replication_reduction='statistics',
         
     | 
| 22 | 
         
            +
                             **kwargs):
         
     | 
| 23 | 
         
            +
                    super().__init__(
         
     | 
| 24 | 
         
            +
                        replication_times=replication_times,
         
     | 
| 25 | 
         
            +
                        replication_reduction=replication_reduction,
         
     | 
| 26 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 27 | 
         
            +
                        drop_last=drop_last,
         
     | 
| 28 | 
         
            +
                        eval_begin_idx=data_len,
         
     | 
| 29 | 
         
            +
                        eval_end_idx=data_len + num_samples * num_repeats
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
                    self.num_samples = num_samples
         
     | 
| 32 | 
         
            +
                    self.num_repeats = num_repeats
         
     | 
| 33 | 
         
            +
                    self.num_picks = num_picks
         
     | 
| 34 | 
         
            +
                    self.append_indexes = []
         
     | 
| 35 | 
         
            +
                    for i in range(replication_times):
         
     | 
| 36 | 
         
            +
                        append_indexes = []
         
     | 
| 37 | 
         
            +
                        selected_indexs = np.random.choice(data_len, self.num_samples)
         
     | 
| 38 | 
         
            +
                        for index in selected_indexs:
         
     | 
| 39 | 
         
            +
                            append_indexes = append_indexes + [index]  * self.num_repeats
         
     | 
| 40 | 
         
            +
                        self.append_indexes.append(np.array(append_indexes))
         
     | 
| 41 | 
         
            +
                    self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
         
     | 
| 42 | 
         
            +
                    self.model_list = [self.motion_encoder]
         
     | 
| 43 | 
         
            +
                    
         
     | 
| 44 | 
         
            +
                def single_evaluate(self, results):
         
     | 
| 45 | 
         
            +
                    results = self.prepare_results(results)
         
     | 
| 46 | 
         
            +
                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 47 | 
         
            +
                    motion = results['motion']
         
     | 
| 48 | 
         
            +
                    pred_motion = results['pred_motion']
         
     | 
| 49 | 
         
            +
                    pred_motion_length = results['pred_motion_length']
         
     | 
| 50 | 
         
            +
                    pred_motion_mask = results['pred_motion_mask']
         
     | 
| 51 | 
         
            +
                    self.motion_encoder.to(device)
         
     | 
| 52 | 
         
            +
                    self.motion_encoder.eval()
         
     | 
| 53 | 
         
            +
                    with torch.no_grad():
         
     | 
| 54 | 
         
            +
                        pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
         
     | 
| 55 | 
         
            +
                    pred_motion_emb = pred_motion_emb.reshape((self.num_samples, self.num_repeats, -1))
         
     | 
| 56 | 
         
            +
                    multimodality = calculate_multimodality(pred_motion_emb, self.num_picks)
         
     | 
| 57 | 
         
            +
                    return multimodality
         
     | 
| 58 | 
         
            +
                    
         
     | 
| 59 | 
         
            +
                def parse_values(self, values):
         
     | 
| 60 | 
         
            +
                    metrics = {}
         
     | 
| 61 | 
         
            +
                    metrics['MultiModality (mean)'] = values[0]
         
     | 
| 62 | 
         
            +
                    metrics['MultiModality (conf)'] = values[1]
         
     | 
| 63 | 
         
            +
                    return metrics
         
     | 
    	
        mogen/core/evaluation/evaluators/precision_evaluator.py
    ADDED
    
    | 
         @@ -0,0 +1,74 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from ..get_model import get_motion_model, get_text_model
         
     | 
| 5 | 
         
            +
            from .base_evaluator import BaseEvaluator
         
     | 
| 6 | 
         
            +
            from ..utils import calculate_top_k, euclidean_distance_matrix
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class PrecisionEvaluator(BaseEvaluator):
         
     | 
| 10 | 
         
            +
                
         
     | 
| 11 | 
         
            +
                def __init__(self,
         
     | 
| 12 | 
         
            +
                             data_len=0,
         
     | 
| 13 | 
         
            +
                             text_encoder_name=None,
         
     | 
| 14 | 
         
            +
                             text_encoder_path=None,
         
     | 
| 15 | 
         
            +
                             motion_encoder_name=None,
         
     | 
| 16 | 
         
            +
                             motion_encoder_path=None,
         
     | 
| 17 | 
         
            +
                             top_k=3,
         
     | 
| 18 | 
         
            +
                             batch_size=32,
         
     | 
| 19 | 
         
            +
                             drop_last=False,
         
     | 
| 20 | 
         
            +
                             replication_times=1,
         
     | 
| 21 | 
         
            +
                             replication_reduction='statistics',
         
     | 
| 22 | 
         
            +
                             **kwargs):
         
     | 
| 23 | 
         
            +
                    super().__init__(
         
     | 
| 24 | 
         
            +
                        replication_times=replication_times,
         
     | 
| 25 | 
         
            +
                        replication_reduction=replication_reduction,
         
     | 
| 26 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 27 | 
         
            +
                        drop_last=drop_last,
         
     | 
| 28 | 
         
            +
                        eval_begin_idx=0,
         
     | 
| 29 | 
         
            +
                        eval_end_idx=data_len
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
                    self.append_indexes = None
         
     | 
| 32 | 
         
            +
                    self.text_encoder = get_text_model(text_encoder_name, text_encoder_path)
         
     | 
| 33 | 
         
            +
                    self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path)
         
     | 
| 34 | 
         
            +
                    self.top_k = top_k
         
     | 
| 35 | 
         
            +
                    self.model_list = [self.text_encoder, self.motion_encoder]
         
     | 
| 36 | 
         
            +
                    
         
     | 
| 37 | 
         
            +
                def single_evaluate(self, results):
         
     | 
| 38 | 
         
            +
                    results = self.prepare_results(results)
         
     | 
| 39 | 
         
            +
                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         
     | 
| 40 | 
         
            +
                    motion = results['motion']
         
     | 
| 41 | 
         
            +
                    pred_motion = results['pred_motion']
         
     | 
| 42 | 
         
            +
                    pred_motion_length = results['pred_motion_length']
         
     | 
| 43 | 
         
            +
                    pred_motion_mask = results['pred_motion_mask']
         
     | 
| 44 | 
         
            +
                    text = results['text']
         
     | 
| 45 | 
         
            +
                    token = results['token']
         
     | 
| 46 | 
         
            +
                    self.text_encoder.to(device)
         
     | 
| 47 | 
         
            +
                    self.motion_encoder.to(device)
         
     | 
| 48 | 
         
            +
                    self.text_encoder.eval()
         
     | 
| 49 | 
         
            +
                    self.motion_encoder.eval()
         
     | 
| 50 | 
         
            +
                    with torch.no_grad():
         
     | 
| 51 | 
         
            +
                        word_emb = self.text_encode(text, token, device=device).cpu().detach().numpy()
         
     | 
| 52 | 
         
            +
                        motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy()
         
     | 
| 53 | 
         
            +
                        dist_mat = euclidean_distance_matrix(word_emb, motion_emb)
         
     | 
| 54 | 
         
            +
                        argsmax = np.argsort(dist_mat, axis=1)
         
     | 
| 55 | 
         
            +
                        top_k_mat = calculate_top_k(argsmax, top_k=self.top_k)
         
     | 
| 56 | 
         
            +
                        top_k_count = top_k_mat.sum(axis=0)
         
     | 
| 57 | 
         
            +
                        all_size = word_emb.shape[0]
         
     | 
| 58 | 
         
            +
                    return top_k_count, all_size
         
     | 
| 59 | 
         
            +
                
         
     | 
| 60 | 
         
            +
                def concat_batch_metrics(self, batch_metrics):
         
     | 
| 61 | 
         
            +
                    top_k_count = 0
         
     | 
| 62 | 
         
            +
                    all_size = 0
         
     | 
| 63 | 
         
            +
                    for batch_top_k_count, batch_all_size in batch_metrics:
         
     | 
| 64 | 
         
            +
                        top_k_count += batch_top_k_count
         
     | 
| 65 | 
         
            +
                        all_size += batch_all_size
         
     | 
| 66 | 
         
            +
                    R_precision = top_k_count / all_size
         
     | 
| 67 | 
         
            +
                    return R_precision
         
     | 
| 68 | 
         
            +
                    
         
     | 
| 69 | 
         
            +
                def parse_values(self, values):
         
     | 
| 70 | 
         
            +
                    metrics = {}
         
     | 
| 71 | 
         
            +
                    for top_k in range(self.top_k):
         
     | 
| 72 | 
         
            +
                        metrics['R_precision Top %d (mean)' % (top_k + 1)] = values[0][top_k]
         
     | 
| 73 | 
         
            +
                        metrics['R_precision Top %d (conf)' % (top_k + 1)] = values[1][top_k]
         
     | 
| 74 | 
         
            +
                    return metrics
         
     | 
    	
        mogen/core/evaluation/get_model.py
    ADDED
    
    | 
         @@ -0,0 +1,46 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from mogen.models import build_submodule
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            def get_motion_model(name, ckpt_path):
         
     | 
| 5 | 
         
            +
                if name == 'kit_ml':
         
     | 
| 6 | 
         
            +
                    model = build_submodule(dict(
         
     | 
| 7 | 
         
            +
                        type='T2MMotionEncoder',
         
     | 
| 8 | 
         
            +
                        input_size=251,
         
     | 
| 9 | 
         
            +
                        movement_hidden_size=512,
         
     | 
| 10 | 
         
            +
                        movement_latent_size=512,
         
     | 
| 11 | 
         
            +
                        motion_hidden_size=1024,
         
     | 
| 12 | 
         
            +
                        motion_latent_size=512,
         
     | 
| 13 | 
         
            +
                    ))
         
     | 
| 14 | 
         
            +
                else:
         
     | 
| 15 | 
         
            +
                    model = build_submodule(dict(
         
     | 
| 16 | 
         
            +
                        type='T2MMotionEncoder',
         
     | 
| 17 | 
         
            +
                        input_size=263,
         
     | 
| 18 | 
         
            +
                        movement_hidden_size=512,
         
     | 
| 19 | 
         
            +
                        movement_latent_size=512,
         
     | 
| 20 | 
         
            +
                        motion_hidden_size=1024,
         
     | 
| 21 | 
         
            +
                        motion_latent_size=512,
         
     | 
| 22 | 
         
            +
                    ))
         
     | 
| 23 | 
         
            +
                model.load_pretrained(ckpt_path)
         
     | 
| 24 | 
         
            +
                return model
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            def get_text_model(name, ckpt_path):
         
     | 
| 27 | 
         
            +
                if name == 'kit_ml':
         
     | 
| 28 | 
         
            +
                    model = build_submodule(dict(
         
     | 
| 29 | 
         
            +
                        type='T2MTextEncoder',
         
     | 
| 30 | 
         
            +
                        word_size=300,
         
     | 
| 31 | 
         
            +
                        pos_size=15,
         
     | 
| 32 | 
         
            +
                        hidden_size=512,
         
     | 
| 33 | 
         
            +
                        output_size=512,
         
     | 
| 34 | 
         
            +
                        max_text_len=20
         
     | 
| 35 | 
         
            +
                    ))
         
     | 
| 36 | 
         
            +
                else:
         
     | 
| 37 | 
         
            +
                    model = build_submodule(dict(
         
     | 
| 38 | 
         
            +
                        type='T2MTextEncoder',
         
     | 
| 39 | 
         
            +
                        word_size=300,
         
     | 
| 40 | 
         
            +
                        pos_size=15,
         
     | 
| 41 | 
         
            +
                        hidden_size=512,
         
     | 
| 42 | 
         
            +
                        output_size=512,
         
     | 
| 43 | 
         
            +
                        max_text_len=20
         
     | 
| 44 | 
         
            +
                    ))
         
     | 
| 45 | 
         
            +
                model.load_pretrained(ckpt_path)
         
     | 
| 46 | 
         
            +
                return model
         
     | 
    	
        mogen/core/evaluation/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,130 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            from scipy import linalg
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def get_metric_statistics(values, replication_times):
         
     | 
| 6 | 
         
            +
                mean = np.mean(values, axis=0)
         
     | 
| 7 | 
         
            +
                std = np.std(values, axis=0)
         
     | 
| 8 | 
         
            +
                conf_interval = 1.96 * std / np.sqrt(replication_times)
         
     | 
| 9 | 
         
            +
                return mean, conf_interval
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
         
     | 
| 13 | 
         
            +
            def euclidean_distance_matrix(matrix1, matrix2):
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
                    Params:
         
     | 
| 16 | 
         
            +
                    -- matrix1: N1 x D
         
     | 
| 17 | 
         
            +
                    -- matrix2: N2 x D
         
     | 
| 18 | 
         
            +
                    Returns:
         
     | 
| 19 | 
         
            +
                    -- dist: N1 x N2
         
     | 
| 20 | 
         
            +
                    dist[i, j] == distance(matrix1[i], matrix2[j])
         
     | 
| 21 | 
         
            +
                """
         
     | 
| 22 | 
         
            +
                assert matrix1.shape[1] == matrix2.shape[1]
         
     | 
| 23 | 
         
            +
                d1 = -2 * np.dot(matrix1, matrix2.T)    # shape (num_test, num_train)
         
     | 
| 24 | 
         
            +
                d2 = np.sum(np.square(matrix1), axis=1, keepdims=True)    # shape (num_test, 1)
         
     | 
| 25 | 
         
            +
                d3 = np.sum(np.square(matrix2), axis=1)     # shape (num_train, )
         
     | 
| 26 | 
         
            +
                dists = np.sqrt(d1 + d2 + d3)  # broadcasting
         
     | 
| 27 | 
         
            +
                return dists
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def calculate_top_k(mat, top_k):
         
     | 
| 31 | 
         
            +
                size = mat.shape[0]
         
     | 
| 32 | 
         
            +
                gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
         
     | 
| 33 | 
         
            +
                bool_mat = (mat == gt_mat)
         
     | 
| 34 | 
         
            +
                correct_vec = False
         
     | 
| 35 | 
         
            +
                top_k_list = []
         
     | 
| 36 | 
         
            +
                for i in range(top_k):
         
     | 
| 37 | 
         
            +
            #         print(correct_vec, bool_mat[:, i])
         
     | 
| 38 | 
         
            +
                    correct_vec = (correct_vec | bool_mat[:, i])
         
     | 
| 39 | 
         
            +
                    # print(correct_vec)
         
     | 
| 40 | 
         
            +
                    top_k_list.append(correct_vec[:, None])
         
     | 
| 41 | 
         
            +
                top_k_mat = np.concatenate(top_k_list, axis=1)
         
     | 
| 42 | 
         
            +
                return top_k_mat
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            def calculate_activation_statistics(activations):
         
     | 
| 46 | 
         
            +
                """
         
     | 
| 47 | 
         
            +
                Params:
         
     | 
| 48 | 
         
            +
                -- activation: num_samples x dim_feat
         
     | 
| 49 | 
         
            +
                Returns:
         
     | 
| 50 | 
         
            +
                -- mu: dim_feat
         
     | 
| 51 | 
         
            +
                -- sigma: dim_feat x dim_feat
         
     | 
| 52 | 
         
            +
                """
         
     | 
| 53 | 
         
            +
                mu = np.mean(activations, axis=0)
         
     | 
| 54 | 
         
            +
                cov = np.cov(activations, rowvar=False)
         
     | 
| 55 | 
         
            +
                return mu, cov
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
         
     | 
| 58 | 
         
            +
                """Numpy implementation of the Frechet Distance.
         
     | 
| 59 | 
         
            +
                The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
         
     | 
| 60 | 
         
            +
                and X_2 ~ N(mu_2, C_2) is
         
     | 
| 61 | 
         
            +
                        d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
         
     | 
| 62 | 
         
            +
                Stable version by Dougal J. Sutherland.
         
     | 
| 63 | 
         
            +
                Params:
         
     | 
| 64 | 
         
            +
                -- mu1   : Numpy array containing the activations of a layer of the
         
     | 
| 65 | 
         
            +
                           inception net (like returned by the function 'get_predictions')
         
     | 
| 66 | 
         
            +
                           for generated samples.
         
     | 
| 67 | 
         
            +
                -- mu2   : The sample mean over activations, precalculated on an
         
     | 
| 68 | 
         
            +
                           representative data set.
         
     | 
| 69 | 
         
            +
                -- sigma1: The covariance matrix over activations for generated samples.
         
     | 
| 70 | 
         
            +
                -- sigma2: The covariance matrix over activations, precalculated on an
         
     | 
| 71 | 
         
            +
                           representative data set.
         
     | 
| 72 | 
         
            +
                Returns:
         
     | 
| 73 | 
         
            +
                --   : The Frechet Distance.
         
     | 
| 74 | 
         
            +
                """
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                mu1 = np.atleast_1d(mu1)
         
     | 
| 77 | 
         
            +
                mu2 = np.atleast_1d(mu2)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                sigma1 = np.atleast_2d(sigma1)
         
     | 
| 80 | 
         
            +
                sigma2 = np.atleast_2d(sigma2)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                assert mu1.shape == mu2.shape, \
         
     | 
| 83 | 
         
            +
                    'Training and test mean vectors have different lengths'
         
     | 
| 84 | 
         
            +
                assert sigma1.shape == sigma2.shape, \
         
     | 
| 85 | 
         
            +
                    'Training and test covariances have different dimensions'
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                diff = mu1 - mu2
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                # Product might be almost singular
         
     | 
| 90 | 
         
            +
                covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
         
     | 
| 91 | 
         
            +
                if not np.isfinite(covmean).all():
         
     | 
| 92 | 
         
            +
                    msg = ('fid calculation produces singular product; '
         
     | 
| 93 | 
         
            +
                           'adding %s to diagonal of cov estimates') % eps
         
     | 
| 94 | 
         
            +
                    print(msg)
         
     | 
| 95 | 
         
            +
                    offset = np.eye(sigma1.shape[0]) * eps
         
     | 
| 96 | 
         
            +
                    covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                # Numerical error might give slight imaginary component
         
     | 
| 99 | 
         
            +
                if np.iscomplexobj(covmean):
         
     | 
| 100 | 
         
            +
                    if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
         
     | 
| 101 | 
         
            +
                        m = np.max(np.abs(covmean.imag))
         
     | 
| 102 | 
         
            +
                        raise ValueError('Imaginary component {}'.format(m))
         
     | 
| 103 | 
         
            +
                    covmean = covmean.real
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                tr_covmean = np.trace(covmean)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                return (diff.dot(diff) + np.trace(sigma1) +
         
     | 
| 108 | 
         
            +
                        np.trace(sigma2) - 2 * tr_covmean)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
            def calculate_diversity(activation, diversity_times):
         
     | 
| 112 | 
         
            +
                assert len(activation.shape) == 2
         
     | 
| 113 | 
         
            +
                assert activation.shape[0] > diversity_times
         
     | 
| 114 | 
         
            +
                num_samples = activation.shape[0]
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                first_indices = np.random.choice(num_samples, diversity_times, replace=False)
         
     | 
| 117 | 
         
            +
                second_indices = np.random.choice(num_samples, diversity_times, replace=False)
         
     | 
| 118 | 
         
            +
                dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
         
     | 
| 119 | 
         
            +
                return dist.mean()
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            def calculate_multimodality(activation, multimodality_times):
         
     | 
| 123 | 
         
            +
                assert len(activation.shape) == 3
         
     | 
| 124 | 
         
            +
                assert activation.shape[1] > multimodality_times
         
     | 
| 125 | 
         
            +
                num_per_sent = activation.shape[1]
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
         
     | 
| 128 | 
         
            +
                second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
         
     | 
| 129 | 
         
            +
                dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
         
     | 
| 130 | 
         
            +
                return dist.mean()
         
     | 
    	
        mogen/core/optimizer/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .builder import OPTIMIZERS, build_optimizers
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            __all__ = ['build_optimizers', 'OPTIMIZERS']
         
     | 
    	
        mogen/core/optimizer/builder.py
    ADDED
    
    | 
         @@ -0,0 +1,52 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) OpenMMLab. All rights reserved.
         
     | 
| 2 | 
         
            +
            from mmcv.runner import build_optimizer
         
     | 
| 3 | 
         
            +
            from mmcv.utils import Registry
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            OPTIMIZERS = Registry('optimizers')
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            def build_optimizers(model, cfgs):
         
     | 
| 9 | 
         
            +
                """Build multiple optimizers from configs. If `cfgs` contains several dicts
         
     | 
| 10 | 
         
            +
                for optimizers, then a dict for each constructed optimizers will be
         
     | 
| 11 | 
         
            +
                returned. If `cfgs` only contains one optimizer config, the constructed
         
     | 
| 12 | 
         
            +
                optimizer itself will be returned. For example,
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                1) Multiple optimizer configs:
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                .. code-block:: python
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    optimizer_cfg = dict(
         
     | 
| 19 | 
         
            +
                        model1=dict(type='SGD', lr=lr),
         
     | 
| 20 | 
         
            +
                        model2=dict(type='SGD', lr=lr))
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                The return dict is
         
     | 
| 23 | 
         
            +
                ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                2) Single optimizer config:
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                .. code-block:: python
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    optimizer_cfg = dict(type='SGD', lr=lr)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                The return is ``torch.optim.Optimizer``.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                Args:
         
     | 
| 34 | 
         
            +
                    model (:obj:`nn.Module`): The model with parameters to be optimized.
         
     | 
| 35 | 
         
            +
                    cfgs (dict): The config dict of the optimizer.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                Returns:
         
     | 
| 38 | 
         
            +
                    dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
         
     | 
| 39 | 
         
            +
                        The initialized optimizers.
         
     | 
| 40 | 
         
            +
                """
         
     | 
| 41 | 
         
            +
                optimizers = {}
         
     | 
| 42 | 
         
            +
                if hasattr(model, 'module'):
         
     | 
| 43 | 
         
            +
                    model = model.module
         
     | 
| 44 | 
         
            +
                # determine whether 'cfgs' has several dicts for optimizers
         
     | 
| 45 | 
         
            +
                if all(isinstance(v, dict) for v in cfgs.values()):
         
     | 
| 46 | 
         
            +
                    for key, cfg in cfgs.items():
         
     | 
| 47 | 
         
            +
                        cfg_ = cfg.copy()
         
     | 
| 48 | 
         
            +
                        module = getattr(model, key)
         
     | 
| 49 | 
         
            +
                        optimizers[key] = build_optimizer(module, cfg_)
         
     | 
| 50 | 
         
            +
                    return optimizers
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                return build_optimizer(model, cfgs)
         
     | 
    	
        mogen/datasets/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,11 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .base_dataset import BaseMotionDataset
         
     | 
| 2 | 
         
            +
            from .text_motion_dataset import TextMotionDataset
         
     | 
| 3 | 
         
            +
            from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
         
     | 
| 4 | 
         
            +
            from .pipelines import Compose
         
     | 
| 5 | 
         
            +
            from .samplers import DistributedSampler
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            __all__ = [
         
     | 
| 9 | 
         
            +
                'BaseMotionDataset', 'TextMotionDataset', 'DATASETS', 'PIPELINES', 'build_dataloader',
         
     | 
| 10 | 
         
            +
                'build_dataset', 'Compose', 'DistributedSampler'
         
     | 
| 11 | 
         
            +
            ]
         
     | 
    	
        mogen/datasets/base_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,117 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import copy
         
     | 
| 3 | 
         
            +
            from typing import Optional, Union
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            from torch.utils.data import Dataset
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from .pipelines import Compose
         
     | 
| 9 | 
         
            +
            from .builder import DATASETS
         
     | 
| 10 | 
         
            +
            from mogen.core.evaluation import build_evaluator
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 14 | 
         
            +
            class BaseMotionDataset(Dataset):
         
     | 
| 15 | 
         
            +
                """Base motion dataset.
         
     | 
| 16 | 
         
            +
                Args:
         
     | 
| 17 | 
         
            +
                    data_prefix (str): the prefix of data path.
         
     | 
| 18 | 
         
            +
                    pipeline (list): a list of dict, where each element represents
         
     | 
| 19 | 
         
            +
                        a operation defined in `mogen.datasets.pipelines`.
         
     | 
| 20 | 
         
            +
                    ann_file (str | None, optional): the annotation file. When ann_file is
         
     | 
| 21 | 
         
            +
                        str, the subclass is expected to read from the ann_file. When
         
     | 
| 22 | 
         
            +
                        ann_file is None, the subclass is expected to read according
         
     | 
| 23 | 
         
            +
                        to data_prefix.
         
     | 
| 24 | 
         
            +
                    test_mode (bool): in train mode or test mode. Default: None.
         
     | 
| 25 | 
         
            +
                    dataset_name (str | None, optional): the name of dataset. It is used
         
     | 
| 26 | 
         
            +
                        to identify the type of evaluation metric. Default: None.
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def __init__(self,
         
     | 
| 30 | 
         
            +
                             data_prefix: str,
         
     | 
| 31 | 
         
            +
                             pipeline: list,
         
     | 
| 32 | 
         
            +
                             dataset_name: Optional[Union[str, None]] = None,
         
     | 
| 33 | 
         
            +
                             fixed_length: Optional[Union[int, None]] = None,
         
     | 
| 34 | 
         
            +
                             ann_file: Optional[Union[str, None]] = None,
         
     | 
| 35 | 
         
            +
                             motion_dir: Optional[Union[str, None]] = None,
         
     | 
| 36 | 
         
            +
                             eval_cfg: Optional[Union[dict, None]] = None,
         
     | 
| 37 | 
         
            +
                             test_mode: Optional[bool] = False):
         
     | 
| 38 | 
         
            +
                    super(BaseMotionDataset, self).__init__()
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    self.data_prefix = data_prefix
         
     | 
| 41 | 
         
            +
                    self.pipeline = Compose(pipeline)
         
     | 
| 42 | 
         
            +
                    self.dataset_name = dataset_name
         
     | 
| 43 | 
         
            +
                    self.fixed_length = fixed_length
         
     | 
| 44 | 
         
            +
                    self.ann_file = os.path.join(data_prefix, 'datasets', dataset_name, ann_file)
         
     | 
| 45 | 
         
            +
                    self.motion_dir = os.path.join(data_prefix, 'datasets', dataset_name, motion_dir)
         
     | 
| 46 | 
         
            +
                    self.eval_cfg = copy.deepcopy(eval_cfg)
         
     | 
| 47 | 
         
            +
                    self.test_mode = test_mode
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    self.load_annotations()
         
     | 
| 50 | 
         
            +
                    if self.test_mode:
         
     | 
| 51 | 
         
            +
                        self.prepare_evaluation()
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def load_anno(self, name):
         
     | 
| 54 | 
         
            +
                    motion_path = os.path.join(self.motion_dir, name + '.npy')
         
     | 
| 55 | 
         
            +
                    motion_data = np.load(motion_path)
         
     | 
| 56 | 
         
            +
                    return {'motion': motion_data}
         
     | 
| 57 | 
         
            +
                    
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                def load_annotations(self):
         
     | 
| 60 | 
         
            +
                    """Load annotations from ``ann_file`` to ``data_infos``"""
         
     | 
| 61 | 
         
            +
                    self.data_infos = []
         
     | 
| 62 | 
         
            +
                    for line in open(self.ann_file, 'r').readlines():
         
     | 
| 63 | 
         
            +
                        line = line.strip()
         
     | 
| 64 | 
         
            +
                        self.data_infos.append(self.load_anno(line))
         
     | 
| 65 | 
         
            +
                        
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def prepare_data(self, idx: int):
         
     | 
| 68 | 
         
            +
                    """"Prepare raw data for the f'{idx'}-th data."""
         
     | 
| 69 | 
         
            +
                    results = copy.deepcopy(self.data_infos[idx])
         
     | 
| 70 | 
         
            +
                    results['dataset_name'] = self.dataset_name
         
     | 
| 71 | 
         
            +
                    results['sample_idx'] = idx
         
     | 
| 72 | 
         
            +
                    return self.pipeline(results)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def __len__(self):
         
     | 
| 75 | 
         
            +
                    """Return the length of current dataset."""
         
     | 
| 76 | 
         
            +
                    if self.test_mode:
         
     | 
| 77 | 
         
            +
                        return len(self.eval_indexes)
         
     | 
| 78 | 
         
            +
                    elif self.fixed_length is not None:
         
     | 
| 79 | 
         
            +
                        return self.fixed_length
         
     | 
| 80 | 
         
            +
                    return len(self.data_infos)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def __getitem__(self, idx: int):
         
     | 
| 83 | 
         
            +
                    """Prepare data for the ``idx``-th data.
         
     | 
| 84 | 
         
            +
                    As for video dataset, we can first parse raw data for each frame. Then
         
     | 
| 85 | 
         
            +
                    we combine annotations from all frames. This interface is used to
         
     | 
| 86 | 
         
            +
                    simplify the logic of video dataset and other special datasets.
         
     | 
| 87 | 
         
            +
                    """
         
     | 
| 88 | 
         
            +
                    if self.test_mode:
         
     | 
| 89 | 
         
            +
                        idx = self.eval_indexes[idx]
         
     | 
| 90 | 
         
            +
                    elif self.fixed_length is not None:
         
     | 
| 91 | 
         
            +
                        idx = idx % len(self.data_infos)
         
     | 
| 92 | 
         
            +
                    return self.prepare_data(idx)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                def prepare_evaluation(self):
         
     | 
| 95 | 
         
            +
                    self.evaluators = []
         
     | 
| 96 | 
         
            +
                    self.eval_indexes = []
         
     | 
| 97 | 
         
            +
                    for _ in range(self.eval_cfg['replication_times']):
         
     | 
| 98 | 
         
            +
                        eval_indexes = np.arange(len(self.data_infos)) 
         
     | 
| 99 | 
         
            +
                        if self.eval_cfg.get('shuffle_indexes', False):
         
     | 
| 100 | 
         
            +
                            np.random.shuffle(eval_indexes)
         
     | 
| 101 | 
         
            +
                        self.eval_indexes.append(eval_indexes)
         
     | 
| 102 | 
         
            +
                    for metric in self.eval_cfg['metrics']:
         
     | 
| 103 | 
         
            +
                        evaluator, self.eval_indexes = build_evaluator(
         
     | 
| 104 | 
         
            +
                            metric, self.eval_cfg, len(self.data_infos), self.eval_indexes)
         
     | 
| 105 | 
         
            +
                        self.evaluators.append(evaluator)
         
     | 
| 106 | 
         
            +
                    
         
     | 
| 107 | 
         
            +
                    self.eval_indexes = np.concatenate(self.eval_indexes)
         
     | 
| 108 | 
         
            +
                        
         
     | 
| 109 | 
         
            +
                def evaluate(self, results, work_dir, logger=None):
         
     | 
| 110 | 
         
            +
                    metrics = {}
         
     | 
| 111 | 
         
            +
                    device = results[0]['motion'].device
         
     | 
| 112 | 
         
            +
                    for evaluator in self.evaluators:
         
     | 
| 113 | 
         
            +
                        evaluator.to_device(device)
         
     | 
| 114 | 
         
            +
                        metrics.update(evaluator.evaluate(results))
         
     | 
| 115 | 
         
            +
                    if logger is not None:
         
     | 
| 116 | 
         
            +
                        logger.info(metrics)
         
     | 
| 117 | 
         
            +
                    return metrics
         
     | 
    	
        mogen/datasets/builder.py
    ADDED
    
    | 
         @@ -0,0 +1,113 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import platform
         
     | 
| 2 | 
         
            +
            import random
         
     | 
| 3 | 
         
            +
            from functools import partial
         
     | 
| 4 | 
         
            +
            from typing import Optional, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            from mmcv.parallel import collate
         
     | 
| 8 | 
         
            +
            from mmcv.runner import get_dist_info
         
     | 
| 9 | 
         
            +
            from mmcv.utils import Registry, build_from_cfg
         
     | 
| 10 | 
         
            +
            from torch.utils.data import DataLoader
         
     | 
| 11 | 
         
            +
            from torch.utils.data.dataset import Dataset
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from .samplers import DistributedSampler
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            if platform.system() != 'Windows':
         
     | 
| 16 | 
         
            +
                # https://github.com/pytorch/pytorch/issues/973
         
     | 
| 17 | 
         
            +
                import resource
         
     | 
| 18 | 
         
            +
                rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
         
     | 
| 19 | 
         
            +
                base_soft_limit = rlimit[0]
         
     | 
| 20 | 
         
            +
                hard_limit = rlimit[1]
         
     | 
| 21 | 
         
            +
                soft_limit = min(max(4096, base_soft_limit), hard_limit)
         
     | 
| 22 | 
         
            +
                resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            DATASETS = Registry('dataset')
         
     | 
| 25 | 
         
            +
            PIPELINES = Registry('pipeline')
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def build_dataset(cfg: Union[dict, list, tuple],
         
     | 
| 29 | 
         
            +
                              default_args: Optional[Union[dict, None]] = None):
         
     | 
| 30 | 
         
            +
                """"Build dataset by the given config."""
         
     | 
| 31 | 
         
            +
                from .dataset_wrappers import (
         
     | 
| 32 | 
         
            +
                    ConcatDataset,
         
     | 
| 33 | 
         
            +
                    RepeatDataset,
         
     | 
| 34 | 
         
            +
                )
         
     | 
| 35 | 
         
            +
                if isinstance(cfg, (list, tuple)):
         
     | 
| 36 | 
         
            +
                    dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
         
     | 
| 37 | 
         
            +
                elif cfg['type'] == 'RepeatDataset':
         
     | 
| 38 | 
         
            +
                    dataset = RepeatDataset(
         
     | 
| 39 | 
         
            +
                        build_dataset(cfg['dataset'], default_args), cfg['times'])
         
     | 
| 40 | 
         
            +
                else:
         
     | 
| 41 | 
         
            +
                    dataset = build_from_cfg(cfg, DATASETS, default_args)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                return dataset
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def build_dataloader(dataset: Dataset,
         
     | 
| 47 | 
         
            +
                                 samples_per_gpu: int,
         
     | 
| 48 | 
         
            +
                                 workers_per_gpu: int,
         
     | 
| 49 | 
         
            +
                                 num_gpus: Optional[int] = 1,
         
     | 
| 50 | 
         
            +
                                 dist: Optional[bool] = True,
         
     | 
| 51 | 
         
            +
                                 shuffle: Optional[bool] = True,
         
     | 
| 52 | 
         
            +
                                 round_up: Optional[bool] = True,
         
     | 
| 53 | 
         
            +
                                 seed: Optional[Union[int, None]] = None,
         
     | 
| 54 | 
         
            +
                                 persistent_workers: Optional[bool] = True,
         
     | 
| 55 | 
         
            +
                                 **kwargs):
         
     | 
| 56 | 
         
            +
                """Build PyTorch DataLoader.
         
     | 
| 57 | 
         
            +
                In distributed training, each GPU/process has a dataloader.
         
     | 
| 58 | 
         
            +
                In non-distributed training, there is only one dataloader for all GPUs.
         
     | 
| 59 | 
         
            +
                Args:
         
     | 
| 60 | 
         
            +
                    dataset (:obj:`Dataset`): A PyTorch dataset.
         
     | 
| 61 | 
         
            +
                    samples_per_gpu (int): Number of training samples on each GPU, i.e.,
         
     | 
| 62 | 
         
            +
                        batch size of each GPU.
         
     | 
| 63 | 
         
            +
                    workers_per_gpu (int): How many subprocesses to use for data loading
         
     | 
| 64 | 
         
            +
                        for each GPU.
         
     | 
| 65 | 
         
            +
                    num_gpus (int, optional): Number of GPUs. Only used in non-distributed
         
     | 
| 66 | 
         
            +
                        training.
         
     | 
| 67 | 
         
            +
                    dist (bool, optional): Distributed training/test or not. Default: True.
         
     | 
| 68 | 
         
            +
                    shuffle (bool, optional): Whether to shuffle the data at every epoch.
         
     | 
| 69 | 
         
            +
                        Default: True.
         
     | 
| 70 | 
         
            +
                    round_up (bool, optional): Whether to round up the length of dataset by
         
     | 
| 71 | 
         
            +
                        adding extra samples to make it evenly divisible. Default: True.
         
     | 
| 72 | 
         
            +
                    kwargs: any keyword argument to be used to initialize DataLoader
         
     | 
| 73 | 
         
            +
                Returns:
         
     | 
| 74 | 
         
            +
                    DataLoader: A PyTorch dataloader.
         
     | 
| 75 | 
         
            +
                """
         
     | 
| 76 | 
         
            +
                rank, world_size = get_dist_info()
         
     | 
| 77 | 
         
            +
                if dist:
         
     | 
| 78 | 
         
            +
                    sampler = DistributedSampler(
         
     | 
| 79 | 
         
            +
                        dataset, world_size, rank, shuffle=shuffle, round_up=round_up)
         
     | 
| 80 | 
         
            +
                    shuffle = False
         
     | 
| 81 | 
         
            +
                    batch_size = samples_per_gpu
         
     | 
| 82 | 
         
            +
                    num_workers = workers_per_gpu
         
     | 
| 83 | 
         
            +
                else:
         
     | 
| 84 | 
         
            +
                    sampler = None
         
     | 
| 85 | 
         
            +
                    batch_size = num_gpus * samples_per_gpu
         
     | 
| 86 | 
         
            +
                    num_workers = num_gpus * workers_per_gpu
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                init_fn = partial(
         
     | 
| 89 | 
         
            +
                    worker_init_fn, num_workers=num_workers, rank=rank,
         
     | 
| 90 | 
         
            +
                    seed=seed) if seed is not None else None
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                data_loader = DataLoader(
         
     | 
| 93 | 
         
            +
                    dataset,
         
     | 
| 94 | 
         
            +
                    batch_size=batch_size,
         
     | 
| 95 | 
         
            +
                    sampler=sampler,
         
     | 
| 96 | 
         
            +
                    num_workers=num_workers,
         
     | 
| 97 | 
         
            +
                    collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
         
     | 
| 98 | 
         
            +
                    pin_memory=False,
         
     | 
| 99 | 
         
            +
                    shuffle=shuffle,
         
     | 
| 100 | 
         
            +
                    worker_init_fn=init_fn,
         
     | 
| 101 | 
         
            +
                    persistent_workers=persistent_workers,
         
     | 
| 102 | 
         
            +
                    **kwargs)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                return data_loader
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
         
     | 
| 108 | 
         
            +
                """Init random seed for each worker."""
         
     | 
| 109 | 
         
            +
                # The seed of each worker equals to
         
     | 
| 110 | 
         
            +
                # num_worker * rank + worker_id + user_seed
         
     | 
| 111 | 
         
            +
                worker_seed = num_workers * rank + worker_id + seed
         
     | 
| 112 | 
         
            +
                np.random.seed(worker_seed)
         
     | 
| 113 | 
         
            +
                random.seed(worker_seed)
         
     | 
    	
        mogen/datasets/dataset_wrappers.py
    ADDED
    
    | 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
         
     | 
| 2 | 
         
            +
            from torch.utils.data.dataset import Dataset
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from .builder import DATASETS
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 8 | 
         
            +
            class ConcatDataset(_ConcatDataset):
         
     | 
| 9 | 
         
            +
                """A wrapper of concatenated dataset.
         
     | 
| 10 | 
         
            +
                Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
         
     | 
| 11 | 
         
            +
                add `get_cat_ids` function.
         
     | 
| 12 | 
         
            +
                Args:
         
     | 
| 13 | 
         
            +
                    datasets (list[:obj:`Dataset`]): A list of datasets.
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(self, datasets: list):
         
     | 
| 17 | 
         
            +
                    super(ConcatDataset, self).__init__(datasets)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 21 | 
         
            +
            class RepeatDataset(object):
         
     | 
| 22 | 
         
            +
                """A wrapper of repeated dataset.
         
     | 
| 23 | 
         
            +
                The length of repeated dataset will be `times` larger than the original
         
     | 
| 24 | 
         
            +
                dataset. This is useful when the data loading time is long but the dataset
         
     | 
| 25 | 
         
            +
                is small. Using RepeatDataset can reduce the data loading time between
         
     | 
| 26 | 
         
            +
                epochs.
         
     | 
| 27 | 
         
            +
                Args:
         
     | 
| 28 | 
         
            +
                    dataset (:obj:`Dataset`): The dataset to be repeated.
         
     | 
| 29 | 
         
            +
                    times (int): Repeat times.
         
     | 
| 30 | 
         
            +
                """
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def __init__(self, dataset: Dataset, times: int):
         
     | 
| 33 | 
         
            +
                    self.dataset = dataset
         
     | 
| 34 | 
         
            +
                    self.times = times
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    self._ori_len = len(self.dataset)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def __getitem__(self, idx: int):
         
     | 
| 39 | 
         
            +
                    return self.dataset[idx % self._ori_len]
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __len__(self):
         
     | 
| 42 | 
         
            +
                    return self.times * self._ori_len
         
     | 
    	
        mogen/datasets/pipelines/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,18 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .compose import Compose
         
     | 
| 2 | 
         
            +
            from .formatting import (
         
     | 
| 3 | 
         
            +
                to_tensor,
         
     | 
| 4 | 
         
            +
                ToTensor,
         
     | 
| 5 | 
         
            +
                Transpose,
         
     | 
| 6 | 
         
            +
                Collect,
         
     | 
| 7 | 
         
            +
                WrapFieldsToLists
         
     | 
| 8 | 
         
            +
            )
         
     | 
| 9 | 
         
            +
            from .transforms import (
         
     | 
| 10 | 
         
            +
                Crop,
         
     | 
| 11 | 
         
            +
                RandomCrop,
         
     | 
| 12 | 
         
            +
                Normalize
         
     | 
| 13 | 
         
            +
            )
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            __all__ = [
         
     | 
| 16 | 
         
            +
                'Compose', 'to_tensor', 'Transpose', 'Collect', 'WrapFieldsToLists', 'ToTensor',
         
     | 
| 17 | 
         
            +
                'Crop', 'RandomCrop', 'Normalize'
         
     | 
| 18 | 
         
            +
            ]
         
     | 
    	
        mogen/datasets/pipelines/compose.py
    ADDED
    
    | 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from collections.abc import Sequence
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from mmcv.utils import build_from_cfg
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from ..builder import PIPELINES
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            @PIPELINES.register_module()
         
     | 
| 9 | 
         
            +
            class Compose(object):
         
     | 
| 10 | 
         
            +
                """Compose a data pipeline with a sequence of transforms.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                Args:
         
     | 
| 13 | 
         
            +
                    transforms (list[dict | callable]):
         
     | 
| 14 | 
         
            +
                        Either config dicts of transforms or transform objects.
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def __init__(self, transforms):
         
     | 
| 18 | 
         
            +
                    assert isinstance(transforms, Sequence)
         
     | 
| 19 | 
         
            +
                    self.transforms = []
         
     | 
| 20 | 
         
            +
                    for transform in transforms:
         
     | 
| 21 | 
         
            +
                        if isinstance(transform, dict):
         
     | 
| 22 | 
         
            +
                            transform = build_from_cfg(transform, PIPELINES)
         
     | 
| 23 | 
         
            +
                            self.transforms.append(transform)
         
     | 
| 24 | 
         
            +
                        elif callable(transform):
         
     | 
| 25 | 
         
            +
                            self.transforms.append(transform)
         
     | 
| 26 | 
         
            +
                        else:
         
     | 
| 27 | 
         
            +
                            raise TypeError('transform must be callable or a dict, but got'
         
     | 
| 28 | 
         
            +
                                            f' {type(transform)}')
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                def __call__(self, data):
         
     | 
| 31 | 
         
            +
                    for t in self.transforms:
         
     | 
| 32 | 
         
            +
                        data = t(data)
         
     | 
| 33 | 
         
            +
                        if data is None:
         
     | 
| 34 | 
         
            +
                            return None
         
     | 
| 35 | 
         
            +
                    return data
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                def __repr__(self):
         
     | 
| 38 | 
         
            +
                    format_string = self.__class__.__name__ + '('
         
     | 
| 39 | 
         
            +
                    for t in self.transforms:
         
     | 
| 40 | 
         
            +
                        format_string += f'\n    {t}'
         
     | 
| 41 | 
         
            +
                    format_string += '\n)'
         
     | 
| 42 | 
         
            +
                    return format_string
         
     | 
    	
        mogen/datasets/pipelines/formatting.py
    ADDED
    
    | 
         @@ -0,0 +1,134 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from collections.abc import Sequence
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import mmcv
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from mmcv.parallel import DataContainer as DC
         
     | 
| 7 | 
         
            +
            from PIL import Image
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from ..builder import PIPELINES
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def to_tensor(data):
         
     | 
| 13 | 
         
            +
                """Convert objects of various python types to :obj:`torch.Tensor`.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
         
     | 
| 16 | 
         
            +
                :class:`Sequence`, :class:`int` and :class:`float`.
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
                if isinstance(data, torch.Tensor):
         
     | 
| 19 | 
         
            +
                    return data
         
     | 
| 20 | 
         
            +
                elif isinstance(data, np.ndarray):
         
     | 
| 21 | 
         
            +
                    return torch.from_numpy(data)
         
     | 
| 22 | 
         
            +
                elif isinstance(data, Sequence) and not mmcv.is_str(data):
         
     | 
| 23 | 
         
            +
                    return torch.tensor(data)
         
     | 
| 24 | 
         
            +
                elif isinstance(data, int):
         
     | 
| 25 | 
         
            +
                    return torch.LongTensor([data])
         
     | 
| 26 | 
         
            +
                elif isinstance(data, float):
         
     | 
| 27 | 
         
            +
                    return torch.FloatTensor([data])
         
     | 
| 28 | 
         
            +
                else:
         
     | 
| 29 | 
         
            +
                    raise TypeError(
         
     | 
| 30 | 
         
            +
                        f'Type {type(data)} cannot be converted to tensor.'
         
     | 
| 31 | 
         
            +
                        'Supported types are: `numpy.ndarray`, `torch.Tensor`, '
         
     | 
| 32 | 
         
            +
                        '`Sequence`, `int` and `float`')
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @PIPELINES.register_module()
         
     | 
| 36 | 
         
            +
            class ToTensor(object):
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def __init__(self, keys):
         
     | 
| 39 | 
         
            +
                    self.keys = keys
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def __call__(self, results):
         
     | 
| 42 | 
         
            +
                    for key in self.keys:
         
     | 
| 43 | 
         
            +
                        results[key] = to_tensor(results[key])
         
     | 
| 44 | 
         
            +
                    return results
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def __repr__(self):
         
     | 
| 47 | 
         
            +
                    return self.__class__.__name__ + f'(keys={self.keys})'
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            @PIPELINES.register_module()
         
     | 
| 51 | 
         
            +
            class Transpose(object):
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                def __init__(self, keys, order):
         
     | 
| 54 | 
         
            +
                    self.keys = keys
         
     | 
| 55 | 
         
            +
                    self.order = order
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def __call__(self, results):
         
     | 
| 58 | 
         
            +
                    for key in self.keys:
         
     | 
| 59 | 
         
            +
                        results[key] = results[key].transpose(self.order)
         
     | 
| 60 | 
         
            +
                    return results
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def __repr__(self):
         
     | 
| 63 | 
         
            +
                    return self.__class__.__name__ + \
         
     | 
| 64 | 
         
            +
                        f'(keys={self.keys}, order={self.order})'
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            @PIPELINES.register_module()
         
     | 
| 68 | 
         
            +
            class Collect(object):
         
     | 
| 69 | 
         
            +
                """Collect data from the loader relevant to the specific task.
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                This is usually the last stage of the data loader pipeline. 
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                Args:
         
     | 
| 74 | 
         
            +
                    keys (Sequence[str]): Keys of results to be collected in ``data``.
         
     | 
| 75 | 
         
            +
                    meta_keys (Sequence[str], optional): Meta keys to be converted to
         
     | 
| 76 | 
         
            +
                        ``mmcv.DataContainer`` and collected in ``data[motion_metas]``.
         
     | 
| 77 | 
         
            +
                        Default: ``('filename', 'ori_filename', 'ori_shape', 'motion_shape', 'motion_mask')``
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                Returns:
         
     | 
| 80 | 
         
            +
                    dict: The result dict contains the following keys
         
     | 
| 81 | 
         
            +
                            - keys in``self.keys``
         
     | 
| 82 | 
         
            +
                            - ``motion_metas`` if available
         
     | 
| 83 | 
         
            +
                """
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                def __init__(self,
         
     | 
| 86 | 
         
            +
                             keys,
         
     | 
| 87 | 
         
            +
                             meta_keys=('filename', 'ori_filename', 'ori_shape', 'motion_shape', 'motion_mask')):
         
     | 
| 88 | 
         
            +
                    self.keys = keys
         
     | 
| 89 | 
         
            +
                    self.meta_keys = meta_keys
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def __call__(self, results):
         
     | 
| 92 | 
         
            +
                    data = {}
         
     | 
| 93 | 
         
            +
                    motion_meta = {}
         
     | 
| 94 | 
         
            +
                    for key in self.meta_keys:
         
     | 
| 95 | 
         
            +
                        if key in results:
         
     | 
| 96 | 
         
            +
                            motion_meta[key] = results[key]
         
     | 
| 97 | 
         
            +
                    data['motion_metas'] = DC(motion_meta, cpu_only=True)
         
     | 
| 98 | 
         
            +
                    for key in self.keys:
         
     | 
| 99 | 
         
            +
                        data[key] = results[key]
         
     | 
| 100 | 
         
            +
                    return data
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def __repr__(self):
         
     | 
| 103 | 
         
            +
                    return self.__class__.__name__ + \
         
     | 
| 104 | 
         
            +
                        f'(keys={self.keys}, meta_keys={self.meta_keys})'
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            @PIPELINES.register_module()
         
     | 
| 108 | 
         
            +
            class WrapFieldsToLists(object):
         
     | 
| 109 | 
         
            +
                """Wrap fields of the data dictionary into lists for evaluation.
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                This class can be used as a last step of a test or validation
         
     | 
| 112 | 
         
            +
                pipeline for single image evaluation or inference.
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                Example:
         
     | 
| 115 | 
         
            +
                    >>> test_pipeline = [
         
     | 
| 116 | 
         
            +
                    >>>    dict(type='LoadImageFromFile'),
         
     | 
| 117 | 
         
            +
                    >>>    dict(type='Normalize',
         
     | 
| 118 | 
         
            +
                                mean=[123.675, 116.28, 103.53],
         
     | 
| 119 | 
         
            +
                                std=[58.395, 57.12, 57.375],
         
     | 
| 120 | 
         
            +
                                to_rgb=True),
         
     | 
| 121 | 
         
            +
                    >>>    dict(type='ImageToTensor', keys=['img']),
         
     | 
| 122 | 
         
            +
                    >>>    dict(type='Collect', keys=['img']),
         
     | 
| 123 | 
         
            +
                    >>>    dict(type='WrapIntoLists')
         
     | 
| 124 | 
         
            +
                    >>> ]
         
     | 
| 125 | 
         
            +
                """
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def __call__(self, results):
         
     | 
| 128 | 
         
            +
                    # Wrap dict fields into lists
         
     | 
| 129 | 
         
            +
                    for key, val in results.items():
         
     | 
| 130 | 
         
            +
                        results[key] = [val]
         
     | 
| 131 | 
         
            +
                    return results
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def __repr__(self):
         
     | 
| 134 | 
         
            +
                    return f'{self.__class__.__name__}()'
         
     | 
    	
        mogen/datasets/pipelines/transforms.py
    ADDED
    
    | 
         @@ -0,0 +1,120 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            import random
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import mmcv
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from ..builder import PIPELINES
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from typing import Optional, Tuple, Union
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            @PIPELINES.register_module()
         
     | 
| 13 | 
         
            +
            class Crop(object):
         
     | 
| 14 | 
         
            +
                r"""Crop motion sequences.
         
     | 
| 15 | 
         
            +
                
         
     | 
| 16 | 
         
            +
                Args:
         
     | 
| 17 | 
         
            +
                    crop_size (int): The size of the cropped motion sequence.
         
     | 
| 18 | 
         
            +
                """
         
     | 
| 19 | 
         
            +
                def __init__(self,
         
     | 
| 20 | 
         
            +
                             crop_size: Optional[Union[int, None]] = None):
         
     | 
| 21 | 
         
            +
                    self.crop_size = crop_size
         
     | 
| 22 | 
         
            +
                    assert self.crop_size is not None
         
     | 
| 23 | 
         
            +
                    
         
     | 
| 24 | 
         
            +
                def __call__(self, results):
         
     | 
| 25 | 
         
            +
                    motion = results['motion']
         
     | 
| 26 | 
         
            +
                    length = len(motion)
         
     | 
| 27 | 
         
            +
                    if length >= self.crop_size:
         
     | 
| 28 | 
         
            +
                        idx = random.randint(0, length - self.crop_size)
         
     | 
| 29 | 
         
            +
                        motion = motion[idx: idx + self.crop_size]
         
     | 
| 30 | 
         
            +
                        results['motion_length'] = self.crop_size
         
     | 
| 31 | 
         
            +
                    else:
         
     | 
| 32 | 
         
            +
                        padding_length = self.crop_size - length
         
     | 
| 33 | 
         
            +
                        D = motion.shape[1:]
         
     | 
| 34 | 
         
            +
                        padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
         
     | 
| 35 | 
         
            +
                        motion = np.concatenate([motion, padding_zeros], axis=0)
         
     | 
| 36 | 
         
            +
                        results['motion_length'] = length
         
     | 
| 37 | 
         
            +
                    assert len(motion) == self.crop_size
         
     | 
| 38 | 
         
            +
                    results['motion'] = motion
         
     | 
| 39 | 
         
            +
                    results['motion_shape'] = motion.shape
         
     | 
| 40 | 
         
            +
                    if length >= self.crop_size:
         
     | 
| 41 | 
         
            +
                        results['motion_mask'] = torch.ones(self.crop_size).numpy()
         
     | 
| 42 | 
         
            +
                    else:
         
     | 
| 43 | 
         
            +
                        results['motion_mask'] = torch.cat(
         
     | 
| 44 | 
         
            +
                            (torch.ones(length), torch.zeros(self.crop_size - length))).numpy()
         
     | 
| 45 | 
         
            +
                    return results
         
     | 
| 46 | 
         
            +
                    
         
     | 
| 47 | 
         
            +
                    
         
     | 
| 48 | 
         
            +
                def __repr__(self):
         
     | 
| 49 | 
         
            +
                    repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})'
         
     | 
| 50 | 
         
            +
                    return repr_str
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            @PIPELINES.register_module()
         
     | 
| 53 | 
         
            +
            class RandomCrop(object):
         
     | 
| 54 | 
         
            +
                r"""Random crop motion sequences. Each sequence will be padded with zeros to the maximum length.
         
     | 
| 55 | 
         
            +
                
         
     | 
| 56 | 
         
            +
                Args:
         
     | 
| 57 | 
         
            +
                    min_size (int or None): The minimum size of the cropped motion sequence (inclusive).
         
     | 
| 58 | 
         
            +
                    max_size (int or None): The maximum size of the cropped motion sequence (inclusive).
         
     | 
| 59 | 
         
            +
                """
         
     | 
| 60 | 
         
            +
                def __init__(self,
         
     | 
| 61 | 
         
            +
                             min_size: Optional[Union[int, None]] = None,
         
     | 
| 62 | 
         
            +
                             max_size: Optional[Union[int, None]] = None):
         
     | 
| 63 | 
         
            +
                    self.min_size = min_size
         
     | 
| 64 | 
         
            +
                    self.max_size = max_size
         
     | 
| 65 | 
         
            +
                    assert self.min_size is not None
         
     | 
| 66 | 
         
            +
                    assert self.max_size is not None
         
     | 
| 67 | 
         
            +
                    
         
     | 
| 68 | 
         
            +
                def __call__(self, results):
         
     | 
| 69 | 
         
            +
                    motion = results['motion']
         
     | 
| 70 | 
         
            +
                    length = len(motion)
         
     | 
| 71 | 
         
            +
                    crop_size = random.randint(self.min_size, self.max_size)
         
     | 
| 72 | 
         
            +
                    if length > crop_size:
         
     | 
| 73 | 
         
            +
                        idx = random.randint(0, length - crop_size)
         
     | 
| 74 | 
         
            +
                        motion = motion[idx: idx + crop_size]
         
     | 
| 75 | 
         
            +
                        results['motion_length'] = crop_size
         
     | 
| 76 | 
         
            +
                    else:
         
     | 
| 77 | 
         
            +
                        results['motion_length'] = length
         
     | 
| 78 | 
         
            +
                    padding_length = self.max_size - min(crop_size, length)
         
     | 
| 79 | 
         
            +
                    if padding_length > 0:
         
     | 
| 80 | 
         
            +
                        D = motion.shape[1:]
         
     | 
| 81 | 
         
            +
                        padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
         
     | 
| 82 | 
         
            +
                        motion = np.concatenate([motion, padding_zeros], axis=0)
         
     | 
| 83 | 
         
            +
                    results['motion'] = motion
         
     | 
| 84 | 
         
            +
                    results['motion_shape'] = motion.shape
         
     | 
| 85 | 
         
            +
                    if length >= self.max_size and crop_size == self.max_size:
         
     | 
| 86 | 
         
            +
                        results['motion_mask'] = torch.ones(self.max_size).numpy()
         
     | 
| 87 | 
         
            +
                    else:
         
     | 
| 88 | 
         
            +
                        results['motion_mask'] = torch.cat((
         
     | 
| 89 | 
         
            +
                            torch.ones(min(length, crop_size)),
         
     | 
| 90 | 
         
            +
                            torch.zeros(self.max_size - min(length, crop_size))), dim=0).numpy()
         
     | 
| 91 | 
         
            +
                    assert len(motion) == self.max_size
         
     | 
| 92 | 
         
            +
                    return results
         
     | 
| 93 | 
         
            +
                    
         
     | 
| 94 | 
         
            +
                    
         
     | 
| 95 | 
         
            +
                def __repr__(self):
         
     | 
| 96 | 
         
            +
                    repr_str = self.__class__.__name__ + f'(min_size={self.min_size}'
         
     | 
| 97 | 
         
            +
                    repr_str += f', max_size={self.max_size})'
         
     | 
| 98 | 
         
            +
                    return repr_str
         
     | 
| 99 | 
         
            +
                    
         
     | 
| 100 | 
         
            +
            @PIPELINES.register_module()
         
     | 
| 101 | 
         
            +
            class Normalize(object):
         
     | 
| 102 | 
         
            +
                """Normalize motion sequences.
         
     | 
| 103 | 
         
            +
                
         
     | 
| 104 | 
         
            +
                Args:
         
     | 
| 105 | 
         
            +
                    mean_path (str): Path of mean file.
         
     | 
| 106 | 
         
            +
                    std_path (str): Path of std file.
         
     | 
| 107 | 
         
            +
                """
         
     | 
| 108 | 
         
            +
                
         
     | 
| 109 | 
         
            +
                def __init__(self, mean_path, std_path, eps=1e-9):
         
     | 
| 110 | 
         
            +
                    self.mean = np.load(mean_path)
         
     | 
| 111 | 
         
            +
                    self.std = np.load(std_path)
         
     | 
| 112 | 
         
            +
                    self.eps = eps
         
     | 
| 113 | 
         
            +
                    
         
     | 
| 114 | 
         
            +
                def __call__(self, results):
         
     | 
| 115 | 
         
            +
                    motion = results['motion']
         
     | 
| 116 | 
         
            +
                    motion = (motion - self.mean) / (self.std + self.eps)
         
     | 
| 117 | 
         
            +
                    results['motion'] = motion
         
     | 
| 118 | 
         
            +
                    results['motion_norm_mean'] = self.mean
         
     | 
| 119 | 
         
            +
                    results['motion_norm_std'] = self.std
         
     | 
| 120 | 
         
            +
                    return results
         
     | 
    	
        mogen/datasets/samplers/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .distributed_sampler import DistributedSampler
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            __all__ = ['DistributedSampler']
         
     | 
    	
        mogen/datasets/samplers/distributed_sampler.py
    ADDED
    
    | 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torch.utils.data import DistributedSampler as _DistributedSampler
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class DistributedSampler(_DistributedSampler):
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                def __init__(self,
         
     | 
| 8 | 
         
            +
                             dataset,
         
     | 
| 9 | 
         
            +
                             num_replicas=None,
         
     | 
| 10 | 
         
            +
                             rank=None,
         
     | 
| 11 | 
         
            +
                             shuffle=True,
         
     | 
| 12 | 
         
            +
                             round_up=True):
         
     | 
| 13 | 
         
            +
                    super().__init__(dataset, num_replicas=num_replicas, rank=rank)
         
     | 
| 14 | 
         
            +
                    self.shuffle = shuffle
         
     | 
| 15 | 
         
            +
                    self.round_up = round_up
         
     | 
| 16 | 
         
            +
                    if self.round_up:
         
     | 
| 17 | 
         
            +
                        self.total_size = self.num_samples * self.num_replicas
         
     | 
| 18 | 
         
            +
                    else:
         
     | 
| 19 | 
         
            +
                        self.total_size = len(self.dataset)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                def __iter__(self):
         
     | 
| 22 | 
         
            +
                    # deterministically shuffle based on epoch
         
     | 
| 23 | 
         
            +
                    if self.shuffle:
         
     | 
| 24 | 
         
            +
                        g = torch.Generator()
         
     | 
| 25 | 
         
            +
                        g.manual_seed(self.epoch)
         
     | 
| 26 | 
         
            +
                        indices = torch.randperm(len(self.dataset), generator=g).tolist()
         
     | 
| 27 | 
         
            +
                    else:
         
     | 
| 28 | 
         
            +
                        indices = torch.arange(len(self.dataset)).tolist()
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    # add extra samples to make it evenly divisible
         
     | 
| 31 | 
         
            +
                    if self.round_up:
         
     | 
| 32 | 
         
            +
                        indices = (
         
     | 
| 33 | 
         
            +
                            indices *
         
     | 
| 34 | 
         
            +
                            int(self.total_size / len(indices) + 1))[:self.total_size]
         
     | 
| 35 | 
         
            +
                    assert len(indices) == self.total_size
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    # subsample
         
     | 
| 38 | 
         
            +
                    indices = indices[self.rank:self.total_size:self.num_replicas]
         
     | 
| 39 | 
         
            +
                    if self.round_up:
         
     | 
| 40 | 
         
            +
                        assert len(indices) == self.num_samples
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    return iter(indices)
         
     | 
    	
        mogen/datasets/text_motion_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,93 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import json
         
     | 
| 2 | 
         
            +
            import os
         
     | 
| 3 | 
         
            +
            import os.path
         
     | 
| 4 | 
         
            +
            from abc import ABCMeta
         
     | 
| 5 | 
         
            +
            from collections import OrderedDict
         
     | 
| 6 | 
         
            +
            from typing import Any, List, Optional, Union
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import mmcv
         
     | 
| 9 | 
         
            +
            import copy
         
     | 
| 10 | 
         
            +
            import numpy as np
         
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
            import torch.distributed as dist
         
     | 
| 13 | 
         
            +
            from mmcv.runner import get_dist_info
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from .base_dataset import BaseMotionDataset
         
     | 
| 16 | 
         
            +
            from .builder import DATASETS
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            @DATASETS.register_module()
         
     | 
| 20 | 
         
            +
            class TextMotionDataset(BaseMotionDataset):
         
     | 
| 21 | 
         
            +
                """TextMotion dataset.
         
     | 
| 22 | 
         
            +
                
         
     | 
| 23 | 
         
            +
                Args:
         
     | 
| 24 | 
         
            +
                    text_dir (str): Path to the directory containing the text files.
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                def __init__(self,
         
     | 
| 27 | 
         
            +
                             data_prefix: str,
         
     | 
| 28 | 
         
            +
                             pipeline: list,
         
     | 
| 29 | 
         
            +
                             dataset_name: Optional[Union[str, None]] = None,
         
     | 
| 30 | 
         
            +
                             fixed_length: Optional[Union[int, None]] = None,
         
     | 
| 31 | 
         
            +
                             ann_file: Optional[Union[str, None]] = None,
         
     | 
| 32 | 
         
            +
                             motion_dir: Optional[Union[str, None]] = None,
         
     | 
| 33 | 
         
            +
                             text_dir: Optional[Union[str, None]] = None,
         
     | 
| 34 | 
         
            +
                             token_dir: Optional[Union[str, None]] = None,
         
     | 
| 35 | 
         
            +
                             clip_feat_dir: Optional[Union[str, None]] = None,
         
     | 
| 36 | 
         
            +
                             eval_cfg: Optional[Union[dict, None]] = None,
         
     | 
| 37 | 
         
            +
                             fine_mode: Optional[bool] = False,
         
     | 
| 38 | 
         
            +
                             test_mode: Optional[bool] = False):
         
     | 
| 39 | 
         
            +
                    self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, text_dir)
         
     | 
| 40 | 
         
            +
                    if token_dir is not None:
         
     | 
| 41 | 
         
            +
                        self.token_dir = os.path.join(data_prefix, 'datasets', dataset_name, token_dir)
         
     | 
| 42 | 
         
            +
                    else:
         
     | 
| 43 | 
         
            +
                        self.token_dir = None
         
     | 
| 44 | 
         
            +
                    if clip_feat_dir is not None:
         
     | 
| 45 | 
         
            +
                        self.clip_feat_dir = os.path.join(data_prefix, 'datasets', dataset_name, clip_feat_dir)
         
     | 
| 46 | 
         
            +
                    else:
         
     | 
| 47 | 
         
            +
                        self.clip_feat_dir = None
         
     | 
| 48 | 
         
            +
                    self.fine_mode = fine_mode
         
     | 
| 49 | 
         
            +
                    super(TextMotionDataset, self).__init__(
         
     | 
| 50 | 
         
            +
                        data_prefix=data_prefix,
         
     | 
| 51 | 
         
            +
                        pipeline=pipeline,
         
     | 
| 52 | 
         
            +
                        dataset_name=dataset_name,
         
     | 
| 53 | 
         
            +
                        fixed_length=fixed_length,
         
     | 
| 54 | 
         
            +
                        ann_file=ann_file,
         
     | 
| 55 | 
         
            +
                        motion_dir=motion_dir,
         
     | 
| 56 | 
         
            +
                        eval_cfg=eval_cfg,
         
     | 
| 57 | 
         
            +
                        test_mode=test_mode)
         
     | 
| 58 | 
         
            +
                    
         
     | 
| 59 | 
         
            +
                def load_anno(self, name):
         
     | 
| 60 | 
         
            +
                    results = super().load_anno(name)
         
     | 
| 61 | 
         
            +
                    text_path = os.path.join(self.text_dir, name + '.txt')
         
     | 
| 62 | 
         
            +
                    text_data = []
         
     | 
| 63 | 
         
            +
                    for line in open(text_path, 'r'):
         
     | 
| 64 | 
         
            +
                        text_data.append(line.strip())
         
     | 
| 65 | 
         
            +
                    results['text'] = text_data
         
     | 
| 66 | 
         
            +
                    if self.token_dir is not None:
         
     | 
| 67 | 
         
            +
                        token_path = os.path.join(self.token_dir, name + '.txt')
         
     | 
| 68 | 
         
            +
                        token_data = []
         
     | 
| 69 | 
         
            +
                        for line in open(token_path, 'r'):
         
     | 
| 70 | 
         
            +
                            token_data.append(line.strip())
         
     | 
| 71 | 
         
            +
                        results['token'] = token_data
         
     | 
| 72 | 
         
            +
                    if self.clip_feat_dir is not None:
         
     | 
| 73 | 
         
            +
                        clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy')
         
     | 
| 74 | 
         
            +
                        clip_feat = torch.from_numpy(np.load(clip_feat_path))
         
     | 
| 75 | 
         
            +
                        results['clip_feat'] = clip_feat
         
     | 
| 76 | 
         
            +
                    return results
         
     | 
| 77 | 
         
            +
                
         
     | 
| 78 | 
         
            +
                def prepare_data(self, idx: int):
         
     | 
| 79 | 
         
            +
                    """"Prepare raw data for the f'{idx'}-th data."""
         
     | 
| 80 | 
         
            +
                    results = copy.deepcopy(self.data_infos[idx])
         
     | 
| 81 | 
         
            +
                    text_list = results['text']
         
     | 
| 82 | 
         
            +
                    idx = np.random.randint(0, len(text_list))
         
     | 
| 83 | 
         
            +
                    if self.fine_mode:
         
     | 
| 84 | 
         
            +
                        results['text'] = json.loads(text_list[idx])
         
     | 
| 85 | 
         
            +
                    else:
         
     | 
| 86 | 
         
            +
                        results['text'] = text_list[idx]
         
     | 
| 87 | 
         
            +
                    if 'clip_feat' in results.keys():
         
     | 
| 88 | 
         
            +
                        results['clip_feat'] = results['clip_feat'][idx]
         
     | 
| 89 | 
         
            +
                    if 'token' in results.keys():
         
     | 
| 90 | 
         
            +
                        results['token'] = results['token'][idx]
         
     | 
| 91 | 
         
            +
                    results['dataset_name'] = self.dataset_name
         
     | 
| 92 | 
         
            +
                    results['sample_idx'] = idx
         
     | 
| 93 | 
         
            +
                    return self.pipeline(results)
         
     | 
    	
        mogen/models/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .architectures import *
         
     | 
| 2 | 
         
            +
            from .losses import *
         
     | 
| 3 | 
         
            +
            from .rnns import *
         
     | 
| 4 | 
         
            +
            from .transformers import *
         
     | 
| 5 | 
         
            +
            from .attentions import *
         
     | 
| 6 | 
         
            +
            from .builder import *
         
     | 
| 7 | 
         
            +
            from .utils import *
         
     |