Jialin Yang commited on
Commit
352b049
·
1 Parent(s): 7bbe360

Initial release on Huggingface Spaces with Gradio UI

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. .gitignore +213 -0
  3. .gradio/certificate.pem +31 -0
  4. .vscode/settings.json +6 -0
  5. README.md +4 -2
  6. Roboto-VariableFont_wdth,wght.ttf +0 -0
  7. SkeletonDiffusion/__init__.py +0 -0
  8. SkeletonDiffusion/configs/config_eval/config.yaml +53 -0
  9. SkeletonDiffusion/configs/config_eval/config_inferencetime.yaml +43 -0
  10. SkeletonDiffusion/configs/config_eval/dataset/3dpw.yaml +35 -0
  11. SkeletonDiffusion/configs/config_eval/dataset/amass-mano.yaml +76 -0
  12. SkeletonDiffusion/configs/config_eval/dataset/amass.yaml +52 -0
  13. SkeletonDiffusion/configs/config_eval/dataset/freeman.yaml +23 -0
  14. SkeletonDiffusion/configs/config_eval/dataset/h36m.yaml +26 -0
  15. SkeletonDiffusion/configs/config_eval/method_specs/skeleton_diffusion.yaml +1 -0
  16. SkeletonDiffusion/configs/config_eval/method_specs/zerovelocity_alg_baseline.yaml +3 -0
  17. SkeletonDiffusion/configs/config_eval/task/hmp.yaml +4 -0
  18. SkeletonDiffusion/configs/config_train/config_autoencoder.yaml +27 -0
  19. SkeletonDiffusion/configs/config_train/dataset/amass.yaml +48 -0
  20. SkeletonDiffusion/configs/config_train/dataset/freeman.yaml +38 -0
  21. SkeletonDiffusion/configs/config_train/dataset/h36m.yaml +40 -0
  22. SkeletonDiffusion/configs/config_train/model/autoencoder.yaml +57 -0
  23. SkeletonDiffusion/configs/config_train/task/hmp.yaml +11 -0
  24. SkeletonDiffusion/configs/config_train_diffusion/config_diffusion.yaml +25 -0
  25. SkeletonDiffusion/configs/config_train_diffusion/cov_matrix/adjacency.yaml +1 -0
  26. SkeletonDiffusion/configs/config_train_diffusion/cov_matrix/reachability.yaml +3 -0
  27. SkeletonDiffusion/configs/config_train_diffusion/model/isotropic_diffusion.yaml +57 -0
  28. SkeletonDiffusion/configs/config_train_diffusion/model/isotropic_diffusion_in_noniso_class.yaml +70 -0
  29. SkeletonDiffusion/configs/config_train_diffusion/model/skeleton_diffusion.yaml +69 -0
  30. SkeletonDiffusion/datasets +1 -0
  31. SkeletonDiffusion/environment_inference.yml +19 -0
  32. SkeletonDiffusion/inference.ipynb +343 -0
  33. SkeletonDiffusion/inference_filtered.ipynb +1 -0
  34. SkeletonDiffusion/setup.py +13 -0
  35. SkeletonDiffusion/src/__init__.py +7 -0
  36. SkeletonDiffusion/src/config_utils.py +62 -0
  37. SkeletonDiffusion/src/core/__init__.py +8 -0
  38. SkeletonDiffusion/src/core/diffusion/__init__.py +3 -0
  39. SkeletonDiffusion/src/core/diffusion/base.py +445 -0
  40. SkeletonDiffusion/src/core/diffusion/isotropic.py +104 -0
  41. SkeletonDiffusion/src/core/diffusion/nonisotropic.py +213 -0
  42. SkeletonDiffusion/src/core/diffusion/utils.py +125 -0
  43. SkeletonDiffusion/src/core/diffusion_manager.py +45 -0
  44. SkeletonDiffusion/src/core/network/__init__.py +3 -0
  45. SkeletonDiffusion/src/core/network/layers/__init__.py +3 -0
  46. SkeletonDiffusion/src/core/network/layers/attention.py +138 -0
  47. SkeletonDiffusion/src/core/network/layers/graph_structural.py +133 -0
  48. SkeletonDiffusion/src/core/network/layers/recurrent.py +402 -0
  49. SkeletonDiffusion/src/core/network/nn/__init__.py +2 -0
  50. SkeletonDiffusion/src/core/network/nn/autoencoder.py +105 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.torchscript filter=lfs diff=lfs merge=lfs -text
37
+ *.pkl filter=lfs diff=lfs merge=lfs -text
38
+ ./magick filter=lfs diff=lfs merge=lfs -text
39
+ models/* filter=lfs diff=lfs merge=lfs -text
40
+ models/nlf_l_multi.torchscript filter=lfs diff=lfs merge=lfs -text
41
+ models/checkpoint_150.pt filter=lfs diff=lfs merge=lfs -text
42
+ downloads/* filter=lfs diff=lfs merge=lfs -text
43
+ outputs/* filter=lfs diff=lfs merge=lfs -text
44
+ intermediate_results/* filter=lfs diff=lfs merge=lfs -text
45
+ predictions/* filter=lfs diff=lfs merge=lfs -text
46
+ predictions/joints3d.npy filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ *.egg-info/
19
+ .installed.cfg
20
+ *.egg
21
+
22
+ # Virtual Environment
23
+ venv/
24
+ ENV/
25
+ env/
26
+ .env
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # OS
35
+ .DS_Store
36
+ Thumbs.db
37
+
38
+ # Project specific
39
+ *.pth
40
+ *.ckpt
41
+ # *.pt
42
+ *.bin
43
+ *.npy
44
+ *.npz
45
+ *.mp4
46
+ *.avi
47
+ *.mov
48
+ *.jpg
49
+ *.jpeg
50
+ *.png
51
+
52
+ # Logs
53
+ *.log
54
+ logs/
55
+
56
+ # Distribution / packaging
57
+ .Python
58
+ build/
59
+ develop-eggs/
60
+ dist/
61
+ # downloads/
62
+ eggs/
63
+ .eggs/
64
+ lib/
65
+ lib64/
66
+ parts/
67
+ sdist/
68
+ var/
69
+ share/python-wheels/
70
+ *.egg-info/
71
+ .installed.cfg
72
+ *.egg
73
+ MANIFEST
74
+
75
+ # PyInstaller
76
+ # Usually these files are written by a python script from a template
77
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
78
+ *.manifest
79
+ *.spec
80
+
81
+ # Installer logs
82
+ pip-log.txt
83
+ pip-delete-this-directory.txt
84
+
85
+ # Unit test / coverage reports
86
+ htmlcov/
87
+ .tox/
88
+ .nox/
89
+ .coverage
90
+ .coverage.*
91
+ .cache
92
+ nosetests.xml
93
+ coverage.xml
94
+ *.cover
95
+ *.py,cover
96
+ .hypothesis/
97
+ .pytest_cache/
98
+ cover/
99
+
100
+ # Translations
101
+ *.mo
102
+ *.pot
103
+
104
+ # Django stuff:
105
+ *.log
106
+ local_settings.py
107
+ db.sqlite3
108
+ db.sqlite3-journal
109
+
110
+ # Flask stuff:
111
+ instance/
112
+ .webassets-cache
113
+
114
+ # Scrapy stuff:
115
+ .scrapy
116
+
117
+ # Sphinx documentation
118
+ docs/_build/
119
+
120
+ # PyBuilder
121
+ .pybuilder/
122
+ target/
123
+
124
+ # Jupyter Notebook
125
+ .ipynb_checkpoints
126
+
127
+ # IPython
128
+ profile_default/
129
+ ipython_config.py
130
+
131
+ # pyenv
132
+ # For a library or package, you might want to ignore these files since the code is
133
+ # intended to run in multiple environments; otherwise, check them in:
134
+ # .python-version
135
+
136
+ # pipenv
137
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
138
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
139
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
140
+ # install all needed dependencies.
141
+ #Pipfile.lock
142
+
143
+ # UV
144
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
145
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
146
+ # commonly ignored for libraries.
147
+ #uv.lock
148
+
149
+ # poetry
150
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
151
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
152
+ # commonly ignored for libraries.
153
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
154
+ #poetry.lock
155
+
156
+ # pdm
157
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
158
+ #pdm.lock
159
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
160
+ # in version control.
161
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
162
+ .pdm.toml
163
+ .pdm-python
164
+ .pdm-build/
165
+
166
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
167
+ __pypackages__/
168
+
169
+ # Celery stuff
170
+ celerybeat-schedule
171
+ celerybeat.pid
172
+
173
+ # SageMath parsed files
174
+ *.sage.py
175
+
176
+ # Spyder project settings
177
+ .spyderproject
178
+ .spyproject
179
+
180
+ # Rope project settings
181
+ .ropeproject
182
+
183
+ # mkdocs documentation
184
+ /site
185
+
186
+ # mypy
187
+ .mypy_cache/
188
+ .dmypy.json
189
+ dmypy.json
190
+
191
+ # Pyre type checker
192
+ .pyre/
193
+
194
+ # pytype static type analyzer
195
+ .pytype/
196
+
197
+ # Cython debug symbols
198
+ cython_debug/
199
+
200
+ # PyCharm
201
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
202
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
203
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
204
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
205
+ #.idea/
206
+
207
+ # PyPI configuration file
208
+ .pypirc
209
+ 9622_GRAB/
210
+ magick
211
+ outputs/*_obj
212
+ outputs/
213
+ intermediate_results/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.vscode/settings.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "python.analysis.extraPaths": [
3
+ "./SkeletonDiffusion/src",
4
+ "./src_joints2smpl_demo/convert_"
5
+ ]
6
+ }
README.md CHANGED
@@ -4,10 +4,12 @@ emoji: 💻
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.24.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ run: |
12
+ bash setup.sh
13
+ python app.py
14
  ---
15
 
 
Roboto-VariableFont_wdth,wght.ttf ADDED
Binary file (468 kB). View file
 
SkeletonDiffusion/__init__.py ADDED
File without changes
SkeletonDiffusion/configs/config_eval/config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ output_subdir: null
3
+ run:
4
+ dir: .
5
+ job:
6
+ chdir: False
7
+
8
+ dataset_main_path: ./datasets
9
+ dataset_annotation_path: ${dataset_main_path}/annotations&interm
10
+ dataset_precomputed_path: ${dataset_main_path}/processed
11
+ checkpoint_path: ''
12
+ defaults:
13
+ - _self_
14
+ - task: hmp
15
+ - method_specs: skeleton_diffusion
16
+ - dataset: amass
17
+ - override hydra/hydra_logging: disabled
18
+ - override hydra/job_logging: disabled
19
+
20
+ method_name: ${method_specs.method_name}
21
+ dtype: float32
22
+ if_noisy_obs: False
23
+ noise_level: 0.25
24
+ noise_std: 0.02
25
+ # num_nodes: ${eval:"int(${dataset.num_joints})-int(not ${task.if_consider_hip})"}
26
+
27
+ stats_mode: deterministic #probabilistic, deterministic
28
+ batch_size: 512
29
+ metrics_at_cpu: False
30
+ n_gpu: 1
31
+ num_samples: 50
32
+ if_measure_time: False
33
+
34
+ seed: 0
35
+ dataset_split: test
36
+ silent: False
37
+
38
+ obs_length: ${eval:'int(${task.history_sec} * ${dataset.fps})'}
39
+ pred_length: ${eval:'int(${task.prediction_horizon_sec} * ${dataset.fps})'}
40
+
41
+ if_store_output: False
42
+ store_output_path: ${eval:"'models/final_predictions_storage/${task.task_name}/${method_specs.method_name}/${dataset.dataset_name}/' if not ${if_long_term_test} else 'models/final_predictions_storage/${task.task_name}_longterm/${method_specs.method_name}/${dataset.dataset_name}/'"}
43
+ if_store_gt: False
44
+ store_gt_path: ${eval:"'models/final_predictions_storage/${task.task_name}/GT/${dataset.dataset_name}/' if not ${if_long_term_test} else 'models/final_predictions_storage/${task.task_name}_longterm/GT/${dataset.dataset_name}/'"}
45
+
46
+ if_compute_apde: ${eval:"False if ${eval:"'${dataset.dataset_name}' in ['freeman', '3dpw', 'nymeria']"} else True"}
47
+ if_long_term_test: False
48
+ long_term_factor: 2.5
49
+
50
+ if_compute_fid: False
51
+ if_compute_cmd: False
52
+
53
+
SkeletonDiffusion/configs/config_eval/config_inferencetime.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ output_subdir: null
3
+ run:
4
+ dir: .
5
+ job:
6
+ chdir: False
7
+
8
+ dataset_main_path: ./datasets
9
+ dataset_annotation_path: ${dataset_main_path}/annotations&interm
10
+ dataset_precomputed_path: ${dataset_main_path}/processed
11
+ checkpoint_path: ''
12
+ defaults:
13
+ - _self_
14
+ - task: hmp
15
+ - method_specs: skeldiff
16
+ - dataset: amass
17
+ - override hydra/hydra_logging: disabled
18
+ - override hydra/job_logging: disabled
19
+
20
+ method_name: ${method_specs.method_name}
21
+ dtype: float32
22
+ if_noisy_obs: False
23
+ noise_level: 0.25
24
+ noise_std: 0.02
25
+
26
+ mode: stats # 'vis: visualize results\ngen: generate and store all visualizations for a single batch\nstats: launch numeric evaluation')
27
+ stats_mode: deterministic #
28
+ if_measure_time: True
29
+ batch_size: 1
30
+ metrics_at_cpu: False
31
+ n_gpu: 1
32
+ num_samples: 50
33
+
34
+ seed: 0
35
+ dataset_split: test
36
+ silent: False
37
+
38
+ obs_length: ${eval:'int(${task.history_sec} * ${dataset.fps})'}
39
+ pred_length: ${eval:'int(${task.prediction_horizon_sec} * ${dataset.fps})'}
40
+
41
+ if_long_term_test: False
42
+
43
+
SkeletonDiffusion/configs/config_eval/dataset/3dpw.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_joints: 22 #including the hip root joint
2
+ fps: 60
3
+
4
+ multimodal_threshold: 0.4
5
+ dataset_type: D3PWZeroShotDataset
6
+ dataset_name: 3dpw
7
+ precomputed_folder: "${dataset_precomputed_path}/3DPW/${task.task_name}/"
8
+ annotations_folder: "${dataset_annotation_path}/3DPW/${task.task_name}/"
9
+ dtype: float32
10
+
11
+ data_loader_train_eval:
12
+ stride: 30
13
+ augmentation: 0
14
+ shuffle: False
15
+ da_mirroring: 0.
16
+ da_rotations: 0.
17
+ drop_last: False
18
+ if_load_mmgt: False
19
+
20
+ data_loader_valid:
21
+ stride: 30
22
+ augmentation: 0
23
+ shuffle: False
24
+ segments_path: "${dataset.annotations_folder}/segments_valid.csv"
25
+ actions: "all"
26
+ drop_last: False
27
+ if_load_mmgt: False
28
+
29
+
30
+ data_loader_test:
31
+ shuffle: False
32
+ segments_path: "${dataset.annotations_folder}/segments_test_zero_shot.csv"
33
+ actions: "all"
34
+ drop_last: False
35
+ if_load_mmgt: ${eval:'True if "probabilistic" in "${stats_mode}" else False'}
SkeletonDiffusion/configs/config_eval/dataset/amass-mano.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_joints: 52 #included hip even if if_consider_hip=False
2
+ fps: 60
3
+
4
+ multimodal_threshold: 0.4
5
+ dataset_type: AMASSDataset
6
+ dataset_name: amass-mano
7
+ precomputed_folder: "${dataset_precomputed_path}/AMASS-MANO/${task.task_name}/"
8
+ annotations_folder: "${dataset_annotation_path}/AMASS-MANO/${task.task_name}/"
9
+ dtype: float32
10
+
11
+ # Accordingly, the training set
12
+ # contains the ACCAD, BMLhandball, BMLmovi, BMLrub,
13
+ # CMU, EKUT, EyesJapanDataset, KIT, PosePrior, TCD-
14
+ # Hands, and TotalCapture datasets, and the validation set
15
+ # contains the HumanEva, HDM05, SFU, and MoSh datasets.
16
+ # The remaining datasets are all part of the test set: DFaust,
17
+ # DanceDB, GRAB, HUMAN4D, SOMA, SSM, and Transi-
18
+ # tions.
19
+
20
+
21
+ data_loader_train:
22
+ stride: 60
23
+ augmentation: 30
24
+ shuffle: True
25
+ datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture'] # from paper
26
+
27
+
28
+
29
+ # "EyesJapanDataset",
30
+
31
+ # ,
32
+
33
+ # "HDM05",
34
+ # "MoSh"
35
+ da_mirroring: 0.5
36
+ da_rotations: 1.0
37
+ drop_last: True
38
+ if_load_mmgt: False
39
+
40
+
41
+ data_loader_train_eval:
42
+ stride: 30
43
+ augmentation: 0
44
+ shuffle: False
45
+ datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture'] # from paper
46
+ # datasets: ['ACCAD', "BMLhandball", "BMLmovi", 'CMU', 'KIT', 'TotalCapture'] # decrease evaluatio time
47
+ da_mirroring: 0.
48
+ da_rotations: 0.
49
+ drop_last: False
50
+ if_load_mmgt: False
51
+
52
+ data_loader_valid:
53
+ stride: 30
54
+ augmentation: 0
55
+ shuffle: False
56
+ # segments_path: "./dataset_annotation_path/FreeMan/${task.task_name}/segments_valid.csv"
57
+ datasets: ['HumanEva', 'HDM05', 'SFU', 'MoSh'] # from paper
58
+ file_idces: "all"
59
+ drop_last: False
60
+ if_load_mmgt: False
61
+
62
+
63
+ data_loader_test:
64
+ shuffle: False
65
+ segments_path: "${dataset.annotations_folder}/segments_test.csv"
66
+ # datasets: ['Transitions_mocap', 'SSM_synced'], #DFaust, DanceDB, GRAB, HUMAN4D, SOMA, SSM, and Transitions.
67
+ datasets:
68
+ - Transitions
69
+ - SSM
70
+ - DFaust
71
+ - DanceDB
72
+ - GRAB
73
+ - HUMAN4D
74
+ - SOMA
75
+ drop_last: False
76
+ if_load_mmgt: False
SkeletonDiffusion/configs/config_eval/dataset/amass.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_joints: 22 #including the hip root joint
2
+ fps: 60
3
+
4
+ multimodal_threshold: 0.4
5
+ dataset_type: AMASSDataset
6
+ dataset_name: "amass"
7
+ precomputed_folder: "${dataset_precomputed_path}/AMASS/${task.task_name}/"
8
+ annotations_folder: "${dataset_annotation_path}/AMASS/${task.task_name}/"
9
+ dtype: float32
10
+
11
+ # Accordingly, the training set
12
+ # contains the ACCAD, BMLhandball, BMLmovi, BMLrub,
13
+ # CMU, EKUT, EyesJapanDataset, KIT, PosePrior, TCD-
14
+ # Hands, and TotalCapture datasets, and the validation set
15
+ # contains the HumanEva, HDM05, SFU, and MoSh datasets.
16
+ # The remaining datasets are all part of the test set: DFaust,
17
+ # DanceDB, GRAB, HUMAN4D, SOMA, SSM, and Transi-
18
+ # tions.
19
+
20
+ data_loader_train_eval:
21
+ stride: 30
22
+ augmentation: 0
23
+ shuffle: False
24
+ datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture']
25
+ da_mirroring: 0.
26
+ da_rotations: 0.
27
+ drop_last: False
28
+ if_load_mmgt: False
29
+
30
+ data_loader_valid:
31
+ stride: 30
32
+ augmentation: 0
33
+ shuffle: False
34
+ datasets: ['HumanEva', 'HDM05', 'SFU', MoSh']
35
+ file_idces: "all"
36
+ drop_last: False
37
+ if_load_mmgt: False
38
+
39
+
40
+ data_loader_test:
41
+ shuffle: False
42
+ segments_path: ${eval:"'${dataset.annotations_folder}/segments_test.csv' if not ${if_long_term_test} else '${dataset.annotations_folder}/segments_5s_test_long_term_pred.csv'"}
43
+ datasets:
44
+ - Transitions
45
+ - SSM
46
+ - DFaust
47
+ - DanceDB
48
+ - GRAB
49
+ - HUMAN4D
50
+ - SOMA
51
+ drop_last: False
52
+ if_load_mmgt: ${eval:'True if "probabilistic" in "${stats_mode}" else False'}
SkeletonDiffusion/configs/config_eval/dataset/freeman.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_joints: 18 #including the hip root joint
2
+ fps: 30
3
+ dataset_type: FreeManDataset
4
+ dataset_name: "freeman"
5
+ precomputed_folder: "${dataset_precomputed_path}/FreeMan/${task.task_name}/"
6
+ annotations_folder: "${dataset_annotation_path}/FreeMan/${task.task_name}/"
7
+ dtype: float32
8
+ multimodal_threshold: 0.5
9
+
10
+ data_loader_valid:
11
+ shuffle: False
12
+ segments_path: "${dataset.annotations_folder}/segments_valid.csv"
13
+ actions: "all"
14
+ drop_last: False
15
+ if_load_mmgt: ${eval:'True if "probabilistic in str(${stats_mode})" else False'}
16
+
17
+
18
+ data_loader_test:
19
+ shuffle: False
20
+ segments_path: "${dataset.annotations_folder}/segments_test.csv"
21
+ actions: "all"
22
+ drop_last: False
23
+ if_load_mmgt: ${eval:'True if "probabilistic" in "${stats_mode}" else False'}
SkeletonDiffusion/configs/config_eval/dataset/h36m.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_joints: 17 #including the hip root joint
2
+ fps: 50
3
+ dataset_type: H36MDataset
4
+ dataset_name: "h36m"
5
+ precomputed_folder: "${dataset_precomputed_path}/Human36M/${task.task_name}/"
6
+ annotations_folder: "${dataset_annotation_path}/Human36M/${task.task_name}"
7
+ dtype: float32
8
+ multimodal_threshold: 0.5
9
+
10
+ data_loader_valid:
11
+ augmentation: 0
12
+ shuffle: False
13
+ subjects: ["S8"]
14
+ segments_path: "${dataset.annotations_folder}/segments_valid.csv"
15
+ actions: "all"
16
+ drop_last: False
17
+ if_load_mmgt: ${eval:'True if "probabilistic in str(${stats_mode})" else False'}
18
+
19
+ data_loader_test:
20
+ shuffle: False
21
+ augmentation: 0
22
+ segments_path: "${dataset.annotations_folder}/segments_test.csv"
23
+ subjects: ["S9", "S11"]
24
+ actions: "all"
25
+ drop_last: False
26
+ if_load_mmgt: ${eval:'True if "probabilistic" in "${stats_mode}" else False'}
SkeletonDiffusion/configs/config_eval/method_specs/skeleton_diffusion.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ method_name: SkeletonDiffusion
SkeletonDiffusion/configs/config_eval/method_specs/zerovelocity_alg_baseline.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ motion_repr_type: "SkeletonCenterPose"
2
+ method_name: ZeroVelocityBaseline
3
+ baseline_out_path: ./models/output/baselines
SkeletonDiffusion/configs/config_eval/task/hmp.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ history_sec: 0.5 #${eval:'float(1) if ${task}=="motpred" else float(0.5)'}
2
+ prediction_horizon_sec: 2 # ${eval:"float(4) if ${task}=='motpred' else float(2)"}
3
+ task_name: "hmp"
4
+ if_consider_hip: False
SkeletonDiffusion/configs/config_train/config_autoencoder.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - task: hmp
4
+ - dataset: h36m
5
+ - model: autoencoder
6
+ - override hydra/job_logging: disabled
7
+ # - override hydra/hydra_logging: disabled
8
+
9
+ dataset_main_path: ./datasets
10
+ dataset_annotation_path: ${dataset_main_path}/annotations&interm
11
+ dataset_precomputed_path: ${dataset_main_path}/processed
12
+ if_resume_training: false
13
+ debug: false
14
+ device: cuda
15
+ load: false
16
+ load_path: ''
17
+ output_log_path: ../../my_exps/output/${task.task_name}/${dataset.dataset_name}/autoencoder/${now:%B%d_%H-%M-%S}_ID${slurm_id}_${info}
18
+ slurm_id: 0
19
+ slurm_first_run: None
20
+ info: ''
21
+ hydra:
22
+ run:
23
+ dir: ${output_log_path}
24
+ job:
25
+ chdir: False
26
+
27
+
SkeletonDiffusion/configs/config_train/dataset/amass.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_joints: 22 #including the hip root joint
2
+ fps: 60
3
+
4
+ multimodal_threshold: 0.4
5
+ dataset_type: AMASSDataset
6
+ dataset_name: amass
7
+ precomputed_folder: "${dataset_precomputed_path}/AMASS/${task.task_name}/"
8
+ annotations_folder: "${dataset_annotation_path}/AMASS/${task.task_name}/"
9
+ dtype: float32
10
+
11
+ # Accordingly, the training set
12
+ # contains the ACCAD, BMLhandball, BMLmovi, BMLrub,
13
+ # CMU, EKUT, EyesJapanDataset, KIT, PosePrior, TCD-
14
+ # Hands, and TotalCapture datasets, and the validation set
15
+ # contains the HumanEva, HDM05, SFU, and MoSh datasets.
16
+ # The remaining datasets are all part of the test set: DFaust,
17
+ # DanceDB, GRAB, HUMAN4D, SOMA, SSM, and Transi-
18
+ # tions.
19
+
20
+ data_loader_train:
21
+ stride: 60
22
+ augmentation: 30
23
+ shuffle: True
24
+ datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture']
25
+ da_mirroring: 0.5
26
+ da_rotations: 1.0
27
+ drop_last: True
28
+ if_load_mmgt: False
29
+
30
+
31
+ data_loader_train_eval:
32
+ stride: 30
33
+ augmentation: 0
34
+ shuffle: False
35
+ datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture']
36
+ da_mirroring: 0.
37
+ da_rotations: 0.
38
+ drop_last: False
39
+ if_load_mmgt: False
40
+
41
+ data_loader_valid:
42
+ stride: 30
43
+ augmentation: 0
44
+ shuffle: False
45
+ datasets: ['HumanEva', 'HDM05', 'SFU', 'MoSh']
46
+ file_idces: "all"
47
+ drop_last: False
48
+ if_load_mmgt: False
SkeletonDiffusion/configs/config_train/dataset/freeman.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_joints: 18 #including the hip root joint
2
+ fps: 30
3
+
4
+ multimodal_threshold: 0.5
5
+ dataset_type: FreeManDataset
6
+ dataset_name: freeman
7
+ precomputed_folder: "${dataset_precomputed_path}/FreeMan/${task.task_name}/"
8
+ annotations_folder: "${dataset_annotation_path}/FreeMan/${task.task_name}/"
9
+ dtype: float32
10
+
11
+ data_loader_train:
12
+ stride: 10
13
+ augmentation: 5
14
+ shuffle: True
15
+ actions: "all"
16
+ da_mirroring: 0.5
17
+ da_rotations: 1.0
18
+ drop_last: True
19
+ if_load_mmgt: False
20
+
21
+ data_loader_train_eval:
22
+ stride: 30
23
+ augmentation: 0
24
+ shuffle: False
25
+ actions: "all"
26
+ da_mirroring: 0.
27
+ da_rotations: 0.
28
+ drop_last: False
29
+ if_load_mmgt: False
30
+
31
+ data_loader_valid:
32
+ stride: 30
33
+ augmentation: 0
34
+ shuffle: False
35
+ segments_path: "${dataset.annotations_folder}/segments_valid.csv"
36
+ actions: "all"
37
+ drop_last: False
38
+ if_load_mmgt: False
SkeletonDiffusion/configs/config_train/dataset/h36m.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_joints: 17 #including the hip root joint
2
+ fps: 50
3
+
4
+ multimodal_threshold: 0.5
5
+ dataset_type: H36MDataset
6
+ dataset_name: h36m
7
+ precomputed_folder: "${dataset_precomputed_path}/Human36M/${task.task_name}/"
8
+ annotations_folder: "${dataset_annotation_path}/Human36M/${task.task_name}"
9
+ dtype: float32
10
+ data_loader_train:
11
+ stride: 10
12
+ augmentation: 5
13
+ shuffle: True
14
+ subjects: ["S1", "S5", "S6", "S7", "S8"] # training on the validation split as well as in BeLFusion, CoMusion
15
+ actions: "all"
16
+ da_mirroring: 0.5
17
+ da_rotations: 1.0
18
+ drop_last: True
19
+ if_load_mmgt: False
20
+
21
+ data_loader_train_eval:
22
+ stride: 30
23
+ augmentation: 0
24
+ shuffle: False
25
+ subjects: ["S1", "S5", "S6", "S7", "S8"]
26
+ actions: "all"
27
+ da_mirroring: 0.
28
+ da_rotations: 0.
29
+ drop_last: False
30
+ if_load_mmgt: False
31
+
32
+ data_loader_valid:
33
+ stride: 30
34
+ augmentation: 0
35
+ shuffle: False
36
+ subjects: ["S8"]
37
+ segments_path: "${dataset.annotations_folder}/segments_valid.csv"
38
+ actions: "all"
39
+ drop_last: False
40
+ if_load_mmgt: False
SkeletonDiffusion/configs/config_train/model/autoencoder.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 64
2
+ batch_size_eval: 512
3
+ eval_frequency: 5
4
+ num_epochs: 200
5
+ num_iteration_eval: 10
6
+ num_workers: 4
7
+ seed: 52345
8
+
9
+ use_lr_scheduler: True
10
+ lr_scheduler_kwargs:
11
+ lr_scheduler_type: ExponentialLRSchedulerWarmup
12
+ warmup_duration: 10
13
+ update_every: 1
14
+ min_lr: 1.e-4
15
+ gamma_decay: 0.98
16
+
17
+
18
+ loss_pose_type: l1
19
+
20
+
21
+ lr: 0.5e-2
22
+
23
+ latent_size: 96
24
+ output_size: 3 #128
25
+
26
+
27
+ z_activation: tanh
28
+
29
+ num_iter_perepoch: ${eval:"int(485) if ${eval:"'${dataset.dataset_name}' == 'h36m'"} else 580"}
30
+
31
+ obs_length: ${eval:'int(${task.history_sec} * ${dataset.fps})'}
32
+ prediction_horizon_train: ${model.prediction_horizon}
33
+ prediction_horizon_eval: ${model.prediction_horizon}
34
+ prediction_horizon: ${eval:'int(${task.prediction_horizon_sec} * ${dataset.fps})'}
35
+ pred_length: ${model.prediction_horizon_eval}
36
+
37
+
38
+
39
+ autoenc_arch:
40
+ enc_num_layers: 1
41
+ encoder_hidden_size: 96
42
+ decoder_hidden_size: 96
43
+ arch: AutoEncoder
44
+ recurrent_arch_enc: StaticGraphGRU
45
+ recurrent_arch_decoder: StaticGraphGRU
46
+
47
+
48
+
49
+ prediction_horizon_train_min: 10
50
+ prediction_horizon_train_min_from_epoch: 200
51
+ curriculum_it: 10
52
+ random_prediction_horizon: True
53
+
54
+
55
+
56
+
57
+
SkeletonDiffusion/configs/config_train/task/hmp.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if_consider_hip: False
2
+
3
+ history_sec: 0.5
4
+ prediction_horizon_sec: 2
5
+
6
+
7
+ # Joint representation & Skeleton specs
8
+ motion_repr_type: "SkeletonRescalePose"
9
+ pose_box_size: 1.5 # in meters
10
+ seq_centering: 0
11
+ task_name: hmp
SkeletonDiffusion/configs/config_train_diffusion/config_diffusion.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if_resume_training: false
2
+ debug: false
3
+ device: cuda
4
+ load: false
5
+ load_path: ''
6
+ dataset_main_path: ./datasets
7
+ dataset_annotation_path: ${dataset_main_path}/annotations&interm
8
+ dataset_precomputed_path: ${dataset_main_path}/processed
9
+ _load_saved_aoutoenc: hmp-h36m
10
+
11
+ output_log_path: ../../my_exps/output/${eval:"'${_load_saved_aoutoenc}'.split('-', 1)[0]"}/${eval:"'${_load_saved_aoutoenc}'.split('-', 1)[1]"}/diffusion/${now:%B%d_%H-%M-%S}_ID${slurm_id}_${info}
12
+ slurm_id: 0
13
+ slurm_first_run: None
14
+ info: ''
15
+ hydra:
16
+ run:
17
+ dir: ${output_log_path}
18
+ job:
19
+ chdir: False
20
+
21
+ defaults:
22
+ - _self_
23
+ - model: skeleton_diffusion
24
+ - cov_matrix: adjacency
25
+ - override hydra/job_logging: disabled
SkeletonDiffusion/configs/config_train_diffusion/cov_matrix/adjacency.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ covariance_matrix_type: adjacency
SkeletonDiffusion/configs/config_train_diffusion/cov_matrix/reachability.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ covariance_matrix_type: reachability
2
+ reachability_matrix_degree_factor: 0.5
3
+ reachability_matrix_stop_at: hips # or None 'hips' or null
SkeletonDiffusion/configs/config_train_diffusion/model/isotropic_diffusion.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_GM_folder: ${eval:"'./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310' if ${eval:"'${_load_saved_aoutoenc}'.split('-')[1] == 'h36m'"} else './models/final_checkpoints/AMASS/hmp/autoencoder/May11_10-35-09_ID1185354'"}
2
+ pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/checkpoint_final.pt'
3
+
4
+ # These options have to still be checked (ablation)
5
+ lr: 1.e-3
6
+ diffusion_objective: pred_x0
7
+ weight_decay: 0.
8
+
9
+
10
+ if_use_ema: True
11
+ step_start_ema: 100 #100 is default
12
+ ema_power: ${eval:'2/3'} #${eval:'3/4'} #
13
+ ema_update_every: 10
14
+ ema_min_value: 0.0
15
+ use_lr_scheduler: True
16
+ lr_scheduler_kwargs:
17
+ lr_scheduler_type: ExponentialLRSchedulerWarmup # or SchedulerReduceLROnPlateau
18
+ warmup_duration: 200
19
+ update_every: 10
20
+ min_lr: 1.e-4
21
+ gamma_decay: 0.98
22
+
23
+ # THese option are already ablated
24
+ diffusion_conditioning: True
25
+ num_epochs: 600
26
+ num_workers: 4
27
+ batch_size: 64
28
+ batch_size_eval: 256 # or 256 TO CHECK
29
+ eval_frequency: 25 #1000 # in epochs
30
+ train_pick_best_sample_among_k: 50
31
+ similarity_space: latent_space # input_space, latent_space or metric_space
32
+
33
+ diffusion_activation: identity
34
+ num_prob_samples: 50
35
+ diffusion_timesteps: 10
36
+
37
+ diffusion_type: IsotropicGaussianDiffusion
38
+ beta_schedule: cosine
39
+ diffusion_loss_type: l1
40
+ num_iter_perepoch: null
41
+ seed: 63485
42
+
43
+ diffusion_arch:
44
+ arch: Denoiser
45
+ use_attention: True
46
+ self_condition: False
47
+ norm_type: none
48
+ depth: 1
49
+ # resnet_block_groups = 8,
50
+ # learned_variance: False
51
+ # learned_sinusoidal_cond: False
52
+ # random_fourier_features: False
53
+ # learned_sinusoidal_dim: 16
54
+ # sinusoidal_pos_emb_theta: 10000
55
+ attn_dim_head: 32
56
+ attn_heads: 4
57
+ learn_influence: True
SkeletonDiffusion/configs/config_train_diffusion/model/isotropic_diffusion_in_noniso_class.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pretrained_GM_folder: ./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310
2
+ # _pretrained_GM_checkpoint: checkpoint_final
3
+ # pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/${model._pretrained_GM_checkpoint}.pt'
4
+
5
+ _pretrained_GM_checkpoint: checkpoint_final
6
+ pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/${model._pretrained_GM_checkpoint}.pt'
7
+ pretrained_GM_folder: ${eval:"'./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310' if ${eval:"'${_load_saved_aoutoenc}'.split('-')[1] == 'h36m'"} else './models/final_checkpoints/AMASS/hmp/autoencoder/May11_10-35-09_ID1185354'"}
8
+
9
+
10
+ # These options have to still be checked (ablation)
11
+ lr: 1.e-3
12
+ diffusion_objective: pred_x0
13
+ weight_decay: 0.
14
+
15
+
16
+ if_use_ema: True
17
+ step_start_ema: 100 #100 is default
18
+ ema_power: ${eval:'2/3'} #${eval:'3/4'} #
19
+ ema_update_every: 10
20
+ ema_min_value: 0.0
21
+ use_lr_scheduler: True
22
+ lr_scheduler_kwargs:
23
+ lr_scheduler_type: ExponentialLRSchedulerWarmup # or SchedulerReduceLROnPlateau
24
+ warmup_duration: 200
25
+ update_every: 10
26
+ min_lr: 1.e-4
27
+ gamma_decay: 0.98
28
+
29
+ # THese option are already ablated
30
+ diffusion_conditioning: True
31
+ num_epochs: 800
32
+ num_workers: 4
33
+ batch_size: 64
34
+ batch_size_eval: 256
35
+ eval_frequency: 25 #1000 # in epochs
36
+ train_pick_best_sample_among_k: 50
37
+ similarity_space: latent_space # input_space, latent_space or metric_space
38
+
39
+
40
+ diffusion_activation: identity
41
+ num_prob_samples: 50
42
+ diffusion_timesteps: 10
43
+
44
+ diffusion_type: NonisotropicGaussianDiffusion
45
+ diffusion_loss_type: snr #snr_triangle_inequality,mahalanobis, snr
46
+ loss_reduction_type: l1
47
+ if_run_as_isotropic: True
48
+ if_sigma_n_scale: True
49
+ diffusion_covariance_type: isotropic # anisotropic, isotropic, skeleton-diffusion
50
+ gamma_scheduler: cosine # mono_decrease, cosine
51
+ beta_schedule: cosine
52
+ sigma_n_scale: spectral
53
+ num_iter_perepoch: null
54
+ seed: 63485
55
+
56
+ diffusion_arch:
57
+ arch: Denoiser
58
+ use_attention: True
59
+ self_condition: False # True holds better results, but it takes longer to train.
60
+ norm_type: none
61
+ depth: 1
62
+ # resnet_block_groups = 8,
63
+ # learned_variance: False
64
+ # learned_sinusoidal_cond: False
65
+ # random_fourier_features: False
66
+ # learned_sinusoidal_dim: 16
67
+ # sinusoidal_pos_emb_theta: 10000
68
+ attn_dim_head: 32
69
+ attn_heads: 4
70
+ learn_influence: True
SkeletonDiffusion/configs/config_train_diffusion/model/skeleton_diffusion.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pretrained_GM_folder: ./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310
2
+ # _pretrained_GM_checkpoint: checkpoint_final
3
+ # pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/${model._pretrained_GM_checkpoint}.pt'
4
+
5
+ _pretrained_GM_checkpoint: checkpoint_final
6
+ pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/${model._pretrained_GM_checkpoint}.pt'
7
+ pretrained_GM_folder: ${eval:"'./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310' if ${eval:"'${_load_saved_aoutoenc}'.split('-')[1] == 'h36m'"} else './models/final_checkpoints/AMASS/hmp/autoencoder/May11_10-35-09_ID1185354'"}
8
+
9
+
10
+ # These options have to still be checked (ablation)
11
+ lr: 1.e-3
12
+ diffusion_objective: pred_x0
13
+ weight_decay: 0.
14
+
15
+
16
+ if_use_ema: True
17
+ step_start_ema: 100 #100 is default
18
+ ema_power: ${eval:'2/3'} #${eval:'3/4'} #
19
+ ema_update_every: 10
20
+ ema_min_value: 0.0
21
+ use_lr_scheduler: True
22
+ lr_scheduler_kwargs:
23
+ lr_scheduler_type: ExponentialLRSchedulerWarmup # or SchedulerReduceLROnPlateau
24
+ warmup_duration: 200
25
+ update_every: 10
26
+ min_lr: 1.e-4
27
+ gamma_decay: 0.98
28
+
29
+ diffusion_conditioning: True
30
+ num_epochs: 800
31
+ num_workers: 4
32
+ batch_size: 64
33
+ batch_size_eval: 256
34
+ eval_frequency: 25 #1000 # in epochs
35
+ train_pick_best_sample_among_k: 50
36
+ similarity_space: latent_space # input_space, latent_space or metric_space
37
+
38
+
39
+ diffusion_activation: identity
40
+ num_prob_samples: 50
41
+ diffusion_timesteps: 10
42
+
43
+ diffusion_type: NonisotropicGaussianDiffusion
44
+ diffusion_loss_type: snr #snr_triangle_inequality,mahalanobis, snr
45
+ loss_reduction_type: l1
46
+ if_run_as_isotropic: False
47
+ if_sigma_n_scale: True
48
+ diffusion_covariance_type: skeleton-diffusion # anisotropic, isotropic, skeleton-diffusion
49
+ gamma_scheduler: cosine # mono_decrease, cosine
50
+ beta_schedule: cosine
51
+ sigma_n_scale: spectral
52
+ num_iter_perepoch: null
53
+ seed: 63485
54
+
55
+ diffusion_arch:
56
+ arch: Denoiser
57
+ use_attention: True
58
+ self_condition: False # True holds better results, but it takes longer to train.
59
+ norm_type: none
60
+ depth: 1
61
+ # resnet_block_groups = 8,
62
+ # learned_variance: False
63
+ # learned_sinusoidal_cond: False
64
+ # random_fourier_features: False
65
+ # learned_sinusoidal_dim: 16
66
+ # sinusoidal_pos_emb_theta: 10000
67
+ attn_dim_head: 32
68
+ attn_heads: 4
69
+ learn_influence: True
SkeletonDiffusion/datasets ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../motion_must_go_on/datasets/
SkeletonDiffusion/environment_inference.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: skeldiff_inf
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.10.12
7
+ - pytorch=2.0.1
8
+ - pytorch-cuda=11.8
9
+ - torchvision=0.15.2
10
+ - pyyaml=6.0.1
11
+ - einops=0.7.0
12
+ - pip
13
+ - pip:
14
+ - denoising-diffusion-pytorch==1.9.4
15
+ # - pyyaml=6.0.1
16
+ # - imageio
17
+ # - ipympl=0.9.3
18
+ # - ffmpeg
19
+ # - opencv
SkeletonDiffusion/inference.ipynb ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "e927d3c2",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/storage/user/yaji/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import os\n",
19
+ "os.chdir(r\"/home/stud/yaji/storage/user/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src\") \n",
20
+ "root_path = os.getcwd()\n",
21
+ "print(root_path)\n",
22
+ "\n",
23
+ "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\""
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 2,
29
+ "id": "32b71ab8",
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stderr",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "/home/stud/yaji/miniconda3/envs/live_demo/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
37
+ " from .autonotebook import tqdm as notebook_tqdm\n",
38
+ "/storage/user/yaji/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src/core/diffusion/base.py:184: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
39
+ " @autocast(enabled = False)\n",
40
+ "/storage/user/yaji/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src/core/diffusion/isotropic.py:72: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
41
+ " @autocast(enabled = False)\n",
42
+ "/storage/user/yaji/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src/core/diffusion/nonisotropic.py:138: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
43
+ " @autocast(enabled = False)\n"
44
+ ]
45
+ }
46
+ ],
47
+ "source": [
48
+ "from eval_prepare_model import prepare_model, get_prediction, load_model_config_exp\n",
49
+ "from data import create_skeleton\n",
50
+ "import torch\n",
51
+ "import numpy as np\n",
52
+ "import random\n",
53
+ "\n",
54
+ "def set_seed(seed=0):\n",
55
+ " torch.use_deterministic_algorithms(True)\n",
56
+ " torch.backends.cudnn.deterministic = True\n",
57
+ " torch.backends.cudnn.benchmark = False\n",
58
+ " np.random.seed(seed)\n",
59
+ " random.seed(seed)\n",
60
+ " torch.cuda.manual_seed(seed)\n",
61
+ " torch.cuda.manual_seed_all(seed)"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 3,
67
+ "id": "0963a8bd",
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "checkpoint_path = '/usr/wiss/curreli/work/my_exps/checkpoints_release/amass/diffusion/cvpr_release/checkpoints/checkpoint_150.pt'\n",
72
+ "# checkpoint_path = '/usr/wiss/curreli/work/my_exps/checkpoints_release/amass-mano/diffusion/cvpr_release/checkpoints/checkpoint_150.pt'\n",
73
+ "\n",
74
+ "\n",
75
+ "num_samples = 50"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 4,
81
+ "id": "c6ce7b0a",
82
+ "metadata": {},
83
+ "outputs": [
84
+ {
85
+ "name": "stdout",
86
+ "output_type": "stream",
87
+ "text": [
88
+ "> GPU 0 ready: NVIDIA RTX A2000 12GB\n",
89
+ "Loading Autoencoder checkpoint: /usr/wiss/curreli/work/my_exps/checkpoints_release/amass/autoencoder/cvpr_release/checkpoints/checkpoint_300.pt ...\n",
90
+ "Diffusion is_ddim_sampling: False\n",
91
+ "Loading Diffusion checkpoint: /usr/wiss/curreli/work/my_exps/checkpoints_release/amass/diffusion/cvpr_release/checkpoints/checkpoint_150.pt ...\n"
92
+ ]
93
+ }
94
+ ],
95
+ "source": [
96
+ "set_seed(seed=0)\n",
97
+ "\n",
98
+ "config, exp_folder = load_model_config_exp(checkpoint_path)\n",
99
+ "config['checkpoint_path'] = checkpoint_path\n",
100
+ "skeleton = create_skeleton(**config) \n",
101
+ "\n",
102
+ "\n",
103
+ "model, device, *_ = prepare_model(config, skeleton, **config)"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 5,
109
+ "id": "5c4aa1a7",
110
+ "metadata": {},
111
+ "outputs": [
112
+ {
113
+ "data": {
114
+ "text/plain": [
115
+ "torch.Size([1, 30, 22, 3])"
116
+ ]
117
+ },
118
+ "execution_count": 5,
119
+ "metadata": {},
120
+ "output_type": "execute_result"
121
+ }
122
+ ],
123
+ "source": [
124
+ "# prepare input\n",
125
+ "# load input. It should be in meters\n",
126
+ "import numpy as np\n",
127
+ "import torch\n",
128
+ "obs = np.load('/usr/wiss/curreli/work/my_exps/checkpoints_release/amass/exaple_obs.npy') # (t_past, J, 3)\n",
129
+ "# obs = np.load('/usr/wiss/curreli/work/my_exps/checkpoints_release/amass-mano/example_obs.npy') # (t_past, J, 3)\n",
130
+ "\n",
131
+ "\n",
132
+ "obs = torch.from_numpy(obs).to(device)\n",
133
+ "obs = obs.unsqueeze(0) # add bacth size\n",
134
+ "obs.shape"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 6,
140
+ "id": "e782b749",
141
+ "metadata": {},
142
+ "outputs": [
143
+ {
144
+ "name": "stdout",
145
+ "output_type": "stream",
146
+ "text": [
147
+ "torch.Size([1, 50, 120, 21, 3])\n",
148
+ "torch.Size([1, 30, 21, 3])\n"
149
+ ]
150
+ },
151
+ {
152
+ "data": {
153
+ "text/plain": [
154
+ "torch.Size([1, 50, 120, 21, 3])"
155
+ ]
156
+ },
157
+ "execution_count": 6,
158
+ "metadata": {},
159
+ "output_type": "execute_result"
160
+ }
161
+ ],
162
+ "source": [
163
+ "obs_in = skeleton.tranform_to_input_space(obs) # obs sequence contains hip joints, it has not been dropped yet. \n",
164
+ "pred = get_prediction(obs_in, model, num_samples=num_samples, **config) # [batch_size, n_samples, seq_length, num_joints, features]\n",
165
+ "print(pred.shape)\n",
166
+ "pred = skeleton.transform_to_metric_space(pred)\n",
167
+ "print(obs_in.shape)\n",
168
+ "pred.shape"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "id": "d9394774",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "kpts3d = pred.cpu()[0][0]\n",
179
+ "import matplotlib.pyplot as plt\n",
180
+ "for i in range(120):\n",
181
+ " plt.figure()\n",
182
+ " plt.scatter(kpts3d[0, :, 1], kpts3d[0, :, 2])\n",
183
+ " plt.gca().set_aspect('equal')\n",
184
+ " plt.savefig(f'../../vis/kpts3d_{i}.png')"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 14,
190
+ "id": "f91ee849",
191
+ "metadata": {},
192
+ "outputs": [
193
+ {
194
+ "name": "stdout",
195
+ "output_type": "stream",
196
+ "text": [
197
+ "torch.Size([1, 30, 21, 3])\n",
198
+ "torch.Size([1, 50, 120, 21, 3])\n",
199
+ "torch.Size([1, 30, 20, 3])\n",
200
+ "tensor([[0.0123, 0.0127, 0.0128, 0.0128, 0.0131, 0.0131, 0.0137, 0.0144, 0.0148,\n",
201
+ " 0.0149, 0.0151, 0.0153, 0.0154, 0.0158, 0.0164, 0.0168, 0.0170, 0.0171,\n",
202
+ " 0.0171, 0.0173, 0.0189, 0.0193, 0.0224, 0.0229, 0.0235, 0.0237, 0.0244,\n",
203
+ " 0.0248, 0.0261, 0.0269, 0.0270, 0.0271, 0.0274, 0.0275, 0.0280, 0.0282,\n",
204
+ " 0.0292, 0.0293, 0.0307, 0.0339, 0.0346, 0.0350, 0.0351, 0.0367, 0.0379,\n",
205
+ " 0.0386, 0.0391, 0.0395, 0.0439, 0.0523]], device='cuda:0')\n"
206
+ ]
207
+ },
208
+ {
209
+ "data": {
210
+ "text/plain": [
211
+ "tensor([[ 4, 12, 31, 34, 27, 29, 2, 46, 20, 11, 15, 33, 21, 14, 9, 38, 41, 1,\n",
212
+ " 22, 35, 19, 43, 16, 48, 5, 47, 25, 40, 8, 28, 39, 45, 17, 23, 37, 18,\n",
213
+ " 6, 42, 49, 26, 24, 13, 36, 3, 44, 0, 7, 10, 32, 30]],\n",
214
+ " device='cuda:0')"
215
+ ]
216
+ },
217
+ "execution_count": 14,
218
+ "metadata": {},
219
+ "output_type": "execute_result"
220
+ }
221
+ ],
222
+ "source": [
223
+ "# rank predictions according to Limb Stretching. We will visualize first the prediction that have lower limb stretching --> more realistic\n",
224
+ "from metrics.body_realism import limb_stretching_normed_mean, limb_stretching_normed_rmse\n",
225
+ "print(obs_in.shape)\n",
226
+ "print(pred.shape)\n",
227
+ "print(obs_in[..., 1:, :].shape)\n",
228
+ "# limbstretching = limb_stretching_normed_mean(pred, target=obs[..., 1:, :].unsqueeze(1), limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
229
+ "limbstretching = limb_stretching_normed_rmse(pred, target=obs[..., 1:, :].unsqueeze(1), limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
230
+ "limbstretching_sorted, indices = torch.sort(limbstretching.squeeze(1), dim=-1, descending=False) \n",
231
+ "\n",
232
+ "print(limbstretching_sorted)\n",
233
+ "indices\n",
234
+ "\n",
235
+ "# TO DO: index predictions with these indices.\n"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 27,
241
+ "id": "1a9a941a",
242
+ "metadata": {},
243
+ "outputs": [
244
+ {
245
+ "name": "stdout",
246
+ "output_type": "stream",
247
+ "text": [
248
+ "Observation shape: torch.Size([1, 30, 21, 3])\n",
249
+ "Prediction shape: torch.Size([1, 50, 120, 22, 3])\n",
250
+ "tensor([[0.0446, 0.0711, 0.0517, 0.1149, 0.0689, 0.0238, 0.0374, 0.0329, 0.0411,\n",
251
+ " 0.0573, 0.0565, 0.0855, 0.0375, 0.1141, 0.0402, 0.0385, 0.0564, 0.0727,\n",
252
+ " 0.0904, 0.0620, 0.0374, 0.0363, 0.0443, 0.0386, 0.0702, 0.0413, 0.0455,\n",
253
+ " 0.0468, 0.1038, 0.0691, 0.0630, 0.0320, 0.0489, 0.0422, 0.0520, 0.0756,\n",
254
+ " 0.0444, 0.0414, 0.0852, 0.0673, 0.0391, 0.0500, 0.0484, 0.0457, 0.0556,\n",
255
+ " 0.0393, 0.0674, 0.0349, 0.0392, 0.0459]], device='cuda:0')\n"
256
+ ]
257
+ }
258
+ ],
259
+ "source": [
260
+ "# read pred and obs from predictions/joints3d.npy\n",
261
+ "frames_for_half_second = 30\n",
262
+ "# Load the joints3d data from the saved numpy file\n",
263
+ "joints3d = np.load('/home/stud/yaji/storage/user/yaji/NonisotropicSkeletonDiffusion/predictions/joints3d.npy')\n",
264
+ "\n",
265
+ "# Split the data into observation and prediction parts\n",
266
+ "# The first frames_for_half_second frames are observations\n",
267
+ "obs = joints3d[:, 0, :frames_for_half_second, :, :] # [1, num_samples, frames_for_half_second, 22, 3]\n",
268
+ "pred = joints3d[:, :, frames_for_half_second:, :, :] # [1, num_samples, pred_length, 22, 3]\n",
269
+ "\n",
270
+ "# Convert to torch tensors and move to device\n",
271
+ "obs = torch.from_numpy(obs).to(device)\n",
272
+ "pred = torch.from_numpy(pred).to(device)\n",
273
+ "\n",
274
+ "print(\"Observation shape:\", obs[0, ..., 1:, :].unsqueeze(0).shape)\n",
275
+ "print(\"Prediction shape:\", pred.shape)\n",
276
+ "\n",
277
+ "# calculate \n",
278
+ "limbstretching = limb_stretching_normed_rmse(pred[..., 1:, :], target=obs[0, ..., 1:, :].unsqueeze(0), limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
279
+ "limbstretching_sorted, indices = torch.sort(limbstretching.squeeze(1), dim=-1, descending=False) \n",
280
+ "print(limbstretching)\n"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": 11,
286
+ "id": "e51ccbee",
287
+ "metadata": {},
288
+ "outputs": [
289
+ {
290
+ "name": "stdout",
291
+ "output_type": "stream",
292
+ "text": [
293
+ "torch.Size([1, 120, 21, 3])\n",
294
+ "torch.Size([50, 120, 21, 3])\n",
295
+ "torch.Size([120, 21, 3])\n",
296
+ "torch.Size([10, 120, 21, 3])\n",
297
+ "[8, 18, 36, 30, 37, 43, 26, 17, 41, 6]\n"
298
+ ]
299
+ }
300
+ ],
301
+ "source": [
302
+ "from metrics.ranking import get_closest_and_nfurthest_maxapd\n",
303
+ "# If you see problems with the visualizations, you can remove predictions that have limb stretching > 0.04\n",
304
+ "# limbstretching = limb_stretching_normed_mean(pred, target=obs[..., 1:, :], limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
305
+ "# remove batch dimension\n",
306
+ "y_pred = pred.squeeze(0) # [n_samples, seq_length, num_joints, features]\n",
307
+ "#if GT is not there, we use the first sample as GT reference i.e. the most likely closest to GT\n",
308
+ "y_gt = y_pred[0].unsqueeze(0) # [seq_length, num_joints, features]\n",
309
+ "print(y_gt.shape)\n",
310
+ "print(y_pred.shape)\n",
311
+ "pred_closest, sorted_preds, sorted_preds_idxs = get_closest_and_nfurthest_maxapd(y_pred, y_gt, nsamples=10)\n",
312
+ "\n",
313
+ "print(pred_closest.shape)\n",
314
+ "print(sorted_preds.shape)\n",
315
+ "print(sorted_preds_idxs)\n",
316
+ "\n",
317
+ "\n",
318
+ "\n"
319
+ ]
320
+ }
321
+ ],
322
+ "metadata": {
323
+ "kernelspec": {
324
+ "display_name": "live_demo",
325
+ "language": "python",
326
+ "name": "python3"
327
+ },
328
+ "language_info": {
329
+ "codemirror_mode": {
330
+ "name": "ipython",
331
+ "version": 3
332
+ },
333
+ "file_extension": ".py",
334
+ "mimetype": "text/x-python",
335
+ "name": "python",
336
+ "nbconvert_exporter": "python",
337
+ "pygments_lexer": "ipython3",
338
+ "version": "3.10.13"
339
+ }
340
+ },
341
+ "nbformat": 4,
342
+ "nbformat_minor": 5
343
+ }
SkeletonDiffusion/inference_filtered.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+
SkeletonDiffusion/setup.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="SkeletonDiffusion",
5
+ version="0.1.0",
6
+ package_dir={"": "src"},
7
+ packages=find_packages(where="src"),
8
+ install_requires=[
9
+ "torch",
10
+ "numpy",
11
+ ],
12
+ python_requires=">=3.8",
13
+ )
SkeletonDiffusion/src/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Add the src directory to the Python path
5
+ src_path = os.path.dirname(os.path.abspath(__file__))
6
+ if src_path not in sys.path:
7
+ sys.path.insert(0, src_path)
SkeletonDiffusion/src/config_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from ignite.handlers import Checkpoint
4
+ import os
5
+
6
+ from SkeletonDiffusion.src.utils.config import init_obj
7
+ from SkeletonDiffusion.src.utils.reproducibility import seed_worker, seed_eval_worker, RandomStateDict
8
+ import SkeletonDiffusion.src.data.loaders as dataset_type
9
+
10
+ from SkeletonDiffusion.src.utils.load import get_latest_model_path
11
+ from SkeletonDiffusion.src.inference_utils import create_model
12
+
13
+
14
+
15
+ def create_train_dataloaders(batch_size, batch_size_eval, num_workers, skeleton, if_run_validation=True, if_resume_training=False, **config):
16
+
17
+ random_state_manager = RandomStateDict()
18
+ if if_resume_training:
19
+ checkpoint = torch.load(config['load_path'])
20
+ Checkpoint.load_objects(to_load={"random_states": random_state_manager}, checkpoint=checkpoint)
21
+ dataset_train = init_obj(config, 'dataset_type', dataset_type, split="train", skeleton=skeleton, **(config['data_loader_train']))
22
+ data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, worker_init_fn = seed_worker, pin_memory= True,
23
+ num_workers=num_workers, generator=random_state_manager.generator)
24
+
25
+ if if_run_validation:
26
+ dataset_eval = init_obj(config, 'dataset_type', dataset_type, split="valid", skeleton=skeleton, **(config['data_loader_valid']))
27
+ dataset_eval_train = init_obj(config, 'dataset_type', dataset_type, split="train", skeleton=skeleton, **(config['data_loader_train_eval']))
28
+ data_loader_eval = DataLoader(dataset_eval, shuffle=False, worker_init_fn=seed_eval_worker, batch_size=batch_size_eval, num_workers=1, pin_memory= True)
29
+ data_loader_train_eval = DataLoader(dataset_eval_train, shuffle=False, worker_init_fn=seed_eval_worker, batch_size=batch_size_eval, num_workers=1, pin_memory= True)
30
+ else:
31
+ data_loader_eval = None
32
+ data_loader_train_eval = None
33
+ return data_loader_train, data_loader_eval, data_loader_train_eval, random_state_manager
34
+
35
+
36
+ def flat_hydra_config(cfg):
37
+ """
38
+ Flatten the main dict categories of the Hydra config object into a single one.
39
+ """
40
+ for subconf in ['model', 'task', 'dataset', 'autoenc_arch', 'cov_matrix']:
41
+ if subconf in cfg:
42
+ cfg = {**cfg, **cfg[subconf]}
43
+ cfg.pop(subconf)
44
+ return cfg
45
+
46
+ def resume_training(cfg):
47
+ #output folder has been already created.
48
+ assert 'output_log_path' in cfg
49
+
50
+ # decide whether to start from scratch (default if no checkpoints), from latest save (if checkpoints exists), or from given path (if load_path is given)
51
+
52
+ assert len(os.listdir(os.path.join(cfg['output_log_path'], 'checkpoints'))) != 0, "Checkpoints folder is empty. Please provide a valid path to load from."
53
+
54
+ if len(cfg['load_path']) == 0:
55
+ # load latest model
56
+ cfg['load_path'] = get_latest_model_path(os.path.join(cfg['output_log_path'], 'checkpoints'))
57
+ print("Loading latest epoch: ", cfg['load_path'].split('/')[-1])
58
+ else:
59
+ output_path = os.path.dirname(os.path.dirname(cfg['load_path']))
60
+ assert cfg['output_log_path'] == output_path
61
+ cfg['output_log_path'] = output_path
62
+ return cfg
SkeletonDiffusion/src/core/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Import network first
2
+ from .network import AutoEncoder, Denoiser
3
+
4
+ # Then import diffusion manager
5
+ from .diffusion_manager import DiffusionManager
6
+
7
+ # Export all
8
+ __all__ = ['AutoEncoder', 'Denoiser', 'DiffusionManager']
SkeletonDiffusion/src/core/diffusion/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from SkeletonDiffusion.src.core.diffusion.isotropic import IsotropicGaussianDiffusion
2
+ from SkeletonDiffusion.src.core.diffusion.nonisotropic import NonisotropicGaussianDiffusion
3
+ from SkeletonDiffusion.src.core.diffusion.utils import get_cov_from_corr
SkeletonDiffusion/src/core/diffusion/base.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from random import random
3
+ from functools import partial
4
+ from collections import namedtuple
5
+ from typing import Tuple, Optional, List, Union, Dict
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from torch.cuda.amp import autocast
11
+
12
+ from einops import reduce
13
+
14
+
15
+ from tqdm.auto import tqdm
16
+
17
+
18
+ # constants
19
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
20
+
21
+ # helper functions
22
+
23
+ def identity(t, *args, **kwargs):
24
+ return t
25
+
26
+ def exists(x):
27
+ return x is not None
28
+
29
+ def default(val, d):
30
+ if exists(val):
31
+ return val
32
+ return d() if callable(d) else d
33
+
34
+ def extract(a, t, x_shape):
35
+ b, *_ = t.shape
36
+ out = a.gather(-1, t)
37
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
38
+
39
+ def linear_beta_schedule(timesteps):
40
+ scale = 1000 / timesteps
41
+ beta_start = scale * 0.0001
42
+ beta_end = scale * 0.02
43
+ return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
44
+
45
+ def cosine_beta_schedule(timesteps, s = 0.008):
46
+ """
47
+ cosine schedule
48
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
49
+ """
50
+ steps = timesteps + 1
51
+ x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
52
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
53
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
54
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
55
+ return torch.clip(betas, 0, 0.999)
56
+
57
+ def exp_beta_schedule(timesteps, factor=3.0):
58
+ steps = timesteps + 1
59
+ x = torch.linspace(-factor, 0, steps, dtype = torch.float64)#/timesteps
60
+ betas = torch.exp(x)
61
+ return torch.clip(betas, 0, 0.999)
62
+
63
+
64
+ class LatentDiffusion(nn.Module):
65
+ def __init__(self,
66
+ model:torch.nn.Module, latent_size=96, diffusion_timesteps=10, diffusion_objective="pred_x0", sampling_timesteps=None, diffusion_activation='identity',
67
+ silent=True, diffusion_conditioning=False, diffusion_loss_type='mse',
68
+ objective = 'pred_noise',
69
+ beta_schedule = 'cosine',
70
+ beta_schedule_factor=3.0,
71
+ ddim_sampling_eta = 0.,
72
+ **kwargs
73
+ ):
74
+
75
+ super().__init__()
76
+
77
+
78
+ if diffusion_activation == "tanh":
79
+ self.activation = torch.nn.Tanh()
80
+ elif diffusion_activation == "identity":
81
+ self.activation = torch.nn.Identity()
82
+ self.silent = silent
83
+ self.condition = diffusion_conditioning
84
+ self.loss_type = diffusion_loss_type
85
+
86
+ self.statistics_pred = None
87
+ self.statistics_obs = None
88
+
89
+
90
+ timesteps=diffusion_timesteps
91
+ objective=diffusion_objective
92
+
93
+
94
+ self.model = model
95
+ self.channels = self.model.channels
96
+ self.self_condition = self.model.self_condition
97
+
98
+ self.seq_length = latent_size
99
+ self.objective = objective
100
+
101
+ assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
102
+
103
+ if beta_schedule == 'linear':
104
+ betas = linear_beta_schedule(timesteps)
105
+ elif beta_schedule == 'cosine':
106
+ betas = cosine_beta_schedule(timesteps)
107
+ elif beta_schedule == 'exp':
108
+ betas = exp_beta_schedule(timesteps, beta_schedule_factor)
109
+ else:
110
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
111
+
112
+ alphas = 1. - betas
113
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
114
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
115
+
116
+ timesteps, = betas.shape
117
+ self.num_timesteps = int(timesteps)
118
+
119
+ # sampling related parameters
120
+
121
+ self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
122
+
123
+ assert self.sampling_timesteps <= timesteps
124
+ self.is_ddim_sampling = self.sampling_timesteps < timesteps
125
+ self.ddim_sampling_eta = ddim_sampling_eta
126
+
127
+ # helper function to register buffer from float64 to float32
128
+
129
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
130
+
131
+ register_buffer('betas', betas)
132
+ register_buffer('alphas_cumprod', alphas_cumprod)
133
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
134
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
135
+
136
+ print("Diffusion is_ddim_sampling: ", self.is_ddim_sampling)
137
+
138
+
139
+
140
+ def set_normalization_statistics(self, statistics_pred, statistics_obs):
141
+ self.statistics_pred = statistics_pred
142
+ self.statistics_obs = statistics_obs
143
+ print("Setting normalization statistics for diffusion")
144
+
145
+ def get_white_noise(self, x, *args, **kwargs):
146
+ return self.get_noise(x, *args, **kwargs)
147
+
148
+ def get_start_noise(self, x, *args, **kwargs):
149
+ return self.get_white_noise(x, *args, **kwargs)
150
+
151
+ def get_noise(self, x, *args, **kwargs):
152
+ """
153
+ x is either tensor or shape
154
+ """
155
+ if torch.is_tensor(x):
156
+ return torch.randn_like(x)
157
+ elif isinstance(x, tuple):
158
+ return torch.randn(*x, *args, **kwargs)
159
+
160
+ #######################################################################
161
+ # TO SUBCLASS
162
+ #######################################################################
163
+
164
+ def predict_start_from_noise(self, x_t, t, noise):
165
+ assert 0, "Not implemented"
166
+ ...
167
+ return x_t
168
+
169
+ def predict_noise_from_start(self, x_t, t, x0):
170
+ assert 0, "Not implemented"
171
+ ...
172
+ return x_t
173
+
174
+ def predict_v(self, x_start, t, noise):
175
+ assert 0, "Not implemented"
176
+ ...
177
+ return x_start
178
+
179
+ def predict_start_from_v(self, x_t, t, v):
180
+ assert 0, "Not implemented"
181
+ ...
182
+ return x_t
183
+
184
+ @autocast(enabled = False)
185
+ def q_sample(self, x_start, t, noise=None):
186
+ assert 0, "Not implemented"
187
+ ...
188
+ return x_start
189
+
190
+ def q_posterior(self, x_start, x_t, t):
191
+ assert 0, "Not implemented"
192
+ ...
193
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
194
+
195
+ def p_combine_mean_var_noise(self, model_mean, model_log_variance, noise):
196
+ assert 0, "Not implemented"
197
+ ...
198
+ return model_mean
199
+
200
+ def p_interpolate_mean_var_noise(self, model_mean, model_log_variance, noise, node_idx:Optional[int] = None, interpolate_factor=0.0, noise2interpolate=None, **kwargs):
201
+ assert 0, "Not implemented"
202
+ ...
203
+ return model_mean
204
+
205
+ def loss_funct(self, model_out, target, *args, **kwargs):
206
+ if self.loss_type == "mse":
207
+ loss = F.mse_loss(model_out, target, reduction = 'none')
208
+ elif self.loss_type == 'l1':
209
+ loss = F.l1_loss(model_out, target, reduction = 'none')
210
+ else:
211
+ assert 0, "Not implemented"
212
+ return loss
213
+
214
+ ########################################################################
215
+ # NETWORK INTERFACE
216
+ #########################################################################
217
+
218
+
219
+ def model_predictions(self, x, t, x_self_cond = None, x_cond=None, clip_x_start = False, rederive_pred_noise = False):
220
+ model_output = self.feed_model(x, t, x_self_cond=x_self_cond, x_cond=x_cond)
221
+ maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
222
+
223
+ if self.objective == 'pred_noise':
224
+ pred_noise = model_output
225
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
226
+ x_start = maybe_clip(x_start)
227
+
228
+ if clip_x_start and rederive_pred_noise:
229
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
230
+
231
+ elif self.objective == 'pred_x0':
232
+ x_start = model_output
233
+ x_start = maybe_clip(x_start)
234
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
235
+
236
+ elif self.objective == 'pred_v':
237
+ v = model_output
238
+ x_start = self.predict_start_from_v(x, t, v)
239
+ x_start = maybe_clip(x_start)
240
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
241
+ return ModelPrediction(pred_noise, x_start)
242
+
243
+ def feed_model(self, x, t, x_self_cond = None, x_cond=None):
244
+ if self.condition:
245
+ assert x_cond is not None
246
+ if x.shape[0] > x_cond.shape[0]:
247
+ # training with multiple samples
248
+ x_cond = x_cond.repeat_interleave( int(x.shape[0]/x_cond.shape[0]), 0)
249
+ model_in = x
250
+ else:
251
+ model_in = x
252
+
253
+ model_output = self.model(model_in, t, x_self_cond, x_cond)
254
+ model_output = self.activation(model_output)
255
+ return model_output
256
+
257
+ ########################################################################
258
+ # FORWARD PROCESS
259
+ #########################################################################
260
+
261
+
262
+ def p_losses(self, x_start, t, noise = None, x_cond=None, n_train_samples=1):
263
+ b, c, n = x_start.shape
264
+ if n_train_samples > 1:
265
+ x_start = x_start.repeat_interleave(n_train_samples, dim=0)
266
+ t = t.repeat_interleave(n_train_samples, dim=0)
267
+ if x_cond is not None:
268
+ x_cond = x_cond.repeat_interleave(n_train_samples, dim=0)
269
+ noise = default(noise, self.get_white_noise(x_start, t)) # noise for timesteps t
270
+
271
+ # noise sample
272
+
273
+ x = self.q_sample(x_start = x_start, t = t, noise = noise)
274
+
275
+ # if doing self-conditioning, 50% of the time, predict x_start from current set of times
276
+ # and condition with unet with that
277
+ # this technique will slow down training by 25%, but seems to lower FID significantly
278
+
279
+ x_self_cond = None
280
+ if self.self_condition and random.random() < 0.5:
281
+ with torch.no_grad():
282
+ x_self_cond = self.model_predictions(x, t, x_cond=x_cond).pred_x_start
283
+ x_self_cond.detach_()
284
+
285
+ # predict and take gradient step
286
+ model_out = self.feed_model(x, t, x_self_cond=x_self_cond, x_cond=x_cond)
287
+
288
+ if self.objective == 'pred_noise':
289
+ target = noise
290
+ elif self.objective == 'pred_x0':
291
+ target = x_start
292
+ elif self.objective == 'pred_v':
293
+ v = self.predict_v(x_start, t, noise)
294
+ target = v
295
+ else:
296
+ raise ValueError(f'unknown objective {self.objective}')
297
+ loss = self.loss_funct(model_out, target, t)
298
+ loss = reduce(loss, 'b ... -> b', 'mean') # [batch*n_train_samples, Nodes, latent_dim] -> [batch*n_train_samples]
299
+
300
+ return loss, extract(self.loss_weight, t.view(b, -1)[:, 0], loss.shape[0:1]), model_out
301
+
302
+ def forward(self, x, *args, x_cond=None, **kwargs):
303
+ b, c, n, device, seq_length, = *x.shape, x.device, self.seq_length
304
+ assert n == seq_length, f'seq length must be {seq_length}'
305
+ t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
306
+
307
+ return self.p_losses(x, t, *args, x_cond=x_cond, **kwargs)
308
+
309
+
310
+ ########################################################################
311
+ # REVERSE PROCESS
312
+ #########################################################################
313
+
314
+ def p_mean_variance(self, x, t, x_self_cond = None, x_cond=None, clip_denoised = True):
315
+ preds = self.model_predictions(x, t, x_self_cond, x_cond=x_cond)
316
+ x_start = preds.pred_x_start
317
+
318
+ if clip_denoised:
319
+ x_start.clamp_(-1., 1.)
320
+
321
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
322
+ return model_mean, posterior_variance, posterior_log_variance, x_start
323
+
324
+ @torch.no_grad()
325
+ def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True, sampling_noise=None, *args, if_interpolate=False, noise2interpolate=None, interpolation_kwargs:Dict=None, **kwargs):
326
+ b, *_, device = *x.shape, x.device
327
+ batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
328
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised, *args, **kwargs)
329
+
330
+ if sampling_noise is not None and t > 0:
331
+ noise = sampling_noise[:, sampling_noise.shape[1]-t]
332
+ else:
333
+ noise = self.get_white_noise(x) if t > 0 else 0. # no noise if t == 0
334
+
335
+ if if_interpolate and t > 0:
336
+ noise2 = noise2interpolate[:, sampling_noise.shape[1]-t]
337
+ assert noise2.shape == noise.shape
338
+ pred_img = self.p_interpolate_mean_var_noise(model_mean, model_log_variance, noise, noise2, **interpolation_kwargs)
339
+ else:
340
+ pred_img = self.p_combine_mean_var_noise(model_mean, model_log_variance, noise)
341
+ return pred_img, x_start, noise, model_mean
342
+
343
+ @torch.no_grad()
344
+ def p_sample_loop(self, shape, x_cond=None, start_noise=None, sampling_noise=None, return_sampling_noise=False, return_timages=False, **kwargs):
345
+ batch, device = shape[0], self.betas.device
346
+ if start_noise is not None:
347
+ assert start_noise.shape == shape, f"Shape mismatch: {start_noise.shape} != {shape}"
348
+ img = start_noise
349
+ noise = start_noise.clone()
350
+ else:
351
+ img = self.get_start_noise(shape, device=device)
352
+ noise = img.clone()
353
+
354
+ if sampling_noise is not None:
355
+ assert sampling_noise.shape[2:] == shape[1:], f"Shape mismatch: {start_noise.shape} != {shape}"
356
+ assert sampling_noise.shape[0] == shape[0], f"Shape mismatch: {start_noise.shape} != {shape}"
357
+ assert sampling_noise.shape[1] == self.num_timesteps - 1
358
+
359
+ x_start = None
360
+ imgs = []
361
+ noise_t = []
362
+ mean_t = []
363
+ if not self.silent:
364
+ print(f"Evaluation with {len(range(0, self.num_timesteps))} diffusion steps")
365
+ for t in reversed(range(0, self.num_timesteps)): #, desc = 'sampling loop time step', total = self.num_timesteps):
366
+ self_cond = x_start if self.self_condition else None
367
+ img, x_start, nt, model_mean = self.p_sample(img, t, self_cond, x_cond=x_cond, sampling_noise=sampling_noise, **kwargs)
368
+ if return_sampling_noise and t!=0:
369
+ noise_t.append(nt)
370
+ mean_t.append(model_mean)
371
+ if return_timages and t!=0:
372
+ imgs.append(img)
373
+
374
+
375
+ if return_sampling_noise:
376
+ noise_t = torch.stack(noise_t, dim=1)
377
+ mean_t = torch.stack(mean_t, dim=1)
378
+ if return_timages:
379
+ print("Returning timages")
380
+ imgs = torch.stack(imgs, dim=1)
381
+
382
+ if return_sampling_noise:
383
+ if return_timages:
384
+ noise = (noise, noise_t, imgs)
385
+ else:
386
+ noise = (noise, noise_t, mean_t)
387
+ else:
388
+ if return_timages:
389
+ noise = (noise, imgs)
390
+ return img, noise
391
+
392
+ @torch.no_grad()
393
+ def ddim_sample(self, shape, clip_denoised = True, x_cond=None, start_noise=None):
394
+ batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
395
+
396
+ times = list(reversed(times.int().tolist()))
397
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
398
+
399
+ if start_noise is not None:
400
+ assert start_noise.shape == shape
401
+ img = start_noise
402
+ noise = start_noise.clone()
403
+ else:
404
+ img = torch.randn(shape, device=device)
405
+ noise = img.clone()
406
+
407
+ imgs = []
408
+
409
+ x_start = None
410
+ if not self.silent:
411
+ print(f"Evaluation with {len(time_pairs)} diffusion steps")
412
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
413
+ time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
414
+ self_cond = x_start if self.self_condition else None
415
+ pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, x_cond=x_cond, clip_x_start = clip_denoised)
416
+
417
+ if time_next < 0:
418
+ img = x_start
419
+ # imgs.append(img)
420
+
421
+ alpha = self.alphas_cumprod[time]
422
+ alpha_next = self.alphas_cumprod[time_next]
423
+ sqrt_alpha_next = self.sqrt_alphas_cumprod[time_next]
424
+
425
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() #eta*beta_t_tilde
426
+ c = (1 - alpha_next - sigma ** 2).sqrt()
427
+
428
+ noise = torch.randn_like(img)
429
+
430
+ img = x_start * sqrt_alpha_next + \
431
+ c * pred_noise + \
432
+ sigma * noise
433
+ # imgs.append(img)
434
+
435
+ # imgs = torch.stack(imgs, dim=1)
436
+ return img, noise
437
+
438
+
439
+ @torch.no_grad()
440
+ def sample(self, batch_size = 16, *args, **kwargs):
441
+ seq_length, channels = self.seq_length, self.channels
442
+ sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
443
+ return sample_fn((batch_size, channels, seq_length),*args, **kwargs)
444
+
445
+
SkeletonDiffusion/src/core/diffusion/isotropic.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.cuda.amp import autocast
3
+
4
+ from SkeletonDiffusion.src.core.diffusion.base import LatentDiffusion, extract, default
5
+
6
+ class IsotropicGaussianDiffusion(LatentDiffusion):
7
+ def __init__(self, **kwargs):
8
+ super().__init__( **kwargs)
9
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
10
+
11
+ # calculations for diffusion q(x_t | x_{t-1}) and others
12
+
13
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
14
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - self.alphas_cumprod))
15
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod))
16
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod - 1))
17
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
18
+
19
+ posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
20
+ alphas = 1. - self.betas
21
+
22
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
23
+
24
+ register_buffer('posterior_variance', posterior_variance)
25
+
26
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
27
+
28
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
29
+ register_buffer('posterior_mean_coef1', self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod))
30
+ register_buffer('posterior_mean_coef2', (1. - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - self.alphas_cumprod))
31
+
32
+ # calculate loss weight
33
+ snr = self.alphas_cumprod / (1 - self.alphas_cumprod)
34
+
35
+ if self.objective == 'pred_noise':
36
+ loss_weight = torch.ones_like(snr)
37
+ elif self.objective == 'pred_x0':
38
+ loss_weight = snr
39
+ elif self.objective == 'pred_v':
40
+ loss_weight = snr / (snr + 1)
41
+
42
+ register_buffer('loss_weight', loss_weight)
43
+
44
+ ########################################################################
45
+ # FORWARD PROCESS
46
+ #########################################################################
47
+
48
+ def predict_start_from_noise(self, x_t, t, noise):
49
+ return (
50
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
51
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
52
+ )
53
+
54
+ def predict_noise_from_start(self, x_t, t, x0):
55
+ return (
56
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
57
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
58
+ )
59
+
60
+ def predict_v(self, x_start, t, noise):
61
+ return (
62
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
63
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
64
+ )
65
+
66
+ def predict_start_from_v(self, x_t, t, v):
67
+ return (
68
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
69
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
70
+ )
71
+
72
+ @autocast(enabled = False)
73
+ def q_sample(self, x_start, t, noise=None):
74
+ noise = default(noise, lambda: self.get_white_noise(x_start))
75
+
76
+ return (
77
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
78
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
79
+ )
80
+
81
+ ########################################################################
82
+ # REVERSE PROCESS
83
+ #########################################################################
84
+
85
+ def q_posterior(self, x_start, x_t, t):
86
+ posterior_mean = (
87
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
88
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
89
+ )
90
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
91
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
92
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
93
+
94
+ def p_combine_mean_var_noise(self, model_mean, model_log_variance, noise):
95
+ return model_mean + (0.5 * model_log_variance).exp() * noise
96
+
97
+ def interpolate_noise(self, noise1, noise2, interpolate_funct=None, **kwargs):
98
+ interpolated_noise = interpolate_funct(noise1, noise2)
99
+ return interpolated_noise
100
+
101
+ def p_interpolate_mean_var_noise(self, model_mean, model_log_variance, noise, noise2interpolate=None, **kwargs):
102
+ interpolated_noise = self.interpolate_noise(noise, noise2interpolate,**kwargs)
103
+ return model_mean + (0.5 * model_log_variance).exp() * interpolated_noise
104
+
SkeletonDiffusion/src/core/diffusion/nonisotropic.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.cuda.amp import autocast
3
+ from .base import LatentDiffusion, extract, default
4
+
5
+ def extract_matrix(matrix, t, x_shape):
6
+ b, *_ = t.shape
7
+ T, N, *_ = matrix.shape
8
+ out = torch.index_select(matrix, 0, t)
9
+ out = out.reshape(b, *out.shape[1:])
10
+ while len(x_shape) > len(out.shape):
11
+ out = out.unsqueeze(-1)
12
+ return out
13
+
14
+ def verify_noise_scale(diffusion):
15
+ N, *_ = diffusion.Lambda_N.shape
16
+ alphas = 1 - diffusion.betas
17
+ noise = diffusion.get_noise((2000, diffusion.num_timesteps, N))
18
+ zeta_noise = torch.sqrt(diffusion.Lambda_t.unsqueeze(0)) * noise
19
+ print("current: ", (zeta_noise**2).sum(-1).mean(0))
20
+ print("original standard gaussian diffusion: ",(1-alphas) * zeta_noise.shape[-1])
21
+
22
+ def compute_covariance_matrices(diffusion: torch.nn.Module, Lambda_N: torch.Tensor, diffusion_covariance_type='ani-isotropic', gamma_scheduler = 'cosine'):
23
+ N, *_ = Lambda_N.shape
24
+ alphas = 1. - diffusion.betas
25
+ def _alpha_sumprod(alphas, t):
26
+ return torch.sum(torch.cumprod(torch.flip(alphas[:t+1], [0]), dim=0))
27
+ alphas_sumprod = torch.stack([_alpha_sumprod(alphas, t) for t in range(len(alphas))], dim=0)
28
+ diffusion.alphas_sumprod = alphas_sumprod
29
+ if diffusion_covariance_type == 'isotropic':
30
+ assert (Lambda_N == 0).all()
31
+ Lambda_t = (1-alphas).unsqueeze(-1) # (Tdiff, N)
32
+ Lambda_bar_t = (1-diffusion.alphas_cumprod.unsqueeze(-1))
33
+ Lambda_bar_t_prev = torch.cat([torch.zeros(1).unsqueeze(0), Lambda_bar_t[:-1]], dim=0)
34
+ elif diffusion_covariance_type == 'anisotropic':
35
+ Lambda_t = (1-alphas.unsqueeze(-1))*Lambda_N # (Tdiff, N)
36
+ Lambda_bar_t = (1-diffusion.alphas_cumprod.unsqueeze(-1))*Lambda_N
37
+ Lambda_bar_t_prev = (1-diffusion.alphas_cumprod_prev.unsqueeze(-1))*Lambda_N
38
+ elif diffusion_covariance_type == 'skeleton-diffusion':
39
+ if gamma_scheduler== 'cosine':
40
+ gammas = 1 - alphas
41
+ elif gamma_scheduler == 'mono_decrease':
42
+ gammas = 1 - torch.arange(0, diffusion.num_timesteps)/diffusion.num_timesteps
43
+ else:
44
+ assert 0, "Not implemented"
45
+ Lambda_I = Lambda_N - 1
46
+ gammas_bar = (1-alphas)*gammas
47
+ gammas_tilde = diffusion.alphas_cumprod*torch.cumsum(gammas_bar/diffusion.alphas_cumprod, dim=-1)
48
+ Lambda_t = Lambda_I.unsqueeze(0)*gammas_bar.unsqueeze(-1) + (1-alphas).unsqueeze(-1) # (Tdiff, N)
49
+ Lambda_bar_t = Lambda_I.unsqueeze(0)*gammas_tilde.unsqueeze(-1) + (1-diffusion.alphas_cumprod.unsqueeze(-1))
50
+ Lambda_bar_t_prev = torch.cat([torch.zeros(N).unsqueeze(0), Lambda_bar_t[:-1]], dim=0) # we start from det so it must be zero for t=-1
51
+ else:
52
+ assert 0, "Not implemented"
53
+
54
+ return Lambda_t, Lambda_bar_t, Lambda_bar_t_prev
55
+
56
+
57
+ class NonisotropicGaussianDiffusion(LatentDiffusion):
58
+ def __init__(self, Sigma_N: torch.Tensor, Lambda_N: torch.Tensor, U: torch.Tensor, diffusion_covariance_type='skeleton-diffusion', loss_reduction_type='l1', gamma_scheduler = 'cosine', **kwargs):
59
+ super().__init__( **kwargs)
60
+ alphas = 1. - self.betas
61
+
62
+
63
+ N, _ = Sigma_N.shape
64
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
65
+ register_buffer('Lambda_N', Lambda_N)
66
+ register_buffer('Sigma_N', Sigma_N)
67
+ self.set_rotation_matrix(U)
68
+
69
+
70
+ Lambda_t, Lambda_bar_t, Lambda_bar_t_prev = compute_covariance_matrices(diffusion=self, Lambda_N=Lambda_N,
71
+ diffusion_covariance_type=diffusion_covariance_type, gamma_scheduler=gamma_scheduler)
72
+
73
+ def create_diagonal_matrix(diagonal_vector):
74
+ return torch.stack([torch.diag(diag) for diag in diagonal_vector], dim=0) # [T, N, N]
75
+ ######### forward , for training and inference #####################
76
+ #predict_noise_from_start
77
+ inv_sqrt_Lambda_bar = 1/torch.sqrt(Lambda_bar_t)
78
+ inv_sqrt_Lambda_bar_sqrt_alphas_cumprod = (1/torch.sqrt(Lambda_bar_t))*self.sqrt_alphas_cumprod.unsqueeze(-1)
79
+ register_buffer('inv_sqrt_Lambda_bar_mmUt', create_diagonal_matrix(inv_sqrt_Lambda_bar)@self.U_transposed.unsqueeze(0))
80
+ register_buffer('inv_sqrt_Lambda_bar_sqrt_alphas_cumprod_mmUt', create_diagonal_matrix(inv_sqrt_Lambda_bar_sqrt_alphas_cumprod)@self.U_transposed.unsqueeze(0))
81
+ #predict_start_from_noise
82
+ sqrt_Lambda_bar = torch.sqrt(Lambda_bar_t)
83
+ sqrt_Lambda_bar_sqrt_recip_alphas_cumprod = torch.sqrt(Lambda_bar_t/self.alphas_cumprod.unsqueeze(-1))
84
+ register_buffer('Umm_sqrt_Lambda_bar_t', U.unsqueeze(0)@create_diagonal_matrix(sqrt_Lambda_bar))
85
+ register_buffer('Umm_sqrt_Lambda_bar_t_sqrt_recip_alphas_cumprod', U.unsqueeze(0)@create_diagonal_matrix(sqrt_Lambda_bar_sqrt_recip_alphas_cumprod))
86
+
87
+ ######### q_posterior , for reverse process #####################
88
+ #q_posterior
89
+ Lambda_posterior_t = Lambda_t*Lambda_bar_t_prev*(1/Lambda_bar_t)
90
+ sqrt_alphas_cumprod_prev = torch.sqrt(self.alphas_cumprod_prev)
91
+ register_buffer('Lambda_posterior', Lambda_posterior_t)
92
+ register_buffer('Lambda_posterior_log_variance_clipped', torch.log(Lambda_posterior_t.clamp(min =1e-20)))
93
+
94
+ posterior_mean_coef1_x0 = sqrt_alphas_cumprod_prev.unsqueeze(-1).unsqueeze(-1)*(U.unsqueeze(0)@create_diagonal_matrix((1/Lambda_bar_t)*Lambda_t)@self.U_transposed.unsqueeze(0))
95
+ posterior_mean_coef2_xt = torch.sqrt(alphas).unsqueeze(-1).unsqueeze(-1)*(U.unsqueeze(0)@create_diagonal_matrix((1/Lambda_bar_t)*Lambda_bar_t_prev)@self.U_transposed.unsqueeze(0))
96
+ register_buffer('posterior_mean_coef1_x0', posterior_mean_coef1_x0)
97
+ register_buffer('posterior_mean_coef2_xt', posterior_mean_coef2_xt)
98
+
99
+ ######### loss #####################
100
+ self.loss_reduction_type = loss_reduction_type
101
+ sqrt_recip_Lambda_bar_t = torch.sqrt(1. / Lambda_bar_t)
102
+ register_buffer('mahalanobis_S_sqrt_recip', create_diagonal_matrix(sqrt_recip_Lambda_bar_t)@self.U_transposed.unsqueeze(0))
103
+
104
+ if self.objective == 'pred_noise':
105
+ loss_weight = torch.ones_like(alphas)
106
+ elif self.objective == 'pred_x0':
107
+ loss_weight = self.alphas_cumprod
108
+ elif self.objective == 'pred_v':
109
+ assert 0, "Not implemented"
110
+ # loss_weight = snr / (snr + 1)
111
+ register_buffer('loss_weight', loss_weight)
112
+
113
+ assert not len(self.mahalanobis_S_sqrt_recip.shape) == 1
114
+
115
+
116
+ ########################################################################
117
+ # CLASS FUNCTIONS
118
+ #########################################################################
119
+
120
+ def set_rotation_matrix(self, U:torch.Tensor):
121
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
122
+ register_buffer('U', U)
123
+ register_buffer('U_transposed', U.t())
124
+
125
+ def check_eigh(self):
126
+ return torch.isclose(self.U@torch.diag(self.Lambda_N)@self.U_transposed,self.Sigma_N)#.all(), "U@Lambda_N@U^T must be equal to Sigma_N"
127
+
128
+ def get_anisotropic_noise(self, x, *args, **kwargs):
129
+ """
130
+ x is either tensor or shape
131
+ """
132
+ return self.get_noise(x, *args, **kwargs)*self.Lambda_N.unsqueeze(-1)
133
+
134
+ ########################################################################
135
+ # FORWARD PROCESS
136
+ #########################################################################
137
+
138
+ @autocast(enabled = False)
139
+ def q_sample(self, x_start, t, noise=None):
140
+ noise = default(noise, lambda: self.get_white_noise(x_start))
141
+
142
+ return (
143
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
144
+ extract_matrix(self.Umm_sqrt_Lambda_bar_t, t, x_start.shape) @ noise
145
+ )
146
+ # for inference
147
+ def predict_start_from_noise(self, x_t, t, noise):
148
+ return (
149
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
150
+ extract_matrix(self.Umm_sqrt_Lambda_bar_t_sqrt_recip_alphas_cumprod, t, x_t.shape) @ noise
151
+ )
152
+ # for inference
153
+ def predict_noise_from_start(self, x_t, t, x0):
154
+ return (
155
+ extract_matrix(self.inv_sqrt_Lambda_bar_mmUt, t, x_t.shape)@x_t -\
156
+ extract_matrix(self.inv_sqrt_Lambda_bar_sqrt_alphas_cumprod_mmUt, t, x_t.shape)@x0
157
+ )
158
+
159
+ ########################################################################
160
+ # LOSS
161
+ #########################################################################
162
+
163
+ def mahalanobis_dist(self, matrix, vector):
164
+ return (matrix@vector).abs() # check shape
165
+
166
+ def loss_funct(self, model_out, target, t):
167
+ difference = target - model_out if self.objective == 'pred_noise' else model_out - target
168
+
169
+ loss = self.mahalanobis_dist(extract_matrix(self.mahalanobis_S_sqrt_recip, t, difference.shape), difference)
170
+ if self.loss_reduction_type == 'l1':
171
+ loss = loss
172
+ elif self.loss_reduction_type == 'mse':
173
+ loss = loss**2
174
+ else:
175
+ assert 0, "Not implemented"
176
+ return loss
177
+
178
+ ########################################################################
179
+ # REVERSE PROCESS
180
+ #########################################################################
181
+
182
+ def q_posterior_mean(self, x_start, x_t, t):
183
+ return (
184
+ extract_matrix(self.posterior_mean_coef1_x0, t, x_t.shape) @ x_start +
185
+ extract_matrix(self.posterior_mean_coef2_xt, t, x_t.shape) @ x_t
186
+ )
187
+
188
+ def q_posterior(self, x_start, x_t, t):
189
+ posterior_mean = self.q_posterior_mean(x_start, x_t, t)
190
+ posterior_variance = extract_matrix(self.Lambda_posterior, t, x_t.shape)
191
+ posterior_log_variance_clipped = extract_matrix(self.Lambda_posterior_log_variance_clipped, t, x_t.shape)
192
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
193
+
194
+ def p_combine_mean_var_noise(self, model_mean, posterior_log_variance, noise):
195
+ """ mean is in not diagonal coordinate system, posterior_log_variance is in diagonal coordinate system"""
196
+ return model_mean + self.U@((0.5 * posterior_log_variance).exp() * noise)
197
+
198
+
199
+ ########################################################################
200
+ # INTERPOLATION
201
+ #########################################################################
202
+
203
+ def interpolate_noise(self, noise1, noise2, posterior_log_variance=None, interpolate_funct=None):
204
+ noise1 = self.U@((0.5 * posterior_log_variance).exp() * noise1)
205
+ noise2 = self.U@((0.5 * posterior_log_variance).exp() * noise2)
206
+ interpolated_noise = interpolate_funct(noise1, noise2)
207
+ return interpolated_noise
208
+
209
+
210
+ def p_interpolate_mean_var_noise(self, model_mean, model_log_variance, noise, noise2interpolate=None, **kwargs):
211
+ interpolated_noise = self.interpolate_noise(noise, noise2interpolate, posterior_log_variance=model_log_variance, **kwargs)
212
+ return model_mean + interpolated_noise # (0.5 * model_log_variance).exp()
213
+
SkeletonDiffusion/src/core/diffusion/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def dim_null_space(matrix):
4
+ assert matrix.shape[-1] == matrix.shape[-2], "Matrix must be square"
5
+ # rank = torch.linalg.matrix_rank(matrix) This is not set to accuracy of PYTORCH float32
6
+ # 1.0 + eps != 1.0
7
+ # torhc.tensor(1.0) + 0.7e-7!= torhc.tensor(1.0)
8
+ return torch.sum(torch.linalg.eigh(matrix)[0].abs() < 0.7e-7)
9
+
10
+ def is_positive_def(matrix):
11
+ #M is symmetric or Hermitian, and all its eigenvalues are real and positive.
12
+ assert torch.allclose(matrix.transpose(-1, -2), matrix), "Matrix must be symmetric"
13
+ eigenvalues = torch.linalg.eigvals(matrix)
14
+ is_pos_def = (torch.real(eigenvalues)> 0).all()
15
+ if is_pos_def:
16
+ assert torch.isreal(eigenvalues).all(), "Eigenvalues must be real"
17
+ return (torch.real(eigenvalues)> 0).all()
18
+
19
+ def make_positive_definite(matrix, epsilon=1e-6, if_submin=False):
20
+
21
+ eigenvalues = torch.linalg.eigvals(matrix)
22
+ # assert torch.isreal(eigenvalues).all()
23
+ if is_positive_def(matrix):
24
+ print("Input Matrix was positive Definitive without adding spectral norm to the diagonal")
25
+ return matrix
26
+
27
+ eigenvalues = torch.real(eigenvalues)
28
+ if not if_submin:
29
+ max_eig = eigenvalues.abs().max() #
30
+ pos_def_matrix = matrix + torch.eye(matrix.shape[0])*(max_eig + epsilon)
31
+ else:
32
+ min_eig = eigenvalues.min()
33
+ pos_def_matrix = matrix + torch.eye(matrix.shape[0])*(- min_eig + epsilon)
34
+ assert dim_null_space(pos_def_matrix) == 0
35
+ return pos_def_matrix
36
+
37
+ def normalize_cov(Sigma_N:torch.Tensor, Lambda_N:torch.Tensor, U:torch.Tensor, if_sigma_n_scale=True, sigma_n_scale='spectral', **kwargs):
38
+ N, _ = Sigma_N.shape
39
+ assert Lambda_N.shape == (N,)
40
+ assert U.shape == (N, N)
41
+
42
+ if if_sigma_n_scale:
43
+ # decrease the scale of Sigma_N to make it more similar to the identity matrix
44
+ if sigma_n_scale == 'spectral':
45
+ relative_scale_factor = Lambda_N.max()
46
+ else:
47
+ if sigma_n_scale == 'frob':
48
+ relative_scale_factor = Lambda_N.sum()/N
49
+ else:
50
+ assert 0, "Not implemented"
51
+
52
+ Lambda_N = Lambda_N/relative_scale_factor
53
+
54
+ Sigma_N = Sigma_N/relative_scale_factor
55
+ cond = U @ torch.diag(Lambda_N) @ U.mT
56
+ assert torch.isclose(Sigma_N, cond, atol=1e-06).all(), "Sigma_N must be equal to U @ Lambda_N @ U.t()"
57
+ # Sigma_N[Sigma_N>0.] = (Sigma_N + Sigma_N.t())[Sigma_N>0.]/2
58
+ cond = Lambda_N>0.7e-7
59
+ assert (cond).all(), f"Lambda_N must be positive definite: {Lambda_N}"
60
+ assert is_positive_def(Sigma_N), "Sigma_N must be positive definite"
61
+ # print("Frobenius Norm of SigmaN: ", torch.linalg.matrix_norm(Sigma_N, ord='fro').mean().item(), "Spectral Norm of SigmaN: ", Lambda_N.max(dim=-1)[0].mean().item())
62
+ return Sigma_N, Lambda_N
63
+
64
+
65
+ def get_cov_from_corr(correlation_matrix: torch.Tensor, if_sigma_n_scale=True, sigma_n_scale='spectral', if_run_as_isotropic=False, diffusion_covariance_type='skeleton-diffusion', **kwargs):
66
+ N, _ = correlation_matrix.shape
67
+
68
+ if if_run_as_isotropic:
69
+ if diffusion_covariance_type == 'skeleton-diffusion':
70
+ Lambda_N = torch.ones(N, device=correlation_matrix.device)
71
+ Sigma_N = torch.zeros_like(correlation_matrix)
72
+ U = torch.eye(N, device=correlation_matrix.device)
73
+ elif diffusion_covariance_type == 'anisotropic':
74
+ Lambda_N = torch.ones(N, device=correlation_matrix.device)
75
+ Sigma_N = torch.eye(N, device=correlation_matrix.device)
76
+ U = torch.eye(N, device=correlation_matrix.device)
77
+ else:
78
+ Lambda_N = torch.zeros(N, device=correlation_matrix.device)
79
+ Sigma_N = torch.zeros_like(correlation_matrix)
80
+ U = torch.eye(N, device=correlation_matrix.device)
81
+ else:
82
+ Sigma_N = make_positive_definite(correlation_matrix)
83
+ Lambda_N, U = torch.linalg.eigh(Sigma_N, UPLO='L')
84
+
85
+ Sigma_N, Lambda_N = normalize_cov(Sigma_N=Sigma_N, Lambda_N=Lambda_N, U=U, if_sigma_n_scale=if_sigma_n_scale, sigma_n_scale=sigma_n_scale, **kwargs)
86
+ return Sigma_N, Lambda_N, U
87
+
88
+
89
+ def verify_noise_scale(diffusion):
90
+ N, *_ = diffusion.Lambda_N.shape
91
+ alphas = 1 - diffusion.betas
92
+ noise = diffusion.get_noise((2000, diffusion.num_timesteps, N))
93
+ zeta_noise = torch.sqrt(diffusion.Lambda_t.unsqueeze(0)) * noise
94
+ print("current: ", (zeta_noise**2).sum(-1).mean(0))
95
+ print("original standard gaussian diffusion: ",(1-alphas) * zeta_noise.shape[-1])
96
+
97
+
98
+ def plot_matrix(matrix):
99
+ import matplotlib
100
+ import matplotlib.pyplot as plt
101
+ from matplotlib.colors import ListedColormap
102
+ import numpy as np
103
+
104
+ Sigma_N = matrix.cpu().clone().numpy()
105
+ color = 'Purples'
106
+ cmap = matplotlib.colormaps[color].set_bad("white")
107
+ # colormap_r = ListedColormap(cmap.colors[::-1])
108
+
109
+ fig, ax = plt.subplots(1,1, figsize=(6, 6),sharex=True, subplot_kw=dict(box_aspect=1),)
110
+ # cax = fig.add_axes([0.93, 0.15, 0.01, 0.7]) # Adjust the position and size of the colorbar
111
+ # for i, ax in enumerate(axes):
112
+ vmax = Sigma_N.max()
113
+ Sigma_N[Sigma_N <=0.0000] = np.nan
114
+ im = ax.imshow(Sigma_N, cmap=color, vmin=0., vmax=vmax)
115
+ # ax.set_xticks(np.arange(len(Sigma_N)))
116
+ # ax.set_xticklabels(labels=list(skeleton.node_dict.values()), rotation=45, ha="right",
117
+ # rotation_mode="anchor")
118
+ # ax.set_yticks(np.arange(len(Sigma_N)))
119
+ # ax.set_yticklabels(labels=list(skeleton.node_dict.values()), rotation=45, ha="right",
120
+ # rotation_mode="anchor")
121
+ # ax.set_title(list(method2sigman.keys())[i])
122
+ fig.colorbar(im, cmap=cmap)
123
+ # plt.title('Adjancecy Matrix')
124
+ plt.show()
125
+ # fig.savefig("../paper_plots/sigmaN.pdf", format="pdf", bbox_inches="tight")
SkeletonDiffusion/src/core/diffusion_manager.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, List, Union, Dict, Any
2
+ import torch
3
+
4
+ from SkeletonDiffusion.src.core.network.nn import Denoiser
5
+ from SkeletonDiffusion.src.core.diffusion import IsotropicGaussianDiffusion, NonisotropicGaussianDiffusion, get_cov_from_corr
6
+
7
+
8
+ class DiffusionManager():
9
+ def __init__(self, diffusion_type: str='IsotropicGaussianDiffusion', skeleton=None, covariance_matrix_type: str = 'adjacency',
10
+ reachability_matrix_degree_factor=0.5, reachability_matrix_stop_at=0, if_sigma_n_scale=True, sigma_n_scale='spectral', if_run_as_isotropic=False,
11
+ **kwargs):
12
+
13
+ model = self.get_network(**kwargs)
14
+ self.diffusion_type = diffusion_type
15
+
16
+ if diffusion_type == 'NonisotropicGaussianDiffusion':
17
+ # define SigmaN
18
+ if covariance_matrix_type == 'adjacency':
19
+ correlation_matrix = skeleton.adj_matrix
20
+ elif covariance_matrix_type == 'reachability':
21
+ correlation_matrix = skeleton.reachability_matrix(factor=reachability_matrix_degree_factor, stop_at=reachability_matrix_stop_at)
22
+ else:
23
+ assert 0, "Not implemented"
24
+ N, *_ = correlation_matrix.shape
25
+
26
+ Sigma_N, Lambda_N, U = get_cov_from_corr(correlation_matrix=correlation_matrix, if_sigma_n_scale=if_sigma_n_scale, sigma_n_scale=sigma_n_scale, if_run_as_isotropic=if_run_as_isotropic, **kwargs)
27
+ self.diffusion = NonisotropicGaussianDiffusion(Sigma_N=Sigma_N, Lambda_N=Lambda_N, U=U, model=model, **kwargs)
28
+ elif diffusion_type == 'IsotropicGaussianDiffusion':
29
+ self.diffusion = IsotropicGaussianDiffusion(model=model, **kwargs)
30
+ else:
31
+ assert 0, f"{diffusion_type} Not implemented"
32
+
33
+ def get_diffusion(self):
34
+ return self.diffusion
35
+
36
+ def get_network(self, num_nodes, diffusion_conditioning=False, latent_size=96, node_types: torch.Tensor = None, diffusion_arch=Dict[str, Any], **kwargs):
37
+
38
+ if diffusion_conditioning:
39
+ cond_dim = latent_size
40
+ else:
41
+ cond_dim = 0
42
+
43
+ model = Denoiser(dim=latent_size, cond_dim=cond_dim, out_dim=latent_size, channels=num_nodes, num_nodes=num_nodes, node_types=node_types,**diffusion_arch)
44
+
45
+ return model
SkeletonDiffusion/src/core/network/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .nn import AutoEncoder, Denoiser
2
+
3
+ __all__ = ['AutoEncoder', 'Denoiser']
SkeletonDiffusion/src/core/network/layers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .graph_structural import StaticGraphLinear
2
+ from .recurrent import StaticGraphGRU, GraphGRUState, StaticGraphLSTM, GraphLSTMState
3
+ from .attention import Attention, ResnetBlock, Residual, PreNorm, RMSNorm
SkeletonDiffusion/src/core/network/layers/attention.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum, Tensor
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from torch import nn, einsum
6
+
7
+ from .graph_structural import StaticGraphLinear
8
+
9
+
10
+
11
+ class Residual(nn.Module):
12
+ def __init__(self, fn):
13
+ super().__init__()
14
+ self.fn = fn
15
+
16
+ def forward(self, x, *args, **kwargs):
17
+ return self.fn(x, *args, **kwargs) + x
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, dim):
21
+ super().__init__()
22
+ self.norm = torch.nn.LayerNorm((dim), elementwise_affine=True)
23
+
24
+ def forward(self, x):
25
+ x = torch.swapaxes(x, -2, -1)
26
+ x = self.norm(x)
27
+ x = torch.swapaxes(x, -2, -1)
28
+ return x
29
+
30
+ class RMSNorm(nn.Module):
31
+ def __init__(self, dim):
32
+ super().__init__()
33
+ self.g = nn.Parameter(torch.ones(1, 1, dim))
34
+
35
+ def forward(self, x):
36
+ return F.normalize(x, dim = -1) * self.g * (x.shape[-1] ** 0.5) #normalize divides by maximum norm element. Different from original in which we take the max norma nd not the sum of square elem.
37
+
38
+ class PreNorm(nn.Module):
39
+ def __init__(self, dim, fn):
40
+ super().__init__()
41
+ self.fn = fn
42
+ self.norm = RMSNorm(dim)
43
+
44
+ def forward(self, x):
45
+ x = self.norm(x)
46
+ return self.fn(x)
47
+
48
+
49
+ class Block(nn.Module):
50
+ def __init__(self, dim, dim_out, norm_type='none', act_type='tanh', *args, **kwargs):
51
+ super().__init__()
52
+ self.proj = StaticGraphLinear(dim, dim_out, *args, **kwargs)
53
+ # num_nodes=num_nodes,
54
+ # node_types=T)
55
+ if norm_type == 'none':
56
+ self.norm = nn.Identity() #nn.GroupNorm(groups, dim_out)
57
+ elif norm_type == 'layer':
58
+ self.norm = LayerNorm(kwargs['num_nodes'])
59
+ else:
60
+ assert 0, f"Norm type {norm_type} not implemented!"
61
+ if act_type == 'tanh':
62
+ self.act = nn.Tanh()
63
+ else:
64
+ assert 0, f"Activation type {act_type} not implemented!"
65
+
66
+ def forward(self, x, scale_shift = None):
67
+ x = self.proj(x)
68
+ x = self.norm(x)
69
+
70
+ if scale_shift is not None:
71
+ scale, shift = scale_shift
72
+ x = x * (scale + 1) + shift
73
+
74
+ x = self.act(x)
75
+ return x
76
+
77
+
78
+ class ResnetBlock(nn.Module):
79
+ def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8, **kwargs):
80
+ super().__init__()
81
+ self.mlp = nn.Sequential(
82
+ nn.Tanh(),
83
+ nn.Linear(time_emb_dim, dim_out * 2)
84
+ ) if time_emb_dim is not None else None
85
+
86
+ self.block1 = Block(dim, dim_out, groups = groups, **kwargs)
87
+ self.block2 = Block(dim_out, dim_out, groups = groups, **kwargs)
88
+ self.res_linear = StaticGraphLinear(dim, dim_out, bias=False, **kwargs) if dim != dim_out else nn.Identity()
89
+
90
+ def forward(self, x, time_emb = None):
91
+
92
+ scale_shift = None
93
+ if self.mlp is not None and time_emb is not None:
94
+ time_emb = self.mlp(time_emb)
95
+ time_emb = rearrange(time_emb, 'b c -> b 1 c')
96
+ scale_shift = time_emb.chunk(2, dim = -1)
97
+
98
+ h = self.block1(x, scale_shift = scale_shift)
99
+
100
+ h = self.block2(h)
101
+
102
+ return h + self.res_linear(x)
103
+
104
+ # We need default num_heads: int = 8,
105
+ class Attention(nn.Module):
106
+ def __init__(self, dim, dim_out=None, heads = 4, dim_head = 32,qkv_bias: bool = False, attn_dropout: float = 0., proj_dropout: float = 0., qk_norm: bool = False, norm_layer: nn.Module = nn.Identity, **kwargs):
107
+ super().__init__()
108
+ self.scale = dim_head ** -0.5
109
+ self.heads = heads
110
+ hidden_dim = dim_head * heads
111
+ dim_out = dim_out if dim_out is not None else dim
112
+
113
+ self.to_qkv = StaticGraphLinear(dim,hidden_dim * 3,bias=qkv_bias, **kwargs)
114
+ self.to_out = StaticGraphLinear(hidden_dim,dim_out,bias=False,**kwargs)
115
+ self.attn_dropout = nn.Dropout(attn_dropout)
116
+ self.out_dropout = nn.Dropout(proj_dropout)
117
+
118
+
119
+ self.q_norm = norm_layer(dim_head) if qk_norm else nn.Identity()
120
+ self.k_norm = norm_layer(dim_head) if qk_norm else nn.Identity()
121
+
122
+ def forward(self, x):
123
+ b, n, c = x.shape
124
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
125
+ q, k, v = map(lambda t: rearrange(t, 'b n (h c) -> b h c n', h = self.heads), qkv)
126
+ q, k = self.q_norm(q), self.k_norm(k)
127
+
128
+ q = q * self.scale
129
+ sim = einsum('b h c n, b h c j -> b h n j', q, k)
130
+ attn = sim.softmax(dim = -1)
131
+ attn = self.attn_dropout(attn)
132
+
133
+ out = einsum('b h n j, b h d j -> b h n d', attn, v)
134
+
135
+ out = rearrange(out, 'b h n d -> b n (h d)')
136
+ return self.out_dropout(self.to_out(out))
137
+
138
+
SkeletonDiffusion/src/core/network/layers/graph_structural.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, List, Union
2
+
3
+ import torch
4
+ from torch.nn import *
5
+ import math
6
+
7
+ def gmm(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
8
+ return torch.einsum('ndo,bnd->bno', w, x)
9
+
10
+
11
+ class GraphLinear(Module):
12
+ def __init__(self, in_features: int, out_features: int):
13
+ super().__init__()
14
+ self.in_features = in_features
15
+ self.out_features = out_features
16
+
17
+ def reset_parameters(self) -> None:
18
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
19
+ #stdv = 1. / math.sqrt(self.weight.size(1))
20
+ #self.weight.data.uniform_(-stdv, stdv)
21
+ #if self.learn_influence:
22
+ # self.G.data.uniform_(-stdv, stdv)
23
+ if len(self.weight.shape) == 3:
24
+ self.weight.data[1:] = self.weight.data[0]
25
+ if self.bias is not None:
26
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
27
+ bound = 1 / math.sqrt(fan_in)
28
+ init.uniform_(self.bias, -bound, bound)
29
+
30
+ def forward(self, input: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor:
31
+ if g is None and self.learn_influence:
32
+ g = torch.nn.functional.normalize(self.G, p=1., dim=1)
33
+ #g = torch.softmax(self.G, dim=1)
34
+ elif g is None:
35
+ g = self.G
36
+ w = self.weight[self.node_type_index]
37
+ output = self.mm(input, w.transpose(-2, -1))
38
+ if self.bias is not None:
39
+ bias = self.bias[self.node_type_index]
40
+ output += bias
41
+ output = g.matmul(output)
42
+
43
+ return output
44
+
45
+
46
+ class DynamicGraphLinear(GraphLinear):
47
+ def __init__(self, num_node_types: int = 1, *args):
48
+ super().__init__(*args)
49
+
50
+ def forward(self, input: torch.Tensor, g: torch.Tensor = None, t: torch.Tensor = None) -> torch.Tensor:
51
+ assert g is not None or t is not None, "Either Graph Influence Matrix or Node Type Vector is needed"
52
+ if g is None:
53
+ g = self.G[t][:, t]
54
+ return super().forward(input, g)
55
+
56
+
57
+
58
+ class StaticGraphLinear(GraphLinear):
59
+ def __init__(self, *args, bias: bool = True, num_nodes: int = None, graph_influence: Union[torch.Tensor, Parameter] = None,
60
+ learn_influence: bool = False, node_types: torch.Tensor = None, weights_per_type: bool = False, **kwargs):
61
+ """
62
+ :param in_features: Size of each input sample
63
+ :param out_features: Size of each output sample
64
+ :param num_nodes: Number of nodes.
65
+ :param graph_influence: Graph Influence Matrix
66
+ :param learn_influence: If set to ``False``, the layer will not learn an the Graph Influence Matrix.
67
+ :param node_types: List of Type for each node. All nodes of same type will share weights.
68
+ Default: All nodes have unique types.
69
+ :param weights_per_type: If set to ``False``, the layer will not learn weights for each node type.
70
+ :param bias: If set to ``False``, the layer will not learn an additive bias.
71
+ """
72
+ super().__init__(*args)
73
+
74
+ self.learn_influence = learn_influence
75
+
76
+ if graph_influence is not None:
77
+ assert num_nodes == graph_influence.shape[0] or num_nodes is None, 'Number of Nodes or Graph Influence Matrix has to be given.'
78
+ num_nodes = graph_influence.shape[0]
79
+ if type(graph_influence) is Parameter:
80
+ assert learn_influence, "Graph Influence Matrix is a Parameter, therefore it must be learnable."
81
+ self.G = graph_influence
82
+ elif learn_influence:
83
+ self.G = Parameter(graph_influence)
84
+ else:
85
+ self.register_buffer('G', graph_influence)
86
+ else:
87
+ assert num_nodes, 'Number of Nodes or Graph Influence Matrix has to be given.'
88
+ eye_influence = torch.eye(num_nodes, num_nodes)
89
+ if learn_influence:
90
+ self.G = Parameter(eye_influence)
91
+ else:
92
+ self.register_buffer('G', eye_influence)
93
+
94
+ if weights_per_type and node_types is None:
95
+ node_types = torch.tensor([i for i in range(num_nodes)])
96
+ if node_types is not None:
97
+ num_node_types = node_types.max() + 1
98
+ self.weight = Parameter(torch.Tensor(num_node_types, self.out_features, self.in_features))
99
+ self.mm = gmm
100
+ self.node_type_index = node_types
101
+ else:
102
+ self.weight = Parameter(torch.Tensor(self.out_features, self.in_features))
103
+ self.mm = torch.matmul
104
+ self.node_type_index = None
105
+
106
+ if bias:
107
+ if node_types is not None:
108
+ self.bias = Parameter(torch.Tensor(num_node_types, self.out_features))
109
+ else:
110
+ self.bias = Parameter(torch.Tensor(self.out_features))
111
+ else:
112
+ self.register_parameter('bias', None)
113
+
114
+ self.reset_parameters()
115
+
116
+
117
+ # class BN(Module):
118
+ # def __init__(self, num_nodes, num_features):
119
+ # super().__init__()
120
+ # self.num_nodes = num_nodes
121
+ # self.num_features = num_features
122
+ # self.bn = BatchNorm1d(num_nodes * num_features)
123
+
124
+ # def forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ # return self.bn(x.view(-1, self.num_nodes * self.num_features)).view(-1, self.num_nodes, self.num_features)
126
+
127
+ # class LinearX(Module):
128
+ # def __init__(self):
129
+ # super().__init__()
130
+
131
+ # def forward(self, input: torch.Tensor) -> torch.Tensor:
132
+ # return input
133
+
SkeletonDiffusion/src/core/network/layers/recurrent.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, List, Union
2
+
3
+ import torch
4
+ from torch.nn import *
5
+ import math
6
+
7
+ from .graph_structural import gmm
8
+
9
+ IFDEF_JITSCRIPT = False # generates nan for long sequences.
10
+
11
+ GraphLSTMState = Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
12
+
13
+ class StaticGraphLSTMCell_(Module):
14
+ def __init__(self, input_size: int, hidden_size: int, num_nodes: int = None, dropout: float = 0.,
15
+ recurrent_dropout: float = 0., graph_influence: Union[torch.Tensor, Parameter] = None,
16
+ learn_influence: bool = False, additive_graph_influence: Union[torch.Tensor, Parameter] = None,
17
+ learn_additive_graph_influence: bool = False, node_types: torch.Tensor = None,
18
+ weights_per_type: bool = False, clockwork: bool = False, bias: bool = True):
19
+ """
20
+
21
+ :param input_size: The number of expected features in the input `x`
22
+ :param hidden_size: The number of features in the hidden state `h`
23
+ :param num_nodes:
24
+ :param dropout:
25
+ :param recurrent_dropout:
26
+ :param graph_influence:
27
+ :param learn_influence:
28
+ :param additive_graph_influence:
29
+ :param learn_additive_graph_influence:
30
+ :param node_types:
31
+ :param weights_per_type:
32
+ :param bias:
33
+ """
34
+ super().__init__()
35
+ self.input_size = input_size
36
+ self.hidden_size = hidden_size
37
+
38
+ self.learn_influence = learn_influence
39
+ self.learn_additive_graph_influence = learn_additive_graph_influence
40
+ if graph_influence is not None:
41
+ assert num_nodes == graph_influence.shape[0] or num_nodes is None, 'Number of Nodes or Graph Influence Matrix has to be given.'
42
+ num_nodes = graph_influence.shape[0]
43
+ if type(graph_influence) is Parameter:
44
+ assert learn_influence, "Graph Influence Matrix is a Parameter, therefore it must be learnable."
45
+ self.G = graph_influence
46
+ elif learn_influence:
47
+ self.G = Parameter(graph_influence)
48
+ else:
49
+ self.register_buffer('G', graph_influence)
50
+ else:
51
+ assert num_nodes, 'Number of Nodes or Graph Influence Matrix has to be given.'
52
+ eye_influence = torch.eye(num_nodes, num_nodes)
53
+ if learn_influence:
54
+ self.G = Parameter(eye_influence)
55
+ else:
56
+ self.register_buffer('G', eye_influence)
57
+
58
+ if additive_graph_influence is not None:
59
+ if type(additive_graph_influence) is Parameter:
60
+ self.G_add = additive_graph_influence
61
+ elif learn_additive_graph_influence:
62
+ self.G_add = Parameter(additive_graph_influence)
63
+ else:
64
+ self.register_buffer('G_add', additive_graph_influence)
65
+ else:
66
+ if learn_additive_graph_influence:
67
+ self.G_add = Parameter(torch.zeros_like(self.G))
68
+ else:
69
+ self.G_add = 0.
70
+
71
+ if weights_per_type and node_types is None:
72
+ node_types = torch.tensor([i for i in range(num_nodes)])
73
+ if node_types is not None:
74
+ num_node_types = node_types.max() + 1
75
+ self.weight_ih = Parameter(torch.Tensor(num_node_types, 4 * hidden_size, input_size))
76
+ self.weight_hh = Parameter(torch.Tensor(num_node_types, 4 * hidden_size, hidden_size))
77
+ self.mm = gmm
78
+ self.register_buffer('node_type_index', node_types)
79
+ else:
80
+ self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
81
+ self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
82
+ self.mm = torch.matmul
83
+ self.register_buffer('node_type_index', None)
84
+
85
+ if bias:
86
+ if node_types is not None:
87
+ self.bias_ih = Parameter(torch.Tensor(num_node_types, 4 * hidden_size))
88
+ self.bias_hh = Parameter(torch.Tensor(num_node_types, 4 * hidden_size))
89
+ else:
90
+ self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
91
+ self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))
92
+ else:
93
+ self.bias_ih = None
94
+ self.bias_hh = None
95
+
96
+ self.clockwork = clockwork
97
+ if clockwork:
98
+ phase = torch.arange(0., hidden_size)
99
+ phase = phase - phase.min()
100
+ phase = (phase / phase.max()) * 8.
101
+ phase += 1.
102
+ phase = torch.floor(phase)
103
+ self.register_buffer('phase', phase)
104
+ else:
105
+ phase = torch.ones(hidden_size)
106
+ self.register_buffer('phase', phase)
107
+
108
+ self.dropout = Dropout(dropout)
109
+ self.r_dropout = Dropout(recurrent_dropout)
110
+
111
+ self.num_nodes = num_nodes
112
+
113
+ self.init_weights()
114
+
115
+ def init_weights(self):
116
+ stdv = 1.0 / math.sqrt(self.hidden_size)
117
+ for weight in self.parameters():
118
+ if weight is self.G:
119
+ continue
120
+ if weight is self.G_add:
121
+ continue
122
+ weight.data.uniform_(-stdv, stdv)
123
+ if weight is self.weight_hh or weight is self.weight_ih and len(self.weight_ih.shape) == 3:
124
+ weight.data[1:] = weight.data[0]
125
+
126
+ def forward(self, input: torch.Tensor, state: GraphLSTMState, t: int = 0) -> Tuple[torch.Tensor, GraphLSTMState]:
127
+ hx, cx, gx = state
128
+ if hx is None:
129
+ hx = torch.zeros(input.shape[0], self.num_nodes, self.hidden_size, dtype=input.dtype, device=input.device)
130
+ if cx is None:
131
+ cx = torch.zeros(input.shape[0], self.num_nodes, self.hidden_size, dtype=input.dtype, device=input.device)
132
+ if gx is None and self.learn_influence:
133
+ gx = torch.nn.functional.normalize(self.G, p=1., dim=1)
134
+ #gx = torch.softmax(self.G, dim=1)
135
+ elif gx is None:
136
+ gx = self.G
137
+
138
+ hx = self.r_dropout(hx)
139
+
140
+ weight_ih = self.weight_ih[self.node_type_index]
141
+ weight_hh = self.weight_hh[self.node_type_index]
142
+ if self.bias_hh is not None:
143
+ bias_hh = self.bias_hh[self.node_type_index]
144
+ else:
145
+ bias_hh = 0.
146
+
147
+ c_mask = (torch.remainder(torch.tensor(t + 1., device=input.device), self.phase) < 0.01).type_as(cx)
148
+
149
+ gates = (self.dropout(self.mm(input, weight_ih.transpose(-2, -1))) +
150
+ self.mm(hx, weight_hh.transpose(-2, -1)) + bias_hh)
151
+ gates = torch.matmul(gx, gates)
152
+ ingate, forgetgate, cellgate, outgate = gates.chunk(4, 2)
153
+
154
+ ingate = torch.sigmoid(ingate)
155
+ forgetgate = torch.sigmoid(forgetgate)
156
+ cellgate = torch.tanh(cellgate)
157
+ outgate = torch.sigmoid(outgate)
158
+
159
+ cy = c_mask * ((forgetgate * cx) + (ingate * cellgate)) + (1 - c_mask) * cx
160
+ hy = outgate * torch.tanh(cy)
161
+
162
+ gx = gx + self.G_add
163
+ if self.learn_influence or self.learn_additive_graph_influence:
164
+ gx = torch.nn.functional.normalize(gx, p=1., dim=1)
165
+ #gx = torch.softmax(gx, dim=1)
166
+
167
+ return hy, (hy, cy, gx)
168
+
169
+
170
+ class StaticGraphLSTM_(Module):
171
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, layer_dropout: float = 0.0, **kwargs):
172
+ super().__init__()
173
+ self.layers = ModuleList([StaticGraphLSTMCell_(input_size, hidden_size, **kwargs)]
174
+ + [StaticGraphLSTMCell_(hidden_size, hidden_size, **kwargs) for _ in range(num_layers - 1)])
175
+ self.dropout = Dropout(layer_dropout)
176
+
177
+ def forward(self, input: torch.Tensor, states: Optional[List[GraphLSTMState]] = None, t_i: int = 0) -> Tuple[torch.Tensor, List[GraphLSTMState]]:
178
+ if states is None:
179
+ n: Optional[torch.Tensor] = None
180
+ states = [(n, n, n)] * len(self.layers)
181
+
182
+ output_states: List[GraphLSTMState] = []
183
+ output = input
184
+ i = 0
185
+ for rnn_layer in self.layers:
186
+ state = states[i]
187
+ inputs = output.unbind(1)
188
+ outputs: List[torch.Tensor] = []
189
+ for t, input in enumerate(inputs):
190
+ out, state = rnn_layer(input, state, t_i+t)
191
+ outputs += [out]
192
+ output = torch.stack(outputs, dim=1)
193
+ output = self.dropout(output)
194
+ output_states += [state]
195
+ i += 1
196
+ return output, output_states
197
+
198
+
199
+ def StaticGraphLSTM(*args, **kwargs):
200
+ if IFDEF_JITSCRIPT:
201
+ return torch.jit.script(StaticGraphLSTM_(*args, **kwargs))
202
+ else:
203
+ return StaticGraphLSTM_(*args, **kwargs)
204
+
205
+ GraphGRUState = Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]
206
+
207
+
208
+ class StaticGraphGRUCell_(Module):
209
+ def __init__(self, input_size: int, hidden_size: int, num_nodes: int = None, dropout: float = 0.,
210
+ recurrent_dropout: float = 0., graph_influence: Union[torch.Tensor, Parameter] = None,
211
+ learn_influence: bool = False, additive_graph_influence: Union[torch.Tensor, Parameter] = None,
212
+ learn_additive_graph_influence: bool = False, node_types: torch.Tensor = None,
213
+ weights_per_type: bool = False, clockwork: bool = False, bias: bool = True):
214
+ """
215
+
216
+ :param input_size: The number of expected features in the input `x`
217
+ :param hidden_size: The number of features in the hidden state `h`
218
+ :param num_nodes:
219
+ :param dropout:
220
+ :param recurrent_dropout:
221
+ :param graph_influence:
222
+ :param learn_influence:
223
+ :param additive_graph_influence:
224
+ :param learn_additive_graph_influence:
225
+ :param node_types:
226
+ :param weights_per_type:
227
+ :param bias:
228
+ """
229
+ super().__init__()
230
+ self.input_size = input_size
231
+ self.hidden_size = hidden_size
232
+
233
+ self.learn_influence = learn_influence
234
+ self.learn_additive_graph_influence = learn_additive_graph_influence
235
+ if graph_influence is not None:
236
+ assert num_nodes == graph_influence.shape[0] or num_nodes is None, 'Number of Nodes or Graph Influence Matrix has to be given.'
237
+ num_nodes = graph_influence.shape[0]
238
+ if type(graph_influence) is Parameter:
239
+ assert learn_influence, "Graph Influence Matrix is a Parameter, therefore it must be learnable."
240
+ self.G = graph_influence
241
+ elif learn_influence:
242
+ self.G = Parameter(graph_influence)
243
+ else:
244
+ self.register_buffer('G', graph_influence)
245
+ else:
246
+ assert num_nodes, 'Number of Nodes or Graph Influence Matrix has to be given.'
247
+ eye_influence = torch.eye(num_nodes, num_nodes)
248
+ if learn_influence:
249
+ self.G = Parameter(eye_influence)
250
+ else:
251
+ self.register_buffer('G', eye_influence)
252
+
253
+ if additive_graph_influence is not None:
254
+ if type(additive_graph_influence) is Parameter:
255
+ self.G_add = additive_graph_influence
256
+ elif learn_additive_graph_influence:
257
+ self.G_add = Parameter(additive_graph_influence)
258
+ else:
259
+ self.register_buffer('G_add', additive_graph_influence)
260
+ else:
261
+ if learn_additive_graph_influence:
262
+ self.G_add = Parameter(torch.zeros_like(self.G))
263
+ else:
264
+ self.G_add = 0.
265
+
266
+ if weights_per_type and node_types is None:
267
+ node_types = torch.tensor([i for i in range(num_nodes)])
268
+ if node_types is not None:
269
+ num_node_types = node_types.max() + 1
270
+ self.weight_ih = Parameter(torch.Tensor(num_node_types, 3 * hidden_size, input_size))
271
+ self.weight_hh = Parameter(torch.Tensor(num_node_types, 3 * hidden_size, hidden_size))
272
+ self.mm = gmm
273
+ self.register_buffer('node_type_index', node_types)
274
+ else:
275
+ self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
276
+ self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size))
277
+ self.mm = torch.matmul
278
+ self.register_buffer('node_type_index', None)
279
+
280
+ if bias:
281
+ if node_types is not None:
282
+ self.bias_ih = Parameter(torch.Tensor(num_node_types, 3 * hidden_size))
283
+ self.bias_hh = Parameter(torch.Tensor(num_node_types, 3 * hidden_size))
284
+ else:
285
+ self.bias_ih = Parameter(torch.Tensor(3 * hidden_size))
286
+ self.bias_hh = Parameter(torch.Tensor(3 * hidden_size))
287
+ else:
288
+ self.bias_ih = None
289
+ self.bias_hh = None
290
+
291
+ self.clockwork = clockwork
292
+ if clockwork:
293
+ phase = torch.arange(0., hidden_size)
294
+ phase = phase - phase.min()
295
+ phase = (phase / phase.max()) * 8.
296
+ phase += 1.
297
+ phase = torch.floor(phase)
298
+ self.register_buffer('phase', phase)
299
+ else:
300
+ phase = torch.ones(hidden_size)
301
+ self.register_buffer('phase', phase)
302
+
303
+ self.dropout = Dropout(dropout)
304
+ self.r_dropout = Dropout(recurrent_dropout)
305
+
306
+ self.num_nodes = num_nodes
307
+
308
+ self.init_weights()
309
+
310
+ def init_weights(self):
311
+ stdv = 1.0 / math.sqrt(self.hidden_size)
312
+ for weight in self.parameters():
313
+ if weight is self.G:
314
+ continue
315
+ if weight is self.G_add:
316
+ continue
317
+ weight.data.uniform_(-stdv, stdv)
318
+ #if weight is self.weight_hh or weight is self.weight_ih and len(self.weight_ih.shape) == 3:
319
+ # weight.data[1:] = weight.data[0]
320
+
321
+ def forward(self, input: torch.Tensor, state: GraphGRUState, t: int = 0) -> Tuple[torch.Tensor, GraphGRUState]:
322
+ hx, gx = state
323
+ if hx is None:
324
+ hx = torch.zeros(input.shape[0], self.num_nodes, self.hidden_size, dtype=input.dtype, device=input.device)
325
+ if gx is None and self.learn_influence:
326
+ gx = torch.nn.functional.normalize(self.G, p=1., dim=1)
327
+ #gx = torch.softmax(self.G, dim=1)
328
+ elif gx is None:
329
+ gx = self.G
330
+
331
+ hx = self.r_dropout(hx)
332
+
333
+ weight_ih = self.weight_ih[self.node_type_index]
334
+ weight_hh = self.weight_hh[self.node_type_index]
335
+ if self.bias_hh is not None:
336
+ bias_hh = self.bias_hh[self.node_type_index]
337
+ else:
338
+ bias_hh = 0.
339
+ if self.bias_ih is not None:
340
+ bias_ih = self.bias_ih[self.node_type_index]
341
+ else:
342
+ bias_ih = 0.
343
+
344
+ c_mask = (torch.remainder(torch.tensor(t + 1., device=input.device), self.phase) < 0.01).type_as(hx)
345
+
346
+ x_results = self.dropout(self.mm(input, weight_ih.transpose(-2, -1))) + bias_ih
347
+ h_results = self.mm(hx, weight_hh.transpose(-2, -1)) + bias_hh
348
+ x_results = torch.matmul(gx, x_results)
349
+ h_results = torch.matmul(gx, h_results)
350
+
351
+ i_r, i_z, i_n = x_results.chunk(3, 2)
352
+ h_r, h_z, h_n = h_results.chunk(3, 2)
353
+
354
+ r = torch.sigmoid(i_r + h_r)
355
+ z = torch.sigmoid(i_z + h_z)
356
+ n = torch.tanh(i_n + r * h_n)
357
+
358
+ hy = n - torch.mul(n, z) + torch.mul(z, hx)
359
+ hy = c_mask * hy + (1 - c_mask) * hx
360
+
361
+ gx = gx + self.G_add
362
+ if self.learn_influence or self.learn_additive_graph_influence:
363
+ gx = torch.nn.functional.normalize(gx, p=1., dim=1)
364
+ #gx = torch.softmax(gx, dim=1)
365
+
366
+ return hy, (hy, gx)
367
+
368
+
369
+ class StaticGraphGRU_(Module):
370
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, layer_dropout: float = 0.0, **kwargs):
371
+ super().__init__()
372
+ self.layers = ModuleList([StaticGraphGRUCell_(input_size, hidden_size, **kwargs)]
373
+ + [StaticGraphGRUCell_(hidden_size, hidden_size, **kwargs) for _ in range(num_layers - 1)])
374
+ self.dropout = Dropout(layer_dropout)
375
+
376
+ def forward(self, input: torch.Tensor, states: Optional[List[GraphGRUState]] = None, t_i: int = 0) -> Tuple[torch.Tensor, List[GraphGRUState]]:
377
+ if states is None:
378
+ n: Optional[torch.Tensor] = None
379
+ states = [(n, n)] * len(self.layers)
380
+
381
+ output_states: List[GraphGRUState] = []
382
+ output = input
383
+ i = 0
384
+ for rnn_layer in self.layers:
385
+ state = states[i]
386
+ inputs = output.unbind(1)
387
+ outputs: List[torch.Tensor] = []
388
+ for t, input in enumerate(inputs):
389
+ out, state = rnn_layer(input, state, t_i+t)
390
+ outputs += [out]
391
+ output = torch.stack(outputs, dim=1)
392
+ output = self.dropout(output)
393
+ output_states += [state]
394
+ i += 1
395
+ return output, output_states
396
+
397
+
398
+ def StaticGraphGRU(*args, **kwargs):
399
+ if IFDEF_JITSCRIPT:
400
+ return torch.jit.script(StaticGraphGRU_(*args, **kwargs))
401
+ else:
402
+ return StaticGraphGRU_(*args, **kwargs)
SkeletonDiffusion/src/core/network/nn/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .generator import Denoiser
2
+ from .autoencoder import AutoEncoder
SkeletonDiffusion/src/core/network/nn/autoencoder.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from ..layers import StaticGraphLinear
7
+ from .decoder import Decoder
8
+ from .encoder import Encoder
9
+
10
+
11
+ class AutoEncoder(nn.Module):
12
+ def __init__(self,
13
+ num_nodes: int,
14
+ encoder_hidden_size: int,
15
+ decoder_hidden_size: int,
16
+ latent_size: int,
17
+ node_types: torch.Tensor = None,
18
+ input_size: int = 3,
19
+ z_activation: str = 'tanh',
20
+ enc_num_layers: int = 1,
21
+ loss_pose_type: str = 'l1',
22
+ **kwargs):
23
+ super().__init__()
24
+ self.param_groups = [{}]
25
+ self.latent_size = latent_size
26
+ self.loss_pose_type = loss_pose_type
27
+
28
+ self.encoder = Encoder(num_nodes=num_nodes,
29
+ input_size=input_size,
30
+ hidden_size=encoder_hidden_size,
31
+ output_size=latent_size,
32
+ node_types=node_types,
33
+ enc_num_layers = enc_num_layers,
34
+ recurrent_arch = kwargs['recurrent_arch_enc'],)
35
+
36
+
37
+ assert kwargs['output_size'] == input_size
38
+ self.decoder = Decoder( num_nodes=num_nodes,
39
+ input_size=latent_size ,
40
+ feature_size=input_size,
41
+ hidden_size=decoder_hidden_size,
42
+ node_types=node_types,
43
+ param_groups=self.param_groups,
44
+ **kwargs
45
+ )
46
+ assert z_activation in ['tanh', 'identity'], f"z_activation must be either 'tanh' or 'identity', but got {z_activation}"
47
+ self.z_activation = nn.Tanh() if z_activation == "tanh" else nn.Identity()
48
+
49
+
50
+ def forward(self, x):
51
+ h, _ = self.encoder(x)
52
+ return h
53
+
54
+ def get_past_embedding(self, past, state=None):
55
+ with torch.no_grad():
56
+ h_hat_embedding = self(past)
57
+ z_past = self.z_activation(h_hat_embedding)
58
+ return z_past
59
+
60
+ def get_embedding(self, future, state=None):
61
+ z = self.forward(future)
62
+ return z
63
+
64
+ def get_train_embeddings(self, y, past, state=None):
65
+ z_past = self.get_past_embedding(past, state=state)
66
+ z = self.get_embedding(y, state=state)
67
+ return z_past, z
68
+
69
+ def decode(self, x: torch.Tensor, h: torch.Tensor, z: torch.Tensor, ph=1, state=None):
70
+ x_tiled = x[:, -2:]
71
+ out, _ = self.decoder(x=x_tiled,
72
+ h=h,
73
+ z=z,
74
+ ph=ph,
75
+ state=state) # [B * Z, T, N, D]
76
+ return out
77
+
78
+ def autoencode(self, y, past, ph=1, state=None):
79
+ z_past, z = self.get_train_embeddings(y, past, state=state)
80
+ out = self.decode(past, z, z_past, ph)
81
+ return out, z_past, z
82
+
83
+ def loss(self, y_pred, y, type=None, reduction="mean", **kwargs):
84
+ type = self.loss_pose_type if type is None else type
85
+ if type=="mse":
86
+ out = torch.nn.MSELoss(reduction="none")(y_pred,y)
87
+ elif type in ["l1", "L1"]:
88
+ out = torch.nn.L1Loss(reduction="none")(y_pred,y)
89
+ else:
90
+ assert 0, "Not implemnted"
91
+ loss = (out.sum(-1) #spatial size
92
+ .mean(-1) #keypoints
93
+ .mean(-1) # timesteps
94
+ )
95
+ if reduction == "mean":
96
+ return loss.mean()
97
+ elif reduction == "none":
98
+ return loss
99
+ else:
100
+ assert 0, "Not implemnted"
101
+ return loss
102
+
103
+
104
+
105
+