Spaces:
Running
on
Zero
Running
on
Zero
Jialin Yang
commited on
Commit
·
352b049
1
Parent(s):
7bbe360
Initial release on Huggingface Spaces with Gradio UI
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- .gitignore +213 -0
- .gradio/certificate.pem +31 -0
- .vscode/settings.json +6 -0
- README.md +4 -2
- Roboto-VariableFont_wdth,wght.ttf +0 -0
- SkeletonDiffusion/__init__.py +0 -0
- SkeletonDiffusion/configs/config_eval/config.yaml +53 -0
- SkeletonDiffusion/configs/config_eval/config_inferencetime.yaml +43 -0
- SkeletonDiffusion/configs/config_eval/dataset/3dpw.yaml +35 -0
- SkeletonDiffusion/configs/config_eval/dataset/amass-mano.yaml +76 -0
- SkeletonDiffusion/configs/config_eval/dataset/amass.yaml +52 -0
- SkeletonDiffusion/configs/config_eval/dataset/freeman.yaml +23 -0
- SkeletonDiffusion/configs/config_eval/dataset/h36m.yaml +26 -0
- SkeletonDiffusion/configs/config_eval/method_specs/skeleton_diffusion.yaml +1 -0
- SkeletonDiffusion/configs/config_eval/method_specs/zerovelocity_alg_baseline.yaml +3 -0
- SkeletonDiffusion/configs/config_eval/task/hmp.yaml +4 -0
- SkeletonDiffusion/configs/config_train/config_autoencoder.yaml +27 -0
- SkeletonDiffusion/configs/config_train/dataset/amass.yaml +48 -0
- SkeletonDiffusion/configs/config_train/dataset/freeman.yaml +38 -0
- SkeletonDiffusion/configs/config_train/dataset/h36m.yaml +40 -0
- SkeletonDiffusion/configs/config_train/model/autoencoder.yaml +57 -0
- SkeletonDiffusion/configs/config_train/task/hmp.yaml +11 -0
- SkeletonDiffusion/configs/config_train_diffusion/config_diffusion.yaml +25 -0
- SkeletonDiffusion/configs/config_train_diffusion/cov_matrix/adjacency.yaml +1 -0
- SkeletonDiffusion/configs/config_train_diffusion/cov_matrix/reachability.yaml +3 -0
- SkeletonDiffusion/configs/config_train_diffusion/model/isotropic_diffusion.yaml +57 -0
- SkeletonDiffusion/configs/config_train_diffusion/model/isotropic_diffusion_in_noniso_class.yaml +70 -0
- SkeletonDiffusion/configs/config_train_diffusion/model/skeleton_diffusion.yaml +69 -0
- SkeletonDiffusion/datasets +1 -0
- SkeletonDiffusion/environment_inference.yml +19 -0
- SkeletonDiffusion/inference.ipynb +343 -0
- SkeletonDiffusion/inference_filtered.ipynb +1 -0
- SkeletonDiffusion/setup.py +13 -0
- SkeletonDiffusion/src/__init__.py +7 -0
- SkeletonDiffusion/src/config_utils.py +62 -0
- SkeletonDiffusion/src/core/__init__.py +8 -0
- SkeletonDiffusion/src/core/diffusion/__init__.py +3 -0
- SkeletonDiffusion/src/core/diffusion/base.py +445 -0
- SkeletonDiffusion/src/core/diffusion/isotropic.py +104 -0
- SkeletonDiffusion/src/core/diffusion/nonisotropic.py +213 -0
- SkeletonDiffusion/src/core/diffusion/utils.py +125 -0
- SkeletonDiffusion/src/core/diffusion_manager.py +45 -0
- SkeletonDiffusion/src/core/network/__init__.py +3 -0
- SkeletonDiffusion/src/core/network/layers/__init__.py +3 -0
- SkeletonDiffusion/src/core/network/layers/attention.py +138 -0
- SkeletonDiffusion/src/core/network/layers/graph_structural.py +133 -0
- SkeletonDiffusion/src/core/network/layers/recurrent.py +402 -0
- SkeletonDiffusion/src/core/network/nn/__init__.py +2 -0
- SkeletonDiffusion/src/core/network/nn/autoencoder.py +105 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.torchscript filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
./magick filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
models/* filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
models/nlf_l_multi.torchscript filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
models/checkpoint_150.pt filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
downloads/* filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
outputs/* filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
intermediate_results/* filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
predictions/* filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
predictions/joints3d.npy filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
*.egg-info/
|
| 19 |
+
.installed.cfg
|
| 20 |
+
*.egg
|
| 21 |
+
|
| 22 |
+
# Virtual Environment
|
| 23 |
+
venv/
|
| 24 |
+
ENV/
|
| 25 |
+
env/
|
| 26 |
+
.env
|
| 27 |
+
|
| 28 |
+
# IDE
|
| 29 |
+
.idea/
|
| 30 |
+
.vscode/
|
| 31 |
+
*.swp
|
| 32 |
+
*.swo
|
| 33 |
+
|
| 34 |
+
# OS
|
| 35 |
+
.DS_Store
|
| 36 |
+
Thumbs.db
|
| 37 |
+
|
| 38 |
+
# Project specific
|
| 39 |
+
*.pth
|
| 40 |
+
*.ckpt
|
| 41 |
+
# *.pt
|
| 42 |
+
*.bin
|
| 43 |
+
*.npy
|
| 44 |
+
*.npz
|
| 45 |
+
*.mp4
|
| 46 |
+
*.avi
|
| 47 |
+
*.mov
|
| 48 |
+
*.jpg
|
| 49 |
+
*.jpeg
|
| 50 |
+
*.png
|
| 51 |
+
|
| 52 |
+
# Logs
|
| 53 |
+
*.log
|
| 54 |
+
logs/
|
| 55 |
+
|
| 56 |
+
# Distribution / packaging
|
| 57 |
+
.Python
|
| 58 |
+
build/
|
| 59 |
+
develop-eggs/
|
| 60 |
+
dist/
|
| 61 |
+
# downloads/
|
| 62 |
+
eggs/
|
| 63 |
+
.eggs/
|
| 64 |
+
lib/
|
| 65 |
+
lib64/
|
| 66 |
+
parts/
|
| 67 |
+
sdist/
|
| 68 |
+
var/
|
| 69 |
+
share/python-wheels/
|
| 70 |
+
*.egg-info/
|
| 71 |
+
.installed.cfg
|
| 72 |
+
*.egg
|
| 73 |
+
MANIFEST
|
| 74 |
+
|
| 75 |
+
# PyInstaller
|
| 76 |
+
# Usually these files are written by a python script from a template
|
| 77 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 78 |
+
*.manifest
|
| 79 |
+
*.spec
|
| 80 |
+
|
| 81 |
+
# Installer logs
|
| 82 |
+
pip-log.txt
|
| 83 |
+
pip-delete-this-directory.txt
|
| 84 |
+
|
| 85 |
+
# Unit test / coverage reports
|
| 86 |
+
htmlcov/
|
| 87 |
+
.tox/
|
| 88 |
+
.nox/
|
| 89 |
+
.coverage
|
| 90 |
+
.coverage.*
|
| 91 |
+
.cache
|
| 92 |
+
nosetests.xml
|
| 93 |
+
coverage.xml
|
| 94 |
+
*.cover
|
| 95 |
+
*.py,cover
|
| 96 |
+
.hypothesis/
|
| 97 |
+
.pytest_cache/
|
| 98 |
+
cover/
|
| 99 |
+
|
| 100 |
+
# Translations
|
| 101 |
+
*.mo
|
| 102 |
+
*.pot
|
| 103 |
+
|
| 104 |
+
# Django stuff:
|
| 105 |
+
*.log
|
| 106 |
+
local_settings.py
|
| 107 |
+
db.sqlite3
|
| 108 |
+
db.sqlite3-journal
|
| 109 |
+
|
| 110 |
+
# Flask stuff:
|
| 111 |
+
instance/
|
| 112 |
+
.webassets-cache
|
| 113 |
+
|
| 114 |
+
# Scrapy stuff:
|
| 115 |
+
.scrapy
|
| 116 |
+
|
| 117 |
+
# Sphinx documentation
|
| 118 |
+
docs/_build/
|
| 119 |
+
|
| 120 |
+
# PyBuilder
|
| 121 |
+
.pybuilder/
|
| 122 |
+
target/
|
| 123 |
+
|
| 124 |
+
# Jupyter Notebook
|
| 125 |
+
.ipynb_checkpoints
|
| 126 |
+
|
| 127 |
+
# IPython
|
| 128 |
+
profile_default/
|
| 129 |
+
ipython_config.py
|
| 130 |
+
|
| 131 |
+
# pyenv
|
| 132 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 133 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 134 |
+
# .python-version
|
| 135 |
+
|
| 136 |
+
# pipenv
|
| 137 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 138 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 139 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 140 |
+
# install all needed dependencies.
|
| 141 |
+
#Pipfile.lock
|
| 142 |
+
|
| 143 |
+
# UV
|
| 144 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 145 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 146 |
+
# commonly ignored for libraries.
|
| 147 |
+
#uv.lock
|
| 148 |
+
|
| 149 |
+
# poetry
|
| 150 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 151 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 152 |
+
# commonly ignored for libraries.
|
| 153 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 154 |
+
#poetry.lock
|
| 155 |
+
|
| 156 |
+
# pdm
|
| 157 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 158 |
+
#pdm.lock
|
| 159 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 160 |
+
# in version control.
|
| 161 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 162 |
+
.pdm.toml
|
| 163 |
+
.pdm-python
|
| 164 |
+
.pdm-build/
|
| 165 |
+
|
| 166 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 167 |
+
__pypackages__/
|
| 168 |
+
|
| 169 |
+
# Celery stuff
|
| 170 |
+
celerybeat-schedule
|
| 171 |
+
celerybeat.pid
|
| 172 |
+
|
| 173 |
+
# SageMath parsed files
|
| 174 |
+
*.sage.py
|
| 175 |
+
|
| 176 |
+
# Spyder project settings
|
| 177 |
+
.spyderproject
|
| 178 |
+
.spyproject
|
| 179 |
+
|
| 180 |
+
# Rope project settings
|
| 181 |
+
.ropeproject
|
| 182 |
+
|
| 183 |
+
# mkdocs documentation
|
| 184 |
+
/site
|
| 185 |
+
|
| 186 |
+
# mypy
|
| 187 |
+
.mypy_cache/
|
| 188 |
+
.dmypy.json
|
| 189 |
+
dmypy.json
|
| 190 |
+
|
| 191 |
+
# Pyre type checker
|
| 192 |
+
.pyre/
|
| 193 |
+
|
| 194 |
+
# pytype static type analyzer
|
| 195 |
+
.pytype/
|
| 196 |
+
|
| 197 |
+
# Cython debug symbols
|
| 198 |
+
cython_debug/
|
| 199 |
+
|
| 200 |
+
# PyCharm
|
| 201 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 202 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 203 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 204 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 205 |
+
#.idea/
|
| 206 |
+
|
| 207 |
+
# PyPI configuration file
|
| 208 |
+
.pypirc
|
| 209 |
+
9622_GRAB/
|
| 210 |
+
magick
|
| 211 |
+
outputs/*_obj
|
| 212 |
+
outputs/
|
| 213 |
+
intermediate_results/
|
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
.vscode/settings.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python.analysis.extraPaths": [
|
| 3 |
+
"./SkeletonDiffusion/src",
|
| 4 |
+
"./src_joints2smpl_demo/convert_"
|
| 5 |
+
]
|
| 6 |
+
}
|
README.md
CHANGED
|
@@ -4,10 +4,12 @@ emoji: 💻
|
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.24.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
+
run: |
|
| 12 |
+
bash setup.sh
|
| 13 |
+
python app.py
|
| 14 |
---
|
| 15 |
|
|
|
Roboto-VariableFont_wdth,wght.ttf
ADDED
|
Binary file (468 kB). View file
|
|
|
SkeletonDiffusion/__init__.py
ADDED
|
File without changes
|
SkeletonDiffusion/configs/config_eval/config.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
output_subdir: null
|
| 3 |
+
run:
|
| 4 |
+
dir: .
|
| 5 |
+
job:
|
| 6 |
+
chdir: False
|
| 7 |
+
|
| 8 |
+
dataset_main_path: ./datasets
|
| 9 |
+
dataset_annotation_path: ${dataset_main_path}/annotations&interm
|
| 10 |
+
dataset_precomputed_path: ${dataset_main_path}/processed
|
| 11 |
+
checkpoint_path: ''
|
| 12 |
+
defaults:
|
| 13 |
+
- _self_
|
| 14 |
+
- task: hmp
|
| 15 |
+
- method_specs: skeleton_diffusion
|
| 16 |
+
- dataset: amass
|
| 17 |
+
- override hydra/hydra_logging: disabled
|
| 18 |
+
- override hydra/job_logging: disabled
|
| 19 |
+
|
| 20 |
+
method_name: ${method_specs.method_name}
|
| 21 |
+
dtype: float32
|
| 22 |
+
if_noisy_obs: False
|
| 23 |
+
noise_level: 0.25
|
| 24 |
+
noise_std: 0.02
|
| 25 |
+
# num_nodes: ${eval:"int(${dataset.num_joints})-int(not ${task.if_consider_hip})"}
|
| 26 |
+
|
| 27 |
+
stats_mode: deterministic #probabilistic, deterministic
|
| 28 |
+
batch_size: 512
|
| 29 |
+
metrics_at_cpu: False
|
| 30 |
+
n_gpu: 1
|
| 31 |
+
num_samples: 50
|
| 32 |
+
if_measure_time: False
|
| 33 |
+
|
| 34 |
+
seed: 0
|
| 35 |
+
dataset_split: test
|
| 36 |
+
silent: False
|
| 37 |
+
|
| 38 |
+
obs_length: ${eval:'int(${task.history_sec} * ${dataset.fps})'}
|
| 39 |
+
pred_length: ${eval:'int(${task.prediction_horizon_sec} * ${dataset.fps})'}
|
| 40 |
+
|
| 41 |
+
if_store_output: False
|
| 42 |
+
store_output_path: ${eval:"'models/final_predictions_storage/${task.task_name}/${method_specs.method_name}/${dataset.dataset_name}/' if not ${if_long_term_test} else 'models/final_predictions_storage/${task.task_name}_longterm/${method_specs.method_name}/${dataset.dataset_name}/'"}
|
| 43 |
+
if_store_gt: False
|
| 44 |
+
store_gt_path: ${eval:"'models/final_predictions_storage/${task.task_name}/GT/${dataset.dataset_name}/' if not ${if_long_term_test} else 'models/final_predictions_storage/${task.task_name}_longterm/GT/${dataset.dataset_name}/'"}
|
| 45 |
+
|
| 46 |
+
if_compute_apde: ${eval:"False if ${eval:"'${dataset.dataset_name}' in ['freeman', '3dpw', 'nymeria']"} else True"}
|
| 47 |
+
if_long_term_test: False
|
| 48 |
+
long_term_factor: 2.5
|
| 49 |
+
|
| 50 |
+
if_compute_fid: False
|
| 51 |
+
if_compute_cmd: False
|
| 52 |
+
|
| 53 |
+
|
SkeletonDiffusion/configs/config_eval/config_inferencetime.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
output_subdir: null
|
| 3 |
+
run:
|
| 4 |
+
dir: .
|
| 5 |
+
job:
|
| 6 |
+
chdir: False
|
| 7 |
+
|
| 8 |
+
dataset_main_path: ./datasets
|
| 9 |
+
dataset_annotation_path: ${dataset_main_path}/annotations&interm
|
| 10 |
+
dataset_precomputed_path: ${dataset_main_path}/processed
|
| 11 |
+
checkpoint_path: ''
|
| 12 |
+
defaults:
|
| 13 |
+
- _self_
|
| 14 |
+
- task: hmp
|
| 15 |
+
- method_specs: skeldiff
|
| 16 |
+
- dataset: amass
|
| 17 |
+
- override hydra/hydra_logging: disabled
|
| 18 |
+
- override hydra/job_logging: disabled
|
| 19 |
+
|
| 20 |
+
method_name: ${method_specs.method_name}
|
| 21 |
+
dtype: float32
|
| 22 |
+
if_noisy_obs: False
|
| 23 |
+
noise_level: 0.25
|
| 24 |
+
noise_std: 0.02
|
| 25 |
+
|
| 26 |
+
mode: stats # 'vis: visualize results\ngen: generate and store all visualizations for a single batch\nstats: launch numeric evaluation')
|
| 27 |
+
stats_mode: deterministic #
|
| 28 |
+
if_measure_time: True
|
| 29 |
+
batch_size: 1
|
| 30 |
+
metrics_at_cpu: False
|
| 31 |
+
n_gpu: 1
|
| 32 |
+
num_samples: 50
|
| 33 |
+
|
| 34 |
+
seed: 0
|
| 35 |
+
dataset_split: test
|
| 36 |
+
silent: False
|
| 37 |
+
|
| 38 |
+
obs_length: ${eval:'int(${task.history_sec} * ${dataset.fps})'}
|
| 39 |
+
pred_length: ${eval:'int(${task.prediction_horizon_sec} * ${dataset.fps})'}
|
| 40 |
+
|
| 41 |
+
if_long_term_test: False
|
| 42 |
+
|
| 43 |
+
|
SkeletonDiffusion/configs/config_eval/dataset/3dpw.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_joints: 22 #including the hip root joint
|
| 2 |
+
fps: 60
|
| 3 |
+
|
| 4 |
+
multimodal_threshold: 0.4
|
| 5 |
+
dataset_type: D3PWZeroShotDataset
|
| 6 |
+
dataset_name: 3dpw
|
| 7 |
+
precomputed_folder: "${dataset_precomputed_path}/3DPW/${task.task_name}/"
|
| 8 |
+
annotations_folder: "${dataset_annotation_path}/3DPW/${task.task_name}/"
|
| 9 |
+
dtype: float32
|
| 10 |
+
|
| 11 |
+
data_loader_train_eval:
|
| 12 |
+
stride: 30
|
| 13 |
+
augmentation: 0
|
| 14 |
+
shuffle: False
|
| 15 |
+
da_mirroring: 0.
|
| 16 |
+
da_rotations: 0.
|
| 17 |
+
drop_last: False
|
| 18 |
+
if_load_mmgt: False
|
| 19 |
+
|
| 20 |
+
data_loader_valid:
|
| 21 |
+
stride: 30
|
| 22 |
+
augmentation: 0
|
| 23 |
+
shuffle: False
|
| 24 |
+
segments_path: "${dataset.annotations_folder}/segments_valid.csv"
|
| 25 |
+
actions: "all"
|
| 26 |
+
drop_last: False
|
| 27 |
+
if_load_mmgt: False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
data_loader_test:
|
| 31 |
+
shuffle: False
|
| 32 |
+
segments_path: "${dataset.annotations_folder}/segments_test_zero_shot.csv"
|
| 33 |
+
actions: "all"
|
| 34 |
+
drop_last: False
|
| 35 |
+
if_load_mmgt: ${eval:'True if "probabilistic" in "${stats_mode}" else False'}
|
SkeletonDiffusion/configs/config_eval/dataset/amass-mano.yaml
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_joints: 52 #included hip even if if_consider_hip=False
|
| 2 |
+
fps: 60
|
| 3 |
+
|
| 4 |
+
multimodal_threshold: 0.4
|
| 5 |
+
dataset_type: AMASSDataset
|
| 6 |
+
dataset_name: amass-mano
|
| 7 |
+
precomputed_folder: "${dataset_precomputed_path}/AMASS-MANO/${task.task_name}/"
|
| 8 |
+
annotations_folder: "${dataset_annotation_path}/AMASS-MANO/${task.task_name}/"
|
| 9 |
+
dtype: float32
|
| 10 |
+
|
| 11 |
+
# Accordingly, the training set
|
| 12 |
+
# contains the ACCAD, BMLhandball, BMLmovi, BMLrub,
|
| 13 |
+
# CMU, EKUT, EyesJapanDataset, KIT, PosePrior, TCD-
|
| 14 |
+
# Hands, and TotalCapture datasets, and the validation set
|
| 15 |
+
# contains the HumanEva, HDM05, SFU, and MoSh datasets.
|
| 16 |
+
# The remaining datasets are all part of the test set: DFaust,
|
| 17 |
+
# DanceDB, GRAB, HUMAN4D, SOMA, SSM, and Transi-
|
| 18 |
+
# tions.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
data_loader_train:
|
| 22 |
+
stride: 60
|
| 23 |
+
augmentation: 30
|
| 24 |
+
shuffle: True
|
| 25 |
+
datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture'] # from paper
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# "EyesJapanDataset",
|
| 30 |
+
|
| 31 |
+
# ,
|
| 32 |
+
|
| 33 |
+
# "HDM05",
|
| 34 |
+
# "MoSh"
|
| 35 |
+
da_mirroring: 0.5
|
| 36 |
+
da_rotations: 1.0
|
| 37 |
+
drop_last: True
|
| 38 |
+
if_load_mmgt: False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
data_loader_train_eval:
|
| 42 |
+
stride: 30
|
| 43 |
+
augmentation: 0
|
| 44 |
+
shuffle: False
|
| 45 |
+
datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture'] # from paper
|
| 46 |
+
# datasets: ['ACCAD', "BMLhandball", "BMLmovi", 'CMU', 'KIT', 'TotalCapture'] # decrease evaluatio time
|
| 47 |
+
da_mirroring: 0.
|
| 48 |
+
da_rotations: 0.
|
| 49 |
+
drop_last: False
|
| 50 |
+
if_load_mmgt: False
|
| 51 |
+
|
| 52 |
+
data_loader_valid:
|
| 53 |
+
stride: 30
|
| 54 |
+
augmentation: 0
|
| 55 |
+
shuffle: False
|
| 56 |
+
# segments_path: "./dataset_annotation_path/FreeMan/${task.task_name}/segments_valid.csv"
|
| 57 |
+
datasets: ['HumanEva', 'HDM05', 'SFU', 'MoSh'] # from paper
|
| 58 |
+
file_idces: "all"
|
| 59 |
+
drop_last: False
|
| 60 |
+
if_load_mmgt: False
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
data_loader_test:
|
| 64 |
+
shuffle: False
|
| 65 |
+
segments_path: "${dataset.annotations_folder}/segments_test.csv"
|
| 66 |
+
# datasets: ['Transitions_mocap', 'SSM_synced'], #DFaust, DanceDB, GRAB, HUMAN4D, SOMA, SSM, and Transitions.
|
| 67 |
+
datasets:
|
| 68 |
+
- Transitions
|
| 69 |
+
- SSM
|
| 70 |
+
- DFaust
|
| 71 |
+
- DanceDB
|
| 72 |
+
- GRAB
|
| 73 |
+
- HUMAN4D
|
| 74 |
+
- SOMA
|
| 75 |
+
drop_last: False
|
| 76 |
+
if_load_mmgt: False
|
SkeletonDiffusion/configs/config_eval/dataset/amass.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_joints: 22 #including the hip root joint
|
| 2 |
+
fps: 60
|
| 3 |
+
|
| 4 |
+
multimodal_threshold: 0.4
|
| 5 |
+
dataset_type: AMASSDataset
|
| 6 |
+
dataset_name: "amass"
|
| 7 |
+
precomputed_folder: "${dataset_precomputed_path}/AMASS/${task.task_name}/"
|
| 8 |
+
annotations_folder: "${dataset_annotation_path}/AMASS/${task.task_name}/"
|
| 9 |
+
dtype: float32
|
| 10 |
+
|
| 11 |
+
# Accordingly, the training set
|
| 12 |
+
# contains the ACCAD, BMLhandball, BMLmovi, BMLrub,
|
| 13 |
+
# CMU, EKUT, EyesJapanDataset, KIT, PosePrior, TCD-
|
| 14 |
+
# Hands, and TotalCapture datasets, and the validation set
|
| 15 |
+
# contains the HumanEva, HDM05, SFU, and MoSh datasets.
|
| 16 |
+
# The remaining datasets are all part of the test set: DFaust,
|
| 17 |
+
# DanceDB, GRAB, HUMAN4D, SOMA, SSM, and Transi-
|
| 18 |
+
# tions.
|
| 19 |
+
|
| 20 |
+
data_loader_train_eval:
|
| 21 |
+
stride: 30
|
| 22 |
+
augmentation: 0
|
| 23 |
+
shuffle: False
|
| 24 |
+
datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture']
|
| 25 |
+
da_mirroring: 0.
|
| 26 |
+
da_rotations: 0.
|
| 27 |
+
drop_last: False
|
| 28 |
+
if_load_mmgt: False
|
| 29 |
+
|
| 30 |
+
data_loader_valid:
|
| 31 |
+
stride: 30
|
| 32 |
+
augmentation: 0
|
| 33 |
+
shuffle: False
|
| 34 |
+
datasets: ['HumanEva', 'HDM05', 'SFU', MoSh']
|
| 35 |
+
file_idces: "all"
|
| 36 |
+
drop_last: False
|
| 37 |
+
if_load_mmgt: False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
data_loader_test:
|
| 41 |
+
shuffle: False
|
| 42 |
+
segments_path: ${eval:"'${dataset.annotations_folder}/segments_test.csv' if not ${if_long_term_test} else '${dataset.annotations_folder}/segments_5s_test_long_term_pred.csv'"}
|
| 43 |
+
datasets:
|
| 44 |
+
- Transitions
|
| 45 |
+
- SSM
|
| 46 |
+
- DFaust
|
| 47 |
+
- DanceDB
|
| 48 |
+
- GRAB
|
| 49 |
+
- HUMAN4D
|
| 50 |
+
- SOMA
|
| 51 |
+
drop_last: False
|
| 52 |
+
if_load_mmgt: ${eval:'True if "probabilistic" in "${stats_mode}" else False'}
|
SkeletonDiffusion/configs/config_eval/dataset/freeman.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_joints: 18 #including the hip root joint
|
| 2 |
+
fps: 30
|
| 3 |
+
dataset_type: FreeManDataset
|
| 4 |
+
dataset_name: "freeman"
|
| 5 |
+
precomputed_folder: "${dataset_precomputed_path}/FreeMan/${task.task_name}/"
|
| 6 |
+
annotations_folder: "${dataset_annotation_path}/FreeMan/${task.task_name}/"
|
| 7 |
+
dtype: float32
|
| 8 |
+
multimodal_threshold: 0.5
|
| 9 |
+
|
| 10 |
+
data_loader_valid:
|
| 11 |
+
shuffle: False
|
| 12 |
+
segments_path: "${dataset.annotations_folder}/segments_valid.csv"
|
| 13 |
+
actions: "all"
|
| 14 |
+
drop_last: False
|
| 15 |
+
if_load_mmgt: ${eval:'True if "probabilistic in str(${stats_mode})" else False'}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
data_loader_test:
|
| 19 |
+
shuffle: False
|
| 20 |
+
segments_path: "${dataset.annotations_folder}/segments_test.csv"
|
| 21 |
+
actions: "all"
|
| 22 |
+
drop_last: False
|
| 23 |
+
if_load_mmgt: ${eval:'True if "probabilistic" in "${stats_mode}" else False'}
|
SkeletonDiffusion/configs/config_eval/dataset/h36m.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_joints: 17 #including the hip root joint
|
| 2 |
+
fps: 50
|
| 3 |
+
dataset_type: H36MDataset
|
| 4 |
+
dataset_name: "h36m"
|
| 5 |
+
precomputed_folder: "${dataset_precomputed_path}/Human36M/${task.task_name}/"
|
| 6 |
+
annotations_folder: "${dataset_annotation_path}/Human36M/${task.task_name}"
|
| 7 |
+
dtype: float32
|
| 8 |
+
multimodal_threshold: 0.5
|
| 9 |
+
|
| 10 |
+
data_loader_valid:
|
| 11 |
+
augmentation: 0
|
| 12 |
+
shuffle: False
|
| 13 |
+
subjects: ["S8"]
|
| 14 |
+
segments_path: "${dataset.annotations_folder}/segments_valid.csv"
|
| 15 |
+
actions: "all"
|
| 16 |
+
drop_last: False
|
| 17 |
+
if_load_mmgt: ${eval:'True if "probabilistic in str(${stats_mode})" else False'}
|
| 18 |
+
|
| 19 |
+
data_loader_test:
|
| 20 |
+
shuffle: False
|
| 21 |
+
augmentation: 0
|
| 22 |
+
segments_path: "${dataset.annotations_folder}/segments_test.csv"
|
| 23 |
+
subjects: ["S9", "S11"]
|
| 24 |
+
actions: "all"
|
| 25 |
+
drop_last: False
|
| 26 |
+
if_load_mmgt: ${eval:'True if "probabilistic" in "${stats_mode}" else False'}
|
SkeletonDiffusion/configs/config_eval/method_specs/skeleton_diffusion.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
method_name: SkeletonDiffusion
|
SkeletonDiffusion/configs/config_eval/method_specs/zerovelocity_alg_baseline.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
motion_repr_type: "SkeletonCenterPose"
|
| 2 |
+
method_name: ZeroVelocityBaseline
|
| 3 |
+
baseline_out_path: ./models/output/baselines
|
SkeletonDiffusion/configs/config_eval/task/hmp.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
history_sec: 0.5 #${eval:'float(1) if ${task}=="motpred" else float(0.5)'}
|
| 2 |
+
prediction_horizon_sec: 2 # ${eval:"float(4) if ${task}=='motpred' else float(2)"}
|
| 3 |
+
task_name: "hmp"
|
| 4 |
+
if_consider_hip: False
|
SkeletonDiffusion/configs/config_train/config_autoencoder.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- task: hmp
|
| 4 |
+
- dataset: h36m
|
| 5 |
+
- model: autoencoder
|
| 6 |
+
- override hydra/job_logging: disabled
|
| 7 |
+
# - override hydra/hydra_logging: disabled
|
| 8 |
+
|
| 9 |
+
dataset_main_path: ./datasets
|
| 10 |
+
dataset_annotation_path: ${dataset_main_path}/annotations&interm
|
| 11 |
+
dataset_precomputed_path: ${dataset_main_path}/processed
|
| 12 |
+
if_resume_training: false
|
| 13 |
+
debug: false
|
| 14 |
+
device: cuda
|
| 15 |
+
load: false
|
| 16 |
+
load_path: ''
|
| 17 |
+
output_log_path: ../../my_exps/output/${task.task_name}/${dataset.dataset_name}/autoencoder/${now:%B%d_%H-%M-%S}_ID${slurm_id}_${info}
|
| 18 |
+
slurm_id: 0
|
| 19 |
+
slurm_first_run: None
|
| 20 |
+
info: ''
|
| 21 |
+
hydra:
|
| 22 |
+
run:
|
| 23 |
+
dir: ${output_log_path}
|
| 24 |
+
job:
|
| 25 |
+
chdir: False
|
| 26 |
+
|
| 27 |
+
|
SkeletonDiffusion/configs/config_train/dataset/amass.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_joints: 22 #including the hip root joint
|
| 2 |
+
fps: 60
|
| 3 |
+
|
| 4 |
+
multimodal_threshold: 0.4
|
| 5 |
+
dataset_type: AMASSDataset
|
| 6 |
+
dataset_name: amass
|
| 7 |
+
precomputed_folder: "${dataset_precomputed_path}/AMASS/${task.task_name}/"
|
| 8 |
+
annotations_folder: "${dataset_annotation_path}/AMASS/${task.task_name}/"
|
| 9 |
+
dtype: float32
|
| 10 |
+
|
| 11 |
+
# Accordingly, the training set
|
| 12 |
+
# contains the ACCAD, BMLhandball, BMLmovi, BMLrub,
|
| 13 |
+
# CMU, EKUT, EyesJapanDataset, KIT, PosePrior, TCD-
|
| 14 |
+
# Hands, and TotalCapture datasets, and the validation set
|
| 15 |
+
# contains the HumanEva, HDM05, SFU, and MoSh datasets.
|
| 16 |
+
# The remaining datasets are all part of the test set: DFaust,
|
| 17 |
+
# DanceDB, GRAB, HUMAN4D, SOMA, SSM, and Transi-
|
| 18 |
+
# tions.
|
| 19 |
+
|
| 20 |
+
data_loader_train:
|
| 21 |
+
stride: 60
|
| 22 |
+
augmentation: 30
|
| 23 |
+
shuffle: True
|
| 24 |
+
datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture']
|
| 25 |
+
da_mirroring: 0.5
|
| 26 |
+
da_rotations: 1.0
|
| 27 |
+
drop_last: True
|
| 28 |
+
if_load_mmgt: False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
data_loader_train_eval:
|
| 32 |
+
stride: 30
|
| 33 |
+
augmentation: 0
|
| 34 |
+
shuffle: False
|
| 35 |
+
datasets: ['ACCAD', "BMLhandball", "BMLmovi", "BMLrub", 'EKUT', 'CMU', 'EyesJapanDataset', 'KIT', "PosePrior", 'TCDHands', 'TotalCapture']
|
| 36 |
+
da_mirroring: 0.
|
| 37 |
+
da_rotations: 0.
|
| 38 |
+
drop_last: False
|
| 39 |
+
if_load_mmgt: False
|
| 40 |
+
|
| 41 |
+
data_loader_valid:
|
| 42 |
+
stride: 30
|
| 43 |
+
augmentation: 0
|
| 44 |
+
shuffle: False
|
| 45 |
+
datasets: ['HumanEva', 'HDM05', 'SFU', 'MoSh']
|
| 46 |
+
file_idces: "all"
|
| 47 |
+
drop_last: False
|
| 48 |
+
if_load_mmgt: False
|
SkeletonDiffusion/configs/config_train/dataset/freeman.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_joints: 18 #including the hip root joint
|
| 2 |
+
fps: 30
|
| 3 |
+
|
| 4 |
+
multimodal_threshold: 0.5
|
| 5 |
+
dataset_type: FreeManDataset
|
| 6 |
+
dataset_name: freeman
|
| 7 |
+
precomputed_folder: "${dataset_precomputed_path}/FreeMan/${task.task_name}/"
|
| 8 |
+
annotations_folder: "${dataset_annotation_path}/FreeMan/${task.task_name}/"
|
| 9 |
+
dtype: float32
|
| 10 |
+
|
| 11 |
+
data_loader_train:
|
| 12 |
+
stride: 10
|
| 13 |
+
augmentation: 5
|
| 14 |
+
shuffle: True
|
| 15 |
+
actions: "all"
|
| 16 |
+
da_mirroring: 0.5
|
| 17 |
+
da_rotations: 1.0
|
| 18 |
+
drop_last: True
|
| 19 |
+
if_load_mmgt: False
|
| 20 |
+
|
| 21 |
+
data_loader_train_eval:
|
| 22 |
+
stride: 30
|
| 23 |
+
augmentation: 0
|
| 24 |
+
shuffle: False
|
| 25 |
+
actions: "all"
|
| 26 |
+
da_mirroring: 0.
|
| 27 |
+
da_rotations: 0.
|
| 28 |
+
drop_last: False
|
| 29 |
+
if_load_mmgt: False
|
| 30 |
+
|
| 31 |
+
data_loader_valid:
|
| 32 |
+
stride: 30
|
| 33 |
+
augmentation: 0
|
| 34 |
+
shuffle: False
|
| 35 |
+
segments_path: "${dataset.annotations_folder}/segments_valid.csv"
|
| 36 |
+
actions: "all"
|
| 37 |
+
drop_last: False
|
| 38 |
+
if_load_mmgt: False
|
SkeletonDiffusion/configs/config_train/dataset/h36m.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_joints: 17 #including the hip root joint
|
| 2 |
+
fps: 50
|
| 3 |
+
|
| 4 |
+
multimodal_threshold: 0.5
|
| 5 |
+
dataset_type: H36MDataset
|
| 6 |
+
dataset_name: h36m
|
| 7 |
+
precomputed_folder: "${dataset_precomputed_path}/Human36M/${task.task_name}/"
|
| 8 |
+
annotations_folder: "${dataset_annotation_path}/Human36M/${task.task_name}"
|
| 9 |
+
dtype: float32
|
| 10 |
+
data_loader_train:
|
| 11 |
+
stride: 10
|
| 12 |
+
augmentation: 5
|
| 13 |
+
shuffle: True
|
| 14 |
+
subjects: ["S1", "S5", "S6", "S7", "S8"] # training on the validation split as well as in BeLFusion, CoMusion
|
| 15 |
+
actions: "all"
|
| 16 |
+
da_mirroring: 0.5
|
| 17 |
+
da_rotations: 1.0
|
| 18 |
+
drop_last: True
|
| 19 |
+
if_load_mmgt: False
|
| 20 |
+
|
| 21 |
+
data_loader_train_eval:
|
| 22 |
+
stride: 30
|
| 23 |
+
augmentation: 0
|
| 24 |
+
shuffle: False
|
| 25 |
+
subjects: ["S1", "S5", "S6", "S7", "S8"]
|
| 26 |
+
actions: "all"
|
| 27 |
+
da_mirroring: 0.
|
| 28 |
+
da_rotations: 0.
|
| 29 |
+
drop_last: False
|
| 30 |
+
if_load_mmgt: False
|
| 31 |
+
|
| 32 |
+
data_loader_valid:
|
| 33 |
+
stride: 30
|
| 34 |
+
augmentation: 0
|
| 35 |
+
shuffle: False
|
| 36 |
+
subjects: ["S8"]
|
| 37 |
+
segments_path: "${dataset.annotations_folder}/segments_valid.csv"
|
| 38 |
+
actions: "all"
|
| 39 |
+
drop_last: False
|
| 40 |
+
if_load_mmgt: False
|
SkeletonDiffusion/configs/config_train/model/autoencoder.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 64
|
| 2 |
+
batch_size_eval: 512
|
| 3 |
+
eval_frequency: 5
|
| 4 |
+
num_epochs: 200
|
| 5 |
+
num_iteration_eval: 10
|
| 6 |
+
num_workers: 4
|
| 7 |
+
seed: 52345
|
| 8 |
+
|
| 9 |
+
use_lr_scheduler: True
|
| 10 |
+
lr_scheduler_kwargs:
|
| 11 |
+
lr_scheduler_type: ExponentialLRSchedulerWarmup
|
| 12 |
+
warmup_duration: 10
|
| 13 |
+
update_every: 1
|
| 14 |
+
min_lr: 1.e-4
|
| 15 |
+
gamma_decay: 0.98
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
loss_pose_type: l1
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
lr: 0.5e-2
|
| 22 |
+
|
| 23 |
+
latent_size: 96
|
| 24 |
+
output_size: 3 #128
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
z_activation: tanh
|
| 28 |
+
|
| 29 |
+
num_iter_perepoch: ${eval:"int(485) if ${eval:"'${dataset.dataset_name}' == 'h36m'"} else 580"}
|
| 30 |
+
|
| 31 |
+
obs_length: ${eval:'int(${task.history_sec} * ${dataset.fps})'}
|
| 32 |
+
prediction_horizon_train: ${model.prediction_horizon}
|
| 33 |
+
prediction_horizon_eval: ${model.prediction_horizon}
|
| 34 |
+
prediction_horizon: ${eval:'int(${task.prediction_horizon_sec} * ${dataset.fps})'}
|
| 35 |
+
pred_length: ${model.prediction_horizon_eval}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
autoenc_arch:
|
| 40 |
+
enc_num_layers: 1
|
| 41 |
+
encoder_hidden_size: 96
|
| 42 |
+
decoder_hidden_size: 96
|
| 43 |
+
arch: AutoEncoder
|
| 44 |
+
recurrent_arch_enc: StaticGraphGRU
|
| 45 |
+
recurrent_arch_decoder: StaticGraphGRU
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
prediction_horizon_train_min: 10
|
| 50 |
+
prediction_horizon_train_min_from_epoch: 200
|
| 51 |
+
curriculum_it: 10
|
| 52 |
+
random_prediction_horizon: True
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
SkeletonDiffusion/configs/config_train/task/hmp.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if_consider_hip: False
|
| 2 |
+
|
| 3 |
+
history_sec: 0.5
|
| 4 |
+
prediction_horizon_sec: 2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Joint representation & Skeleton specs
|
| 8 |
+
motion_repr_type: "SkeletonRescalePose"
|
| 9 |
+
pose_box_size: 1.5 # in meters
|
| 10 |
+
seq_centering: 0
|
| 11 |
+
task_name: hmp
|
SkeletonDiffusion/configs/config_train_diffusion/config_diffusion.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if_resume_training: false
|
| 2 |
+
debug: false
|
| 3 |
+
device: cuda
|
| 4 |
+
load: false
|
| 5 |
+
load_path: ''
|
| 6 |
+
dataset_main_path: ./datasets
|
| 7 |
+
dataset_annotation_path: ${dataset_main_path}/annotations&interm
|
| 8 |
+
dataset_precomputed_path: ${dataset_main_path}/processed
|
| 9 |
+
_load_saved_aoutoenc: hmp-h36m
|
| 10 |
+
|
| 11 |
+
output_log_path: ../../my_exps/output/${eval:"'${_load_saved_aoutoenc}'.split('-', 1)[0]"}/${eval:"'${_load_saved_aoutoenc}'.split('-', 1)[1]"}/diffusion/${now:%B%d_%H-%M-%S}_ID${slurm_id}_${info}
|
| 12 |
+
slurm_id: 0
|
| 13 |
+
slurm_first_run: None
|
| 14 |
+
info: ''
|
| 15 |
+
hydra:
|
| 16 |
+
run:
|
| 17 |
+
dir: ${output_log_path}
|
| 18 |
+
job:
|
| 19 |
+
chdir: False
|
| 20 |
+
|
| 21 |
+
defaults:
|
| 22 |
+
- _self_
|
| 23 |
+
- model: skeleton_diffusion
|
| 24 |
+
- cov_matrix: adjacency
|
| 25 |
+
- override hydra/job_logging: disabled
|
SkeletonDiffusion/configs/config_train_diffusion/cov_matrix/adjacency.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
covariance_matrix_type: adjacency
|
SkeletonDiffusion/configs/config_train_diffusion/cov_matrix/reachability.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
covariance_matrix_type: reachability
|
| 2 |
+
reachability_matrix_degree_factor: 0.5
|
| 3 |
+
reachability_matrix_stop_at: hips # or None 'hips' or null
|
SkeletonDiffusion/configs/config_train_diffusion/model/isotropic_diffusion.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pretrained_GM_folder: ${eval:"'./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310' if ${eval:"'${_load_saved_aoutoenc}'.split('-')[1] == 'h36m'"} else './models/final_checkpoints/AMASS/hmp/autoencoder/May11_10-35-09_ID1185354'"}
|
| 2 |
+
pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/checkpoint_final.pt'
|
| 3 |
+
|
| 4 |
+
# These options have to still be checked (ablation)
|
| 5 |
+
lr: 1.e-3
|
| 6 |
+
diffusion_objective: pred_x0
|
| 7 |
+
weight_decay: 0.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if_use_ema: True
|
| 11 |
+
step_start_ema: 100 #100 is default
|
| 12 |
+
ema_power: ${eval:'2/3'} #${eval:'3/4'} #
|
| 13 |
+
ema_update_every: 10
|
| 14 |
+
ema_min_value: 0.0
|
| 15 |
+
use_lr_scheduler: True
|
| 16 |
+
lr_scheduler_kwargs:
|
| 17 |
+
lr_scheduler_type: ExponentialLRSchedulerWarmup # or SchedulerReduceLROnPlateau
|
| 18 |
+
warmup_duration: 200
|
| 19 |
+
update_every: 10
|
| 20 |
+
min_lr: 1.e-4
|
| 21 |
+
gamma_decay: 0.98
|
| 22 |
+
|
| 23 |
+
# THese option are already ablated
|
| 24 |
+
diffusion_conditioning: True
|
| 25 |
+
num_epochs: 600
|
| 26 |
+
num_workers: 4
|
| 27 |
+
batch_size: 64
|
| 28 |
+
batch_size_eval: 256 # or 256 TO CHECK
|
| 29 |
+
eval_frequency: 25 #1000 # in epochs
|
| 30 |
+
train_pick_best_sample_among_k: 50
|
| 31 |
+
similarity_space: latent_space # input_space, latent_space or metric_space
|
| 32 |
+
|
| 33 |
+
diffusion_activation: identity
|
| 34 |
+
num_prob_samples: 50
|
| 35 |
+
diffusion_timesteps: 10
|
| 36 |
+
|
| 37 |
+
diffusion_type: IsotropicGaussianDiffusion
|
| 38 |
+
beta_schedule: cosine
|
| 39 |
+
diffusion_loss_type: l1
|
| 40 |
+
num_iter_perepoch: null
|
| 41 |
+
seed: 63485
|
| 42 |
+
|
| 43 |
+
diffusion_arch:
|
| 44 |
+
arch: Denoiser
|
| 45 |
+
use_attention: True
|
| 46 |
+
self_condition: False
|
| 47 |
+
norm_type: none
|
| 48 |
+
depth: 1
|
| 49 |
+
# resnet_block_groups = 8,
|
| 50 |
+
# learned_variance: False
|
| 51 |
+
# learned_sinusoidal_cond: False
|
| 52 |
+
# random_fourier_features: False
|
| 53 |
+
# learned_sinusoidal_dim: 16
|
| 54 |
+
# sinusoidal_pos_emb_theta: 10000
|
| 55 |
+
attn_dim_head: 32
|
| 56 |
+
attn_heads: 4
|
| 57 |
+
learn_influence: True
|
SkeletonDiffusion/configs/config_train_diffusion/model/isotropic_diffusion_in_noniso_class.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pretrained_GM_folder: ./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310
|
| 2 |
+
# _pretrained_GM_checkpoint: checkpoint_final
|
| 3 |
+
# pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/${model._pretrained_GM_checkpoint}.pt'
|
| 4 |
+
|
| 5 |
+
_pretrained_GM_checkpoint: checkpoint_final
|
| 6 |
+
pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/${model._pretrained_GM_checkpoint}.pt'
|
| 7 |
+
pretrained_GM_folder: ${eval:"'./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310' if ${eval:"'${_load_saved_aoutoenc}'.split('-')[1] == 'h36m'"} else './models/final_checkpoints/AMASS/hmp/autoencoder/May11_10-35-09_ID1185354'"}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# These options have to still be checked (ablation)
|
| 11 |
+
lr: 1.e-3
|
| 12 |
+
diffusion_objective: pred_x0
|
| 13 |
+
weight_decay: 0.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if_use_ema: True
|
| 17 |
+
step_start_ema: 100 #100 is default
|
| 18 |
+
ema_power: ${eval:'2/3'} #${eval:'3/4'} #
|
| 19 |
+
ema_update_every: 10
|
| 20 |
+
ema_min_value: 0.0
|
| 21 |
+
use_lr_scheduler: True
|
| 22 |
+
lr_scheduler_kwargs:
|
| 23 |
+
lr_scheduler_type: ExponentialLRSchedulerWarmup # or SchedulerReduceLROnPlateau
|
| 24 |
+
warmup_duration: 200
|
| 25 |
+
update_every: 10
|
| 26 |
+
min_lr: 1.e-4
|
| 27 |
+
gamma_decay: 0.98
|
| 28 |
+
|
| 29 |
+
# THese option are already ablated
|
| 30 |
+
diffusion_conditioning: True
|
| 31 |
+
num_epochs: 800
|
| 32 |
+
num_workers: 4
|
| 33 |
+
batch_size: 64
|
| 34 |
+
batch_size_eval: 256
|
| 35 |
+
eval_frequency: 25 #1000 # in epochs
|
| 36 |
+
train_pick_best_sample_among_k: 50
|
| 37 |
+
similarity_space: latent_space # input_space, latent_space or metric_space
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
diffusion_activation: identity
|
| 41 |
+
num_prob_samples: 50
|
| 42 |
+
diffusion_timesteps: 10
|
| 43 |
+
|
| 44 |
+
diffusion_type: NonisotropicGaussianDiffusion
|
| 45 |
+
diffusion_loss_type: snr #snr_triangle_inequality,mahalanobis, snr
|
| 46 |
+
loss_reduction_type: l1
|
| 47 |
+
if_run_as_isotropic: True
|
| 48 |
+
if_sigma_n_scale: True
|
| 49 |
+
diffusion_covariance_type: isotropic # anisotropic, isotropic, skeleton-diffusion
|
| 50 |
+
gamma_scheduler: cosine # mono_decrease, cosine
|
| 51 |
+
beta_schedule: cosine
|
| 52 |
+
sigma_n_scale: spectral
|
| 53 |
+
num_iter_perepoch: null
|
| 54 |
+
seed: 63485
|
| 55 |
+
|
| 56 |
+
diffusion_arch:
|
| 57 |
+
arch: Denoiser
|
| 58 |
+
use_attention: True
|
| 59 |
+
self_condition: False # True holds better results, but it takes longer to train.
|
| 60 |
+
norm_type: none
|
| 61 |
+
depth: 1
|
| 62 |
+
# resnet_block_groups = 8,
|
| 63 |
+
# learned_variance: False
|
| 64 |
+
# learned_sinusoidal_cond: False
|
| 65 |
+
# random_fourier_features: False
|
| 66 |
+
# learned_sinusoidal_dim: 16
|
| 67 |
+
# sinusoidal_pos_emb_theta: 10000
|
| 68 |
+
attn_dim_head: 32
|
| 69 |
+
attn_heads: 4
|
| 70 |
+
learn_influence: True
|
SkeletonDiffusion/configs/config_train_diffusion/model/skeleton_diffusion.yaml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pretrained_GM_folder: ./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310
|
| 2 |
+
# _pretrained_GM_checkpoint: checkpoint_final
|
| 3 |
+
# pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/${model._pretrained_GM_checkpoint}.pt'
|
| 4 |
+
|
| 5 |
+
_pretrained_GM_checkpoint: checkpoint_final
|
| 6 |
+
pretrained_autoencoder_path: '${model.pretrained_GM_folder}/checkpoints/${model._pretrained_GM_checkpoint}.pt'
|
| 7 |
+
pretrained_GM_folder: ${eval:"'./models/final_checkpoints/H36M/hmp/autoencoder/January19_19-24-04_ID1137310' if ${eval:"'${_load_saved_aoutoenc}'.split('-')[1] == 'h36m'"} else './models/final_checkpoints/AMASS/hmp/autoencoder/May11_10-35-09_ID1185354'"}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# These options have to still be checked (ablation)
|
| 11 |
+
lr: 1.e-3
|
| 12 |
+
diffusion_objective: pred_x0
|
| 13 |
+
weight_decay: 0.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if_use_ema: True
|
| 17 |
+
step_start_ema: 100 #100 is default
|
| 18 |
+
ema_power: ${eval:'2/3'} #${eval:'3/4'} #
|
| 19 |
+
ema_update_every: 10
|
| 20 |
+
ema_min_value: 0.0
|
| 21 |
+
use_lr_scheduler: True
|
| 22 |
+
lr_scheduler_kwargs:
|
| 23 |
+
lr_scheduler_type: ExponentialLRSchedulerWarmup # or SchedulerReduceLROnPlateau
|
| 24 |
+
warmup_duration: 200
|
| 25 |
+
update_every: 10
|
| 26 |
+
min_lr: 1.e-4
|
| 27 |
+
gamma_decay: 0.98
|
| 28 |
+
|
| 29 |
+
diffusion_conditioning: True
|
| 30 |
+
num_epochs: 800
|
| 31 |
+
num_workers: 4
|
| 32 |
+
batch_size: 64
|
| 33 |
+
batch_size_eval: 256
|
| 34 |
+
eval_frequency: 25 #1000 # in epochs
|
| 35 |
+
train_pick_best_sample_among_k: 50
|
| 36 |
+
similarity_space: latent_space # input_space, latent_space or metric_space
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
diffusion_activation: identity
|
| 40 |
+
num_prob_samples: 50
|
| 41 |
+
diffusion_timesteps: 10
|
| 42 |
+
|
| 43 |
+
diffusion_type: NonisotropicGaussianDiffusion
|
| 44 |
+
diffusion_loss_type: snr #snr_triangle_inequality,mahalanobis, snr
|
| 45 |
+
loss_reduction_type: l1
|
| 46 |
+
if_run_as_isotropic: False
|
| 47 |
+
if_sigma_n_scale: True
|
| 48 |
+
diffusion_covariance_type: skeleton-diffusion # anisotropic, isotropic, skeleton-diffusion
|
| 49 |
+
gamma_scheduler: cosine # mono_decrease, cosine
|
| 50 |
+
beta_schedule: cosine
|
| 51 |
+
sigma_n_scale: spectral
|
| 52 |
+
num_iter_perepoch: null
|
| 53 |
+
seed: 63485
|
| 54 |
+
|
| 55 |
+
diffusion_arch:
|
| 56 |
+
arch: Denoiser
|
| 57 |
+
use_attention: True
|
| 58 |
+
self_condition: False # True holds better results, but it takes longer to train.
|
| 59 |
+
norm_type: none
|
| 60 |
+
depth: 1
|
| 61 |
+
# resnet_block_groups = 8,
|
| 62 |
+
# learned_variance: False
|
| 63 |
+
# learned_sinusoidal_cond: False
|
| 64 |
+
# random_fourier_features: False
|
| 65 |
+
# learned_sinusoidal_dim: 16
|
| 66 |
+
# sinusoidal_pos_emb_theta: 10000
|
| 67 |
+
attn_dim_head: 32
|
| 68 |
+
attn_heads: 4
|
| 69 |
+
learn_influence: True
|
SkeletonDiffusion/datasets
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../../motion_must_go_on/datasets/
|
SkeletonDiffusion/environment_inference.yml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: skeldiff_inf
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- nvidia
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.10.12
|
| 7 |
+
- pytorch=2.0.1
|
| 8 |
+
- pytorch-cuda=11.8
|
| 9 |
+
- torchvision=0.15.2
|
| 10 |
+
- pyyaml=6.0.1
|
| 11 |
+
- einops=0.7.0
|
| 12 |
+
- pip
|
| 13 |
+
- pip:
|
| 14 |
+
- denoising-diffusion-pytorch==1.9.4
|
| 15 |
+
# - pyyaml=6.0.1
|
| 16 |
+
# - imageio
|
| 17 |
+
# - ipympl=0.9.3
|
| 18 |
+
# - ffmpeg
|
| 19 |
+
# - opencv
|
SkeletonDiffusion/inference.ipynb
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "e927d3c2",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"/storage/user/yaji/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src\n"
|
| 14 |
+
]
|
| 15 |
+
}
|
| 16 |
+
],
|
| 17 |
+
"source": [
|
| 18 |
+
"import os\n",
|
| 19 |
+
"os.chdir(r\"/home/stud/yaji/storage/user/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src\") \n",
|
| 20 |
+
"root_path = os.getcwd()\n",
|
| 21 |
+
"print(root_path)\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\""
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": 2,
|
| 29 |
+
"id": "32b71ab8",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [
|
| 32 |
+
{
|
| 33 |
+
"name": "stderr",
|
| 34 |
+
"output_type": "stream",
|
| 35 |
+
"text": [
|
| 36 |
+
"/home/stud/yaji/miniconda3/envs/live_demo/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 37 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 38 |
+
"/storage/user/yaji/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src/core/diffusion/base.py:184: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 39 |
+
" @autocast(enabled = False)\n",
|
| 40 |
+
"/storage/user/yaji/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src/core/diffusion/isotropic.py:72: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 41 |
+
" @autocast(enabled = False)\n",
|
| 42 |
+
"/storage/user/yaji/yaji/NonisotropicSkeletonDiffusion/SkeletonDiffusion/src/core/diffusion/nonisotropic.py:138: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 43 |
+
" @autocast(enabled = False)\n"
|
| 44 |
+
]
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"source": [
|
| 48 |
+
"from eval_prepare_model import prepare_model, get_prediction, load_model_config_exp\n",
|
| 49 |
+
"from data import create_skeleton\n",
|
| 50 |
+
"import torch\n",
|
| 51 |
+
"import numpy as np\n",
|
| 52 |
+
"import random\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"def set_seed(seed=0):\n",
|
| 55 |
+
" torch.use_deterministic_algorithms(True)\n",
|
| 56 |
+
" torch.backends.cudnn.deterministic = True\n",
|
| 57 |
+
" torch.backends.cudnn.benchmark = False\n",
|
| 58 |
+
" np.random.seed(seed)\n",
|
| 59 |
+
" random.seed(seed)\n",
|
| 60 |
+
" torch.cuda.manual_seed(seed)\n",
|
| 61 |
+
" torch.cuda.manual_seed_all(seed)"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": 3,
|
| 67 |
+
"id": "0963a8bd",
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [],
|
| 70 |
+
"source": [
|
| 71 |
+
"checkpoint_path = '/usr/wiss/curreli/work/my_exps/checkpoints_release/amass/diffusion/cvpr_release/checkpoints/checkpoint_150.pt'\n",
|
| 72 |
+
"# checkpoint_path = '/usr/wiss/curreli/work/my_exps/checkpoints_release/amass-mano/diffusion/cvpr_release/checkpoints/checkpoint_150.pt'\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"num_samples = 50"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": 4,
|
| 81 |
+
"id": "c6ce7b0a",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [
|
| 84 |
+
{
|
| 85 |
+
"name": "stdout",
|
| 86 |
+
"output_type": "stream",
|
| 87 |
+
"text": [
|
| 88 |
+
"> GPU 0 ready: NVIDIA RTX A2000 12GB\n",
|
| 89 |
+
"Loading Autoencoder checkpoint: /usr/wiss/curreli/work/my_exps/checkpoints_release/amass/autoencoder/cvpr_release/checkpoints/checkpoint_300.pt ...\n",
|
| 90 |
+
"Diffusion is_ddim_sampling: False\n",
|
| 91 |
+
"Loading Diffusion checkpoint: /usr/wiss/curreli/work/my_exps/checkpoints_release/amass/diffusion/cvpr_release/checkpoints/checkpoint_150.pt ...\n"
|
| 92 |
+
]
|
| 93 |
+
}
|
| 94 |
+
],
|
| 95 |
+
"source": [
|
| 96 |
+
"set_seed(seed=0)\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"config, exp_folder = load_model_config_exp(checkpoint_path)\n",
|
| 99 |
+
"config['checkpoint_path'] = checkpoint_path\n",
|
| 100 |
+
"skeleton = create_skeleton(**config) \n",
|
| 101 |
+
"\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"model, device, *_ = prepare_model(config, skeleton, **config)"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "code",
|
| 108 |
+
"execution_count": 5,
|
| 109 |
+
"id": "5c4aa1a7",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [
|
| 112 |
+
{
|
| 113 |
+
"data": {
|
| 114 |
+
"text/plain": [
|
| 115 |
+
"torch.Size([1, 30, 22, 3])"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
"execution_count": 5,
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"output_type": "execute_result"
|
| 121 |
+
}
|
| 122 |
+
],
|
| 123 |
+
"source": [
|
| 124 |
+
"# prepare input\n",
|
| 125 |
+
"# load input. It should be in meters\n",
|
| 126 |
+
"import numpy as np\n",
|
| 127 |
+
"import torch\n",
|
| 128 |
+
"obs = np.load('/usr/wiss/curreli/work/my_exps/checkpoints_release/amass/exaple_obs.npy') # (t_past, J, 3)\n",
|
| 129 |
+
"# obs = np.load('/usr/wiss/curreli/work/my_exps/checkpoints_release/amass-mano/example_obs.npy') # (t_past, J, 3)\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"obs = torch.from_numpy(obs).to(device)\n",
|
| 133 |
+
"obs = obs.unsqueeze(0) # add bacth size\n",
|
| 134 |
+
"obs.shape"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": 6,
|
| 140 |
+
"id": "e782b749",
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"outputs": [
|
| 143 |
+
{
|
| 144 |
+
"name": "stdout",
|
| 145 |
+
"output_type": "stream",
|
| 146 |
+
"text": [
|
| 147 |
+
"torch.Size([1, 50, 120, 21, 3])\n",
|
| 148 |
+
"torch.Size([1, 30, 21, 3])\n"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"data": {
|
| 153 |
+
"text/plain": [
|
| 154 |
+
"torch.Size([1, 50, 120, 21, 3])"
|
| 155 |
+
]
|
| 156 |
+
},
|
| 157 |
+
"execution_count": 6,
|
| 158 |
+
"metadata": {},
|
| 159 |
+
"output_type": "execute_result"
|
| 160 |
+
}
|
| 161 |
+
],
|
| 162 |
+
"source": [
|
| 163 |
+
"obs_in = skeleton.tranform_to_input_space(obs) # obs sequence contains hip joints, it has not been dropped yet. \n",
|
| 164 |
+
"pred = get_prediction(obs_in, model, num_samples=num_samples, **config) # [batch_size, n_samples, seq_length, num_joints, features]\n",
|
| 165 |
+
"print(pred.shape)\n",
|
| 166 |
+
"pred = skeleton.transform_to_metric_space(pred)\n",
|
| 167 |
+
"print(obs_in.shape)\n",
|
| 168 |
+
"pred.shape"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "code",
|
| 173 |
+
"execution_count": null,
|
| 174 |
+
"id": "d9394774",
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"outputs": [],
|
| 177 |
+
"source": [
|
| 178 |
+
"kpts3d = pred.cpu()[0][0]\n",
|
| 179 |
+
"import matplotlib.pyplot as plt\n",
|
| 180 |
+
"for i in range(120):\n",
|
| 181 |
+
" plt.figure()\n",
|
| 182 |
+
" plt.scatter(kpts3d[0, :, 1], kpts3d[0, :, 2])\n",
|
| 183 |
+
" plt.gca().set_aspect('equal')\n",
|
| 184 |
+
" plt.savefig(f'../../vis/kpts3d_{i}.png')"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"cell_type": "code",
|
| 189 |
+
"execution_count": 14,
|
| 190 |
+
"id": "f91ee849",
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [
|
| 193 |
+
{
|
| 194 |
+
"name": "stdout",
|
| 195 |
+
"output_type": "stream",
|
| 196 |
+
"text": [
|
| 197 |
+
"torch.Size([1, 30, 21, 3])\n",
|
| 198 |
+
"torch.Size([1, 50, 120, 21, 3])\n",
|
| 199 |
+
"torch.Size([1, 30, 20, 3])\n",
|
| 200 |
+
"tensor([[0.0123, 0.0127, 0.0128, 0.0128, 0.0131, 0.0131, 0.0137, 0.0144, 0.0148,\n",
|
| 201 |
+
" 0.0149, 0.0151, 0.0153, 0.0154, 0.0158, 0.0164, 0.0168, 0.0170, 0.0171,\n",
|
| 202 |
+
" 0.0171, 0.0173, 0.0189, 0.0193, 0.0224, 0.0229, 0.0235, 0.0237, 0.0244,\n",
|
| 203 |
+
" 0.0248, 0.0261, 0.0269, 0.0270, 0.0271, 0.0274, 0.0275, 0.0280, 0.0282,\n",
|
| 204 |
+
" 0.0292, 0.0293, 0.0307, 0.0339, 0.0346, 0.0350, 0.0351, 0.0367, 0.0379,\n",
|
| 205 |
+
" 0.0386, 0.0391, 0.0395, 0.0439, 0.0523]], device='cuda:0')\n"
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"data": {
|
| 210 |
+
"text/plain": [
|
| 211 |
+
"tensor([[ 4, 12, 31, 34, 27, 29, 2, 46, 20, 11, 15, 33, 21, 14, 9, 38, 41, 1,\n",
|
| 212 |
+
" 22, 35, 19, 43, 16, 48, 5, 47, 25, 40, 8, 28, 39, 45, 17, 23, 37, 18,\n",
|
| 213 |
+
" 6, 42, 49, 26, 24, 13, 36, 3, 44, 0, 7, 10, 32, 30]],\n",
|
| 214 |
+
" device='cuda:0')"
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
"execution_count": 14,
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"output_type": "execute_result"
|
| 220 |
+
}
|
| 221 |
+
],
|
| 222 |
+
"source": [
|
| 223 |
+
"# rank predictions according to Limb Stretching. We will visualize first the prediction that have lower limb stretching --> more realistic\n",
|
| 224 |
+
"from metrics.body_realism import limb_stretching_normed_mean, limb_stretching_normed_rmse\n",
|
| 225 |
+
"print(obs_in.shape)\n",
|
| 226 |
+
"print(pred.shape)\n",
|
| 227 |
+
"print(obs_in[..., 1:, :].shape)\n",
|
| 228 |
+
"# limbstretching = limb_stretching_normed_mean(pred, target=obs[..., 1:, :].unsqueeze(1), limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
|
| 229 |
+
"limbstretching = limb_stretching_normed_rmse(pred, target=obs[..., 1:, :].unsqueeze(1), limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
|
| 230 |
+
"limbstretching_sorted, indices = torch.sort(limbstretching.squeeze(1), dim=-1, descending=False) \n",
|
| 231 |
+
"\n",
|
| 232 |
+
"print(limbstretching_sorted)\n",
|
| 233 |
+
"indices\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"# TO DO: index predictions with these indices.\n"
|
| 236 |
+
]
|
| 237 |
+
},
|
| 238 |
+
{
|
| 239 |
+
"cell_type": "code",
|
| 240 |
+
"execution_count": 27,
|
| 241 |
+
"id": "1a9a941a",
|
| 242 |
+
"metadata": {},
|
| 243 |
+
"outputs": [
|
| 244 |
+
{
|
| 245 |
+
"name": "stdout",
|
| 246 |
+
"output_type": "stream",
|
| 247 |
+
"text": [
|
| 248 |
+
"Observation shape: torch.Size([1, 30, 21, 3])\n",
|
| 249 |
+
"Prediction shape: torch.Size([1, 50, 120, 22, 3])\n",
|
| 250 |
+
"tensor([[0.0446, 0.0711, 0.0517, 0.1149, 0.0689, 0.0238, 0.0374, 0.0329, 0.0411,\n",
|
| 251 |
+
" 0.0573, 0.0565, 0.0855, 0.0375, 0.1141, 0.0402, 0.0385, 0.0564, 0.0727,\n",
|
| 252 |
+
" 0.0904, 0.0620, 0.0374, 0.0363, 0.0443, 0.0386, 0.0702, 0.0413, 0.0455,\n",
|
| 253 |
+
" 0.0468, 0.1038, 0.0691, 0.0630, 0.0320, 0.0489, 0.0422, 0.0520, 0.0756,\n",
|
| 254 |
+
" 0.0444, 0.0414, 0.0852, 0.0673, 0.0391, 0.0500, 0.0484, 0.0457, 0.0556,\n",
|
| 255 |
+
" 0.0393, 0.0674, 0.0349, 0.0392, 0.0459]], device='cuda:0')\n"
|
| 256 |
+
]
|
| 257 |
+
}
|
| 258 |
+
],
|
| 259 |
+
"source": [
|
| 260 |
+
"# read pred and obs from predictions/joints3d.npy\n",
|
| 261 |
+
"frames_for_half_second = 30\n",
|
| 262 |
+
"# Load the joints3d data from the saved numpy file\n",
|
| 263 |
+
"joints3d = np.load('/home/stud/yaji/storage/user/yaji/NonisotropicSkeletonDiffusion/predictions/joints3d.npy')\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"# Split the data into observation and prediction parts\n",
|
| 266 |
+
"# The first frames_for_half_second frames are observations\n",
|
| 267 |
+
"obs = joints3d[:, 0, :frames_for_half_second, :, :] # [1, num_samples, frames_for_half_second, 22, 3]\n",
|
| 268 |
+
"pred = joints3d[:, :, frames_for_half_second:, :, :] # [1, num_samples, pred_length, 22, 3]\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"# Convert to torch tensors and move to device\n",
|
| 271 |
+
"obs = torch.from_numpy(obs).to(device)\n",
|
| 272 |
+
"pred = torch.from_numpy(pred).to(device)\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"print(\"Observation shape:\", obs[0, ..., 1:, :].unsqueeze(0).shape)\n",
|
| 275 |
+
"print(\"Prediction shape:\", pred.shape)\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"# calculate \n",
|
| 278 |
+
"limbstretching = limb_stretching_normed_rmse(pred[..., 1:, :], target=obs[0, ..., 1:, :].unsqueeze(0), limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
|
| 279 |
+
"limbstretching_sorted, indices = torch.sort(limbstretching.squeeze(1), dim=-1, descending=False) \n",
|
| 280 |
+
"print(limbstretching)\n"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"cell_type": "code",
|
| 285 |
+
"execution_count": 11,
|
| 286 |
+
"id": "e51ccbee",
|
| 287 |
+
"metadata": {},
|
| 288 |
+
"outputs": [
|
| 289 |
+
{
|
| 290 |
+
"name": "stdout",
|
| 291 |
+
"output_type": "stream",
|
| 292 |
+
"text": [
|
| 293 |
+
"torch.Size([1, 120, 21, 3])\n",
|
| 294 |
+
"torch.Size([50, 120, 21, 3])\n",
|
| 295 |
+
"torch.Size([120, 21, 3])\n",
|
| 296 |
+
"torch.Size([10, 120, 21, 3])\n",
|
| 297 |
+
"[8, 18, 36, 30, 37, 43, 26, 17, 41, 6]\n"
|
| 298 |
+
]
|
| 299 |
+
}
|
| 300 |
+
],
|
| 301 |
+
"source": [
|
| 302 |
+
"from metrics.ranking import get_closest_and_nfurthest_maxapd\n",
|
| 303 |
+
"# If you see problems with the visualizations, you can remove predictions that have limb stretching > 0.04\n",
|
| 304 |
+
"# limbstretching = limb_stretching_normed_mean(pred, target=obs[..., 1:, :], limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
|
| 305 |
+
"# remove batch dimension\n",
|
| 306 |
+
"y_pred = pred.squeeze(0) # [n_samples, seq_length, num_joints, features]\n",
|
| 307 |
+
"#if GT is not there, we use the first sample as GT reference i.e. the most likely closest to GT\n",
|
| 308 |
+
"y_gt = y_pred[0].unsqueeze(0) # [seq_length, num_joints, features]\n",
|
| 309 |
+
"print(y_gt.shape)\n",
|
| 310 |
+
"print(y_pred.shape)\n",
|
| 311 |
+
"pred_closest, sorted_preds, sorted_preds_idxs = get_closest_and_nfurthest_maxapd(y_pred, y_gt, nsamples=10)\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"print(pred_closest.shape)\n",
|
| 314 |
+
"print(sorted_preds.shape)\n",
|
| 315 |
+
"print(sorted_preds_idxs)\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"\n"
|
| 319 |
+
]
|
| 320 |
+
}
|
| 321 |
+
],
|
| 322 |
+
"metadata": {
|
| 323 |
+
"kernelspec": {
|
| 324 |
+
"display_name": "live_demo",
|
| 325 |
+
"language": "python",
|
| 326 |
+
"name": "python3"
|
| 327 |
+
},
|
| 328 |
+
"language_info": {
|
| 329 |
+
"codemirror_mode": {
|
| 330 |
+
"name": "ipython",
|
| 331 |
+
"version": 3
|
| 332 |
+
},
|
| 333 |
+
"file_extension": ".py",
|
| 334 |
+
"mimetype": "text/x-python",
|
| 335 |
+
"name": "python",
|
| 336 |
+
"nbconvert_exporter": "python",
|
| 337 |
+
"pygments_lexer": "ipython3",
|
| 338 |
+
"version": "3.10.13"
|
| 339 |
+
}
|
| 340 |
+
},
|
| 341 |
+
"nbformat": 4,
|
| 342 |
+
"nbformat_minor": 5
|
| 343 |
+
}
|
SkeletonDiffusion/inference_filtered.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
SkeletonDiffusion/setup.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="SkeletonDiffusion",
|
| 5 |
+
version="0.1.0",
|
| 6 |
+
package_dir={"": "src"},
|
| 7 |
+
packages=find_packages(where="src"),
|
| 8 |
+
install_requires=[
|
| 9 |
+
"torch",
|
| 10 |
+
"numpy",
|
| 11 |
+
],
|
| 12 |
+
python_requires=">=3.8",
|
| 13 |
+
)
|
SkeletonDiffusion/src/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
# Add the src directory to the Python path
|
| 5 |
+
src_path = os.path.dirname(os.path.abspath(__file__))
|
| 6 |
+
if src_path not in sys.path:
|
| 7 |
+
sys.path.insert(0, src_path)
|
SkeletonDiffusion/src/config_utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from ignite.handlers import Checkpoint
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from SkeletonDiffusion.src.utils.config import init_obj
|
| 7 |
+
from SkeletonDiffusion.src.utils.reproducibility import seed_worker, seed_eval_worker, RandomStateDict
|
| 8 |
+
import SkeletonDiffusion.src.data.loaders as dataset_type
|
| 9 |
+
|
| 10 |
+
from SkeletonDiffusion.src.utils.load import get_latest_model_path
|
| 11 |
+
from SkeletonDiffusion.src.inference_utils import create_model
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def create_train_dataloaders(batch_size, batch_size_eval, num_workers, skeleton, if_run_validation=True, if_resume_training=False, **config):
|
| 16 |
+
|
| 17 |
+
random_state_manager = RandomStateDict()
|
| 18 |
+
if if_resume_training:
|
| 19 |
+
checkpoint = torch.load(config['load_path'])
|
| 20 |
+
Checkpoint.load_objects(to_load={"random_states": random_state_manager}, checkpoint=checkpoint)
|
| 21 |
+
dataset_train = init_obj(config, 'dataset_type', dataset_type, split="train", skeleton=skeleton, **(config['data_loader_train']))
|
| 22 |
+
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, worker_init_fn = seed_worker, pin_memory= True,
|
| 23 |
+
num_workers=num_workers, generator=random_state_manager.generator)
|
| 24 |
+
|
| 25 |
+
if if_run_validation:
|
| 26 |
+
dataset_eval = init_obj(config, 'dataset_type', dataset_type, split="valid", skeleton=skeleton, **(config['data_loader_valid']))
|
| 27 |
+
dataset_eval_train = init_obj(config, 'dataset_type', dataset_type, split="train", skeleton=skeleton, **(config['data_loader_train_eval']))
|
| 28 |
+
data_loader_eval = DataLoader(dataset_eval, shuffle=False, worker_init_fn=seed_eval_worker, batch_size=batch_size_eval, num_workers=1, pin_memory= True)
|
| 29 |
+
data_loader_train_eval = DataLoader(dataset_eval_train, shuffle=False, worker_init_fn=seed_eval_worker, batch_size=batch_size_eval, num_workers=1, pin_memory= True)
|
| 30 |
+
else:
|
| 31 |
+
data_loader_eval = None
|
| 32 |
+
data_loader_train_eval = None
|
| 33 |
+
return data_loader_train, data_loader_eval, data_loader_train_eval, random_state_manager
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def flat_hydra_config(cfg):
|
| 37 |
+
"""
|
| 38 |
+
Flatten the main dict categories of the Hydra config object into a single one.
|
| 39 |
+
"""
|
| 40 |
+
for subconf in ['model', 'task', 'dataset', 'autoenc_arch', 'cov_matrix']:
|
| 41 |
+
if subconf in cfg:
|
| 42 |
+
cfg = {**cfg, **cfg[subconf]}
|
| 43 |
+
cfg.pop(subconf)
|
| 44 |
+
return cfg
|
| 45 |
+
|
| 46 |
+
def resume_training(cfg):
|
| 47 |
+
#output folder has been already created.
|
| 48 |
+
assert 'output_log_path' in cfg
|
| 49 |
+
|
| 50 |
+
# decide whether to start from scratch (default if no checkpoints), from latest save (if checkpoints exists), or from given path (if load_path is given)
|
| 51 |
+
|
| 52 |
+
assert len(os.listdir(os.path.join(cfg['output_log_path'], 'checkpoints'))) != 0, "Checkpoints folder is empty. Please provide a valid path to load from."
|
| 53 |
+
|
| 54 |
+
if len(cfg['load_path']) == 0:
|
| 55 |
+
# load latest model
|
| 56 |
+
cfg['load_path'] = get_latest_model_path(os.path.join(cfg['output_log_path'], 'checkpoints'))
|
| 57 |
+
print("Loading latest epoch: ", cfg['load_path'].split('/')[-1])
|
| 58 |
+
else:
|
| 59 |
+
output_path = os.path.dirname(os.path.dirname(cfg['load_path']))
|
| 60 |
+
assert cfg['output_log_path'] == output_path
|
| 61 |
+
cfg['output_log_path'] = output_path
|
| 62 |
+
return cfg
|
SkeletonDiffusion/src/core/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Import network first
|
| 2 |
+
from .network import AutoEncoder, Denoiser
|
| 3 |
+
|
| 4 |
+
# Then import diffusion manager
|
| 5 |
+
from .diffusion_manager import DiffusionManager
|
| 6 |
+
|
| 7 |
+
# Export all
|
| 8 |
+
__all__ = ['AutoEncoder', 'Denoiser', 'DiffusionManager']
|
SkeletonDiffusion/src/core/diffusion/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from SkeletonDiffusion.src.core.diffusion.isotropic import IsotropicGaussianDiffusion
|
| 2 |
+
from SkeletonDiffusion.src.core.diffusion.nonisotropic import NonisotropicGaussianDiffusion
|
| 3 |
+
from SkeletonDiffusion.src.core.diffusion.utils import get_cov_from_corr
|
SkeletonDiffusion/src/core/diffusion/base.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from random import random
|
| 3 |
+
from functools import partial
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
from typing import Tuple, Optional, List, Union, Dict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.cuda.amp import autocast
|
| 11 |
+
|
| 12 |
+
from einops import reduce
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# constants
|
| 19 |
+
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
|
| 20 |
+
|
| 21 |
+
# helper functions
|
| 22 |
+
|
| 23 |
+
def identity(t, *args, **kwargs):
|
| 24 |
+
return t
|
| 25 |
+
|
| 26 |
+
def exists(x):
|
| 27 |
+
return x is not None
|
| 28 |
+
|
| 29 |
+
def default(val, d):
|
| 30 |
+
if exists(val):
|
| 31 |
+
return val
|
| 32 |
+
return d() if callable(d) else d
|
| 33 |
+
|
| 34 |
+
def extract(a, t, x_shape):
|
| 35 |
+
b, *_ = t.shape
|
| 36 |
+
out = a.gather(-1, t)
|
| 37 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
| 38 |
+
|
| 39 |
+
def linear_beta_schedule(timesteps):
|
| 40 |
+
scale = 1000 / timesteps
|
| 41 |
+
beta_start = scale * 0.0001
|
| 42 |
+
beta_end = scale * 0.02
|
| 43 |
+
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
| 44 |
+
|
| 45 |
+
def cosine_beta_schedule(timesteps, s = 0.008):
|
| 46 |
+
"""
|
| 47 |
+
cosine schedule
|
| 48 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
| 49 |
+
"""
|
| 50 |
+
steps = timesteps + 1
|
| 51 |
+
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
| 52 |
+
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
|
| 53 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 54 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 55 |
+
return torch.clip(betas, 0, 0.999)
|
| 56 |
+
|
| 57 |
+
def exp_beta_schedule(timesteps, factor=3.0):
|
| 58 |
+
steps = timesteps + 1
|
| 59 |
+
x = torch.linspace(-factor, 0, steps, dtype = torch.float64)#/timesteps
|
| 60 |
+
betas = torch.exp(x)
|
| 61 |
+
return torch.clip(betas, 0, 0.999)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LatentDiffusion(nn.Module):
|
| 65 |
+
def __init__(self,
|
| 66 |
+
model:torch.nn.Module, latent_size=96, diffusion_timesteps=10, diffusion_objective="pred_x0", sampling_timesteps=None, diffusion_activation='identity',
|
| 67 |
+
silent=True, diffusion_conditioning=False, diffusion_loss_type='mse',
|
| 68 |
+
objective = 'pred_noise',
|
| 69 |
+
beta_schedule = 'cosine',
|
| 70 |
+
beta_schedule_factor=3.0,
|
| 71 |
+
ddim_sampling_eta = 0.,
|
| 72 |
+
**kwargs
|
| 73 |
+
):
|
| 74 |
+
|
| 75 |
+
super().__init__()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if diffusion_activation == "tanh":
|
| 79 |
+
self.activation = torch.nn.Tanh()
|
| 80 |
+
elif diffusion_activation == "identity":
|
| 81 |
+
self.activation = torch.nn.Identity()
|
| 82 |
+
self.silent = silent
|
| 83 |
+
self.condition = diffusion_conditioning
|
| 84 |
+
self.loss_type = diffusion_loss_type
|
| 85 |
+
|
| 86 |
+
self.statistics_pred = None
|
| 87 |
+
self.statistics_obs = None
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
timesteps=diffusion_timesteps
|
| 91 |
+
objective=diffusion_objective
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
self.model = model
|
| 95 |
+
self.channels = self.model.channels
|
| 96 |
+
self.self_condition = self.model.self_condition
|
| 97 |
+
|
| 98 |
+
self.seq_length = latent_size
|
| 99 |
+
self.objective = objective
|
| 100 |
+
|
| 101 |
+
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
|
| 102 |
+
|
| 103 |
+
if beta_schedule == 'linear':
|
| 104 |
+
betas = linear_beta_schedule(timesteps)
|
| 105 |
+
elif beta_schedule == 'cosine':
|
| 106 |
+
betas = cosine_beta_schedule(timesteps)
|
| 107 |
+
elif beta_schedule == 'exp':
|
| 108 |
+
betas = exp_beta_schedule(timesteps, beta_schedule_factor)
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f'unknown beta schedule {beta_schedule}')
|
| 111 |
+
|
| 112 |
+
alphas = 1. - betas
|
| 113 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 114 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
| 115 |
+
|
| 116 |
+
timesteps, = betas.shape
|
| 117 |
+
self.num_timesteps = int(timesteps)
|
| 118 |
+
|
| 119 |
+
# sampling related parameters
|
| 120 |
+
|
| 121 |
+
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
|
| 122 |
+
|
| 123 |
+
assert self.sampling_timesteps <= timesteps
|
| 124 |
+
self.is_ddim_sampling = self.sampling_timesteps < timesteps
|
| 125 |
+
self.ddim_sampling_eta = ddim_sampling_eta
|
| 126 |
+
|
| 127 |
+
# helper function to register buffer from float64 to float32
|
| 128 |
+
|
| 129 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 130 |
+
|
| 131 |
+
register_buffer('betas', betas)
|
| 132 |
+
register_buffer('alphas_cumprod', alphas_cumprod)
|
| 133 |
+
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
| 134 |
+
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
| 135 |
+
|
| 136 |
+
print("Diffusion is_ddim_sampling: ", self.is_ddim_sampling)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def set_normalization_statistics(self, statistics_pred, statistics_obs):
|
| 141 |
+
self.statistics_pred = statistics_pred
|
| 142 |
+
self.statistics_obs = statistics_obs
|
| 143 |
+
print("Setting normalization statistics for diffusion")
|
| 144 |
+
|
| 145 |
+
def get_white_noise(self, x, *args, **kwargs):
|
| 146 |
+
return self.get_noise(x, *args, **kwargs)
|
| 147 |
+
|
| 148 |
+
def get_start_noise(self, x, *args, **kwargs):
|
| 149 |
+
return self.get_white_noise(x, *args, **kwargs)
|
| 150 |
+
|
| 151 |
+
def get_noise(self, x, *args, **kwargs):
|
| 152 |
+
"""
|
| 153 |
+
x is either tensor or shape
|
| 154 |
+
"""
|
| 155 |
+
if torch.is_tensor(x):
|
| 156 |
+
return torch.randn_like(x)
|
| 157 |
+
elif isinstance(x, tuple):
|
| 158 |
+
return torch.randn(*x, *args, **kwargs)
|
| 159 |
+
|
| 160 |
+
#######################################################################
|
| 161 |
+
# TO SUBCLASS
|
| 162 |
+
#######################################################################
|
| 163 |
+
|
| 164 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 165 |
+
assert 0, "Not implemented"
|
| 166 |
+
...
|
| 167 |
+
return x_t
|
| 168 |
+
|
| 169 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 170 |
+
assert 0, "Not implemented"
|
| 171 |
+
...
|
| 172 |
+
return x_t
|
| 173 |
+
|
| 174 |
+
def predict_v(self, x_start, t, noise):
|
| 175 |
+
assert 0, "Not implemented"
|
| 176 |
+
...
|
| 177 |
+
return x_start
|
| 178 |
+
|
| 179 |
+
def predict_start_from_v(self, x_t, t, v):
|
| 180 |
+
assert 0, "Not implemented"
|
| 181 |
+
...
|
| 182 |
+
return x_t
|
| 183 |
+
|
| 184 |
+
@autocast(enabled = False)
|
| 185 |
+
def q_sample(self, x_start, t, noise=None):
|
| 186 |
+
assert 0, "Not implemented"
|
| 187 |
+
...
|
| 188 |
+
return x_start
|
| 189 |
+
|
| 190 |
+
def q_posterior(self, x_start, x_t, t):
|
| 191 |
+
assert 0, "Not implemented"
|
| 192 |
+
...
|
| 193 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 194 |
+
|
| 195 |
+
def p_combine_mean_var_noise(self, model_mean, model_log_variance, noise):
|
| 196 |
+
assert 0, "Not implemented"
|
| 197 |
+
...
|
| 198 |
+
return model_mean
|
| 199 |
+
|
| 200 |
+
def p_interpolate_mean_var_noise(self, model_mean, model_log_variance, noise, node_idx:Optional[int] = None, interpolate_factor=0.0, noise2interpolate=None, **kwargs):
|
| 201 |
+
assert 0, "Not implemented"
|
| 202 |
+
...
|
| 203 |
+
return model_mean
|
| 204 |
+
|
| 205 |
+
def loss_funct(self, model_out, target, *args, **kwargs):
|
| 206 |
+
if self.loss_type == "mse":
|
| 207 |
+
loss = F.mse_loss(model_out, target, reduction = 'none')
|
| 208 |
+
elif self.loss_type == 'l1':
|
| 209 |
+
loss = F.l1_loss(model_out, target, reduction = 'none')
|
| 210 |
+
else:
|
| 211 |
+
assert 0, "Not implemented"
|
| 212 |
+
return loss
|
| 213 |
+
|
| 214 |
+
########################################################################
|
| 215 |
+
# NETWORK INTERFACE
|
| 216 |
+
#########################################################################
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def model_predictions(self, x, t, x_self_cond = None, x_cond=None, clip_x_start = False, rederive_pred_noise = False):
|
| 220 |
+
model_output = self.feed_model(x, t, x_self_cond=x_self_cond, x_cond=x_cond)
|
| 221 |
+
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
|
| 222 |
+
|
| 223 |
+
if self.objective == 'pred_noise':
|
| 224 |
+
pred_noise = model_output
|
| 225 |
+
x_start = self.predict_start_from_noise(x, t, pred_noise)
|
| 226 |
+
x_start = maybe_clip(x_start)
|
| 227 |
+
|
| 228 |
+
if clip_x_start and rederive_pred_noise:
|
| 229 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 230 |
+
|
| 231 |
+
elif self.objective == 'pred_x0':
|
| 232 |
+
x_start = model_output
|
| 233 |
+
x_start = maybe_clip(x_start)
|
| 234 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 235 |
+
|
| 236 |
+
elif self.objective == 'pred_v':
|
| 237 |
+
v = model_output
|
| 238 |
+
x_start = self.predict_start_from_v(x, t, v)
|
| 239 |
+
x_start = maybe_clip(x_start)
|
| 240 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 241 |
+
return ModelPrediction(pred_noise, x_start)
|
| 242 |
+
|
| 243 |
+
def feed_model(self, x, t, x_self_cond = None, x_cond=None):
|
| 244 |
+
if self.condition:
|
| 245 |
+
assert x_cond is not None
|
| 246 |
+
if x.shape[0] > x_cond.shape[0]:
|
| 247 |
+
# training with multiple samples
|
| 248 |
+
x_cond = x_cond.repeat_interleave( int(x.shape[0]/x_cond.shape[0]), 0)
|
| 249 |
+
model_in = x
|
| 250 |
+
else:
|
| 251 |
+
model_in = x
|
| 252 |
+
|
| 253 |
+
model_output = self.model(model_in, t, x_self_cond, x_cond)
|
| 254 |
+
model_output = self.activation(model_output)
|
| 255 |
+
return model_output
|
| 256 |
+
|
| 257 |
+
########################################################################
|
| 258 |
+
# FORWARD PROCESS
|
| 259 |
+
#########################################################################
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def p_losses(self, x_start, t, noise = None, x_cond=None, n_train_samples=1):
|
| 263 |
+
b, c, n = x_start.shape
|
| 264 |
+
if n_train_samples > 1:
|
| 265 |
+
x_start = x_start.repeat_interleave(n_train_samples, dim=0)
|
| 266 |
+
t = t.repeat_interleave(n_train_samples, dim=0)
|
| 267 |
+
if x_cond is not None:
|
| 268 |
+
x_cond = x_cond.repeat_interleave(n_train_samples, dim=0)
|
| 269 |
+
noise = default(noise, self.get_white_noise(x_start, t)) # noise for timesteps t
|
| 270 |
+
|
| 271 |
+
# noise sample
|
| 272 |
+
|
| 273 |
+
x = self.q_sample(x_start = x_start, t = t, noise = noise)
|
| 274 |
+
|
| 275 |
+
# if doing self-conditioning, 50% of the time, predict x_start from current set of times
|
| 276 |
+
# and condition with unet with that
|
| 277 |
+
# this technique will slow down training by 25%, but seems to lower FID significantly
|
| 278 |
+
|
| 279 |
+
x_self_cond = None
|
| 280 |
+
if self.self_condition and random.random() < 0.5:
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
x_self_cond = self.model_predictions(x, t, x_cond=x_cond).pred_x_start
|
| 283 |
+
x_self_cond.detach_()
|
| 284 |
+
|
| 285 |
+
# predict and take gradient step
|
| 286 |
+
model_out = self.feed_model(x, t, x_self_cond=x_self_cond, x_cond=x_cond)
|
| 287 |
+
|
| 288 |
+
if self.objective == 'pred_noise':
|
| 289 |
+
target = noise
|
| 290 |
+
elif self.objective == 'pred_x0':
|
| 291 |
+
target = x_start
|
| 292 |
+
elif self.objective == 'pred_v':
|
| 293 |
+
v = self.predict_v(x_start, t, noise)
|
| 294 |
+
target = v
|
| 295 |
+
else:
|
| 296 |
+
raise ValueError(f'unknown objective {self.objective}')
|
| 297 |
+
loss = self.loss_funct(model_out, target, t)
|
| 298 |
+
loss = reduce(loss, 'b ... -> b', 'mean') # [batch*n_train_samples, Nodes, latent_dim] -> [batch*n_train_samples]
|
| 299 |
+
|
| 300 |
+
return loss, extract(self.loss_weight, t.view(b, -1)[:, 0], loss.shape[0:1]), model_out
|
| 301 |
+
|
| 302 |
+
def forward(self, x, *args, x_cond=None, **kwargs):
|
| 303 |
+
b, c, n, device, seq_length, = *x.shape, x.device, self.seq_length
|
| 304 |
+
assert n == seq_length, f'seq length must be {seq_length}'
|
| 305 |
+
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
|
| 306 |
+
|
| 307 |
+
return self.p_losses(x, t, *args, x_cond=x_cond, **kwargs)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
########################################################################
|
| 311 |
+
# REVERSE PROCESS
|
| 312 |
+
#########################################################################
|
| 313 |
+
|
| 314 |
+
def p_mean_variance(self, x, t, x_self_cond = None, x_cond=None, clip_denoised = True):
|
| 315 |
+
preds = self.model_predictions(x, t, x_self_cond, x_cond=x_cond)
|
| 316 |
+
x_start = preds.pred_x_start
|
| 317 |
+
|
| 318 |
+
if clip_denoised:
|
| 319 |
+
x_start.clamp_(-1., 1.)
|
| 320 |
+
|
| 321 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
|
| 322 |
+
return model_mean, posterior_variance, posterior_log_variance, x_start
|
| 323 |
+
|
| 324 |
+
@torch.no_grad()
|
| 325 |
+
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True, sampling_noise=None, *args, if_interpolate=False, noise2interpolate=None, interpolation_kwargs:Dict=None, **kwargs):
|
| 326 |
+
b, *_, device = *x.shape, x.device
|
| 327 |
+
batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
|
| 328 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised, *args, **kwargs)
|
| 329 |
+
|
| 330 |
+
if sampling_noise is not None and t > 0:
|
| 331 |
+
noise = sampling_noise[:, sampling_noise.shape[1]-t]
|
| 332 |
+
else:
|
| 333 |
+
noise = self.get_white_noise(x) if t > 0 else 0. # no noise if t == 0
|
| 334 |
+
|
| 335 |
+
if if_interpolate and t > 0:
|
| 336 |
+
noise2 = noise2interpolate[:, sampling_noise.shape[1]-t]
|
| 337 |
+
assert noise2.shape == noise.shape
|
| 338 |
+
pred_img = self.p_interpolate_mean_var_noise(model_mean, model_log_variance, noise, noise2, **interpolation_kwargs)
|
| 339 |
+
else:
|
| 340 |
+
pred_img = self.p_combine_mean_var_noise(model_mean, model_log_variance, noise)
|
| 341 |
+
return pred_img, x_start, noise, model_mean
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def p_sample_loop(self, shape, x_cond=None, start_noise=None, sampling_noise=None, return_sampling_noise=False, return_timages=False, **kwargs):
|
| 345 |
+
batch, device = shape[0], self.betas.device
|
| 346 |
+
if start_noise is not None:
|
| 347 |
+
assert start_noise.shape == shape, f"Shape mismatch: {start_noise.shape} != {shape}"
|
| 348 |
+
img = start_noise
|
| 349 |
+
noise = start_noise.clone()
|
| 350 |
+
else:
|
| 351 |
+
img = self.get_start_noise(shape, device=device)
|
| 352 |
+
noise = img.clone()
|
| 353 |
+
|
| 354 |
+
if sampling_noise is not None:
|
| 355 |
+
assert sampling_noise.shape[2:] == shape[1:], f"Shape mismatch: {start_noise.shape} != {shape}"
|
| 356 |
+
assert sampling_noise.shape[0] == shape[0], f"Shape mismatch: {start_noise.shape} != {shape}"
|
| 357 |
+
assert sampling_noise.shape[1] == self.num_timesteps - 1
|
| 358 |
+
|
| 359 |
+
x_start = None
|
| 360 |
+
imgs = []
|
| 361 |
+
noise_t = []
|
| 362 |
+
mean_t = []
|
| 363 |
+
if not self.silent:
|
| 364 |
+
print(f"Evaluation with {len(range(0, self.num_timesteps))} diffusion steps")
|
| 365 |
+
for t in reversed(range(0, self.num_timesteps)): #, desc = 'sampling loop time step', total = self.num_timesteps):
|
| 366 |
+
self_cond = x_start if self.self_condition else None
|
| 367 |
+
img, x_start, nt, model_mean = self.p_sample(img, t, self_cond, x_cond=x_cond, sampling_noise=sampling_noise, **kwargs)
|
| 368 |
+
if return_sampling_noise and t!=0:
|
| 369 |
+
noise_t.append(nt)
|
| 370 |
+
mean_t.append(model_mean)
|
| 371 |
+
if return_timages and t!=0:
|
| 372 |
+
imgs.append(img)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if return_sampling_noise:
|
| 376 |
+
noise_t = torch.stack(noise_t, dim=1)
|
| 377 |
+
mean_t = torch.stack(mean_t, dim=1)
|
| 378 |
+
if return_timages:
|
| 379 |
+
print("Returning timages")
|
| 380 |
+
imgs = torch.stack(imgs, dim=1)
|
| 381 |
+
|
| 382 |
+
if return_sampling_noise:
|
| 383 |
+
if return_timages:
|
| 384 |
+
noise = (noise, noise_t, imgs)
|
| 385 |
+
else:
|
| 386 |
+
noise = (noise, noise_t, mean_t)
|
| 387 |
+
else:
|
| 388 |
+
if return_timages:
|
| 389 |
+
noise = (noise, imgs)
|
| 390 |
+
return img, noise
|
| 391 |
+
|
| 392 |
+
@torch.no_grad()
|
| 393 |
+
def ddim_sample(self, shape, clip_denoised = True, x_cond=None, start_noise=None):
|
| 394 |
+
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
|
| 395 |
+
|
| 396 |
+
times = list(reversed(times.int().tolist()))
|
| 397 |
+
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
| 398 |
+
|
| 399 |
+
if start_noise is not None:
|
| 400 |
+
assert start_noise.shape == shape
|
| 401 |
+
img = start_noise
|
| 402 |
+
noise = start_noise.clone()
|
| 403 |
+
else:
|
| 404 |
+
img = torch.randn(shape, device=device)
|
| 405 |
+
noise = img.clone()
|
| 406 |
+
|
| 407 |
+
imgs = []
|
| 408 |
+
|
| 409 |
+
x_start = None
|
| 410 |
+
if not self.silent:
|
| 411 |
+
print(f"Evaluation with {len(time_pairs)} diffusion steps")
|
| 412 |
+
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
| 413 |
+
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
|
| 414 |
+
self_cond = x_start if self.self_condition else None
|
| 415 |
+
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, x_cond=x_cond, clip_x_start = clip_denoised)
|
| 416 |
+
|
| 417 |
+
if time_next < 0:
|
| 418 |
+
img = x_start
|
| 419 |
+
# imgs.append(img)
|
| 420 |
+
|
| 421 |
+
alpha = self.alphas_cumprod[time]
|
| 422 |
+
alpha_next = self.alphas_cumprod[time_next]
|
| 423 |
+
sqrt_alpha_next = self.sqrt_alphas_cumprod[time_next]
|
| 424 |
+
|
| 425 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() #eta*beta_t_tilde
|
| 426 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
| 427 |
+
|
| 428 |
+
noise = torch.randn_like(img)
|
| 429 |
+
|
| 430 |
+
img = x_start * sqrt_alpha_next + \
|
| 431 |
+
c * pred_noise + \
|
| 432 |
+
sigma * noise
|
| 433 |
+
# imgs.append(img)
|
| 434 |
+
|
| 435 |
+
# imgs = torch.stack(imgs, dim=1)
|
| 436 |
+
return img, noise
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
@torch.no_grad()
|
| 440 |
+
def sample(self, batch_size = 16, *args, **kwargs):
|
| 441 |
+
seq_length, channels = self.seq_length, self.channels
|
| 442 |
+
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
|
| 443 |
+
return sample_fn((batch_size, channels, seq_length),*args, **kwargs)
|
| 444 |
+
|
| 445 |
+
|
SkeletonDiffusion/src/core/diffusion/isotropic.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.cuda.amp import autocast
|
| 3 |
+
|
| 4 |
+
from SkeletonDiffusion.src.core.diffusion.base import LatentDiffusion, extract, default
|
| 5 |
+
|
| 6 |
+
class IsotropicGaussianDiffusion(LatentDiffusion):
|
| 7 |
+
def __init__(self, **kwargs):
|
| 8 |
+
super().__init__( **kwargs)
|
| 9 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 10 |
+
|
| 11 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 12 |
+
|
| 13 |
+
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))
|
| 14 |
+
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - self.alphas_cumprod))
|
| 15 |
+
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod))
|
| 16 |
+
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / self.alphas_cumprod - 1))
|
| 17 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 18 |
+
|
| 19 |
+
posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
|
| 20 |
+
alphas = 1. - self.betas
|
| 21 |
+
|
| 22 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 23 |
+
|
| 24 |
+
register_buffer('posterior_variance', posterior_variance)
|
| 25 |
+
|
| 26 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 27 |
+
|
| 28 |
+
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
| 29 |
+
register_buffer('posterior_mean_coef1', self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod))
|
| 30 |
+
register_buffer('posterior_mean_coef2', (1. - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - self.alphas_cumprod))
|
| 31 |
+
|
| 32 |
+
# calculate loss weight
|
| 33 |
+
snr = self.alphas_cumprod / (1 - self.alphas_cumprod)
|
| 34 |
+
|
| 35 |
+
if self.objective == 'pred_noise':
|
| 36 |
+
loss_weight = torch.ones_like(snr)
|
| 37 |
+
elif self.objective == 'pred_x0':
|
| 38 |
+
loss_weight = snr
|
| 39 |
+
elif self.objective == 'pred_v':
|
| 40 |
+
loss_weight = snr / (snr + 1)
|
| 41 |
+
|
| 42 |
+
register_buffer('loss_weight', loss_weight)
|
| 43 |
+
|
| 44 |
+
########################################################################
|
| 45 |
+
# FORWARD PROCESS
|
| 46 |
+
#########################################################################
|
| 47 |
+
|
| 48 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 49 |
+
return (
|
| 50 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 51 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 55 |
+
return (
|
| 56 |
+
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
| 57 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def predict_v(self, x_start, t, noise):
|
| 61 |
+
return (
|
| 62 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
| 63 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def predict_start_from_v(self, x_t, t, v):
|
| 67 |
+
return (
|
| 68 |
+
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
| 69 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
@autocast(enabled = False)
|
| 73 |
+
def q_sample(self, x_start, t, noise=None):
|
| 74 |
+
noise = default(noise, lambda: self.get_white_noise(x_start))
|
| 75 |
+
|
| 76 |
+
return (
|
| 77 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 78 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
########################################################################
|
| 82 |
+
# REVERSE PROCESS
|
| 83 |
+
#########################################################################
|
| 84 |
+
|
| 85 |
+
def q_posterior(self, x_start, x_t, t):
|
| 86 |
+
posterior_mean = (
|
| 87 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
| 88 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 89 |
+
)
|
| 90 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 91 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 92 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 93 |
+
|
| 94 |
+
def p_combine_mean_var_noise(self, model_mean, model_log_variance, noise):
|
| 95 |
+
return model_mean + (0.5 * model_log_variance).exp() * noise
|
| 96 |
+
|
| 97 |
+
def interpolate_noise(self, noise1, noise2, interpolate_funct=None, **kwargs):
|
| 98 |
+
interpolated_noise = interpolate_funct(noise1, noise2)
|
| 99 |
+
return interpolated_noise
|
| 100 |
+
|
| 101 |
+
def p_interpolate_mean_var_noise(self, model_mean, model_log_variance, noise, noise2interpolate=None, **kwargs):
|
| 102 |
+
interpolated_noise = self.interpolate_noise(noise, noise2interpolate,**kwargs)
|
| 103 |
+
return model_mean + (0.5 * model_log_variance).exp() * interpolated_noise
|
| 104 |
+
|
SkeletonDiffusion/src/core/diffusion/nonisotropic.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.cuda.amp import autocast
|
| 3 |
+
from .base import LatentDiffusion, extract, default
|
| 4 |
+
|
| 5 |
+
def extract_matrix(matrix, t, x_shape):
|
| 6 |
+
b, *_ = t.shape
|
| 7 |
+
T, N, *_ = matrix.shape
|
| 8 |
+
out = torch.index_select(matrix, 0, t)
|
| 9 |
+
out = out.reshape(b, *out.shape[1:])
|
| 10 |
+
while len(x_shape) > len(out.shape):
|
| 11 |
+
out = out.unsqueeze(-1)
|
| 12 |
+
return out
|
| 13 |
+
|
| 14 |
+
def verify_noise_scale(diffusion):
|
| 15 |
+
N, *_ = diffusion.Lambda_N.shape
|
| 16 |
+
alphas = 1 - diffusion.betas
|
| 17 |
+
noise = diffusion.get_noise((2000, diffusion.num_timesteps, N))
|
| 18 |
+
zeta_noise = torch.sqrt(diffusion.Lambda_t.unsqueeze(0)) * noise
|
| 19 |
+
print("current: ", (zeta_noise**2).sum(-1).mean(0))
|
| 20 |
+
print("original standard gaussian diffusion: ",(1-alphas) * zeta_noise.shape[-1])
|
| 21 |
+
|
| 22 |
+
def compute_covariance_matrices(diffusion: torch.nn.Module, Lambda_N: torch.Tensor, diffusion_covariance_type='ani-isotropic', gamma_scheduler = 'cosine'):
|
| 23 |
+
N, *_ = Lambda_N.shape
|
| 24 |
+
alphas = 1. - diffusion.betas
|
| 25 |
+
def _alpha_sumprod(alphas, t):
|
| 26 |
+
return torch.sum(torch.cumprod(torch.flip(alphas[:t+1], [0]), dim=0))
|
| 27 |
+
alphas_sumprod = torch.stack([_alpha_sumprod(alphas, t) for t in range(len(alphas))], dim=0)
|
| 28 |
+
diffusion.alphas_sumprod = alphas_sumprod
|
| 29 |
+
if diffusion_covariance_type == 'isotropic':
|
| 30 |
+
assert (Lambda_N == 0).all()
|
| 31 |
+
Lambda_t = (1-alphas).unsqueeze(-1) # (Tdiff, N)
|
| 32 |
+
Lambda_bar_t = (1-diffusion.alphas_cumprod.unsqueeze(-1))
|
| 33 |
+
Lambda_bar_t_prev = torch.cat([torch.zeros(1).unsqueeze(0), Lambda_bar_t[:-1]], dim=0)
|
| 34 |
+
elif diffusion_covariance_type == 'anisotropic':
|
| 35 |
+
Lambda_t = (1-alphas.unsqueeze(-1))*Lambda_N # (Tdiff, N)
|
| 36 |
+
Lambda_bar_t = (1-diffusion.alphas_cumprod.unsqueeze(-1))*Lambda_N
|
| 37 |
+
Lambda_bar_t_prev = (1-diffusion.alphas_cumprod_prev.unsqueeze(-1))*Lambda_N
|
| 38 |
+
elif diffusion_covariance_type == 'skeleton-diffusion':
|
| 39 |
+
if gamma_scheduler== 'cosine':
|
| 40 |
+
gammas = 1 - alphas
|
| 41 |
+
elif gamma_scheduler == 'mono_decrease':
|
| 42 |
+
gammas = 1 - torch.arange(0, diffusion.num_timesteps)/diffusion.num_timesteps
|
| 43 |
+
else:
|
| 44 |
+
assert 0, "Not implemented"
|
| 45 |
+
Lambda_I = Lambda_N - 1
|
| 46 |
+
gammas_bar = (1-alphas)*gammas
|
| 47 |
+
gammas_tilde = diffusion.alphas_cumprod*torch.cumsum(gammas_bar/diffusion.alphas_cumprod, dim=-1)
|
| 48 |
+
Lambda_t = Lambda_I.unsqueeze(0)*gammas_bar.unsqueeze(-1) + (1-alphas).unsqueeze(-1) # (Tdiff, N)
|
| 49 |
+
Lambda_bar_t = Lambda_I.unsqueeze(0)*gammas_tilde.unsqueeze(-1) + (1-diffusion.alphas_cumprod.unsqueeze(-1))
|
| 50 |
+
Lambda_bar_t_prev = torch.cat([torch.zeros(N).unsqueeze(0), Lambda_bar_t[:-1]], dim=0) # we start from det so it must be zero for t=-1
|
| 51 |
+
else:
|
| 52 |
+
assert 0, "Not implemented"
|
| 53 |
+
|
| 54 |
+
return Lambda_t, Lambda_bar_t, Lambda_bar_t_prev
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class NonisotropicGaussianDiffusion(LatentDiffusion):
|
| 58 |
+
def __init__(self, Sigma_N: torch.Tensor, Lambda_N: torch.Tensor, U: torch.Tensor, diffusion_covariance_type='skeleton-diffusion', loss_reduction_type='l1', gamma_scheduler = 'cosine', **kwargs):
|
| 59 |
+
super().__init__( **kwargs)
|
| 60 |
+
alphas = 1. - self.betas
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
N, _ = Sigma_N.shape
|
| 64 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 65 |
+
register_buffer('Lambda_N', Lambda_N)
|
| 66 |
+
register_buffer('Sigma_N', Sigma_N)
|
| 67 |
+
self.set_rotation_matrix(U)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
Lambda_t, Lambda_bar_t, Lambda_bar_t_prev = compute_covariance_matrices(diffusion=self, Lambda_N=Lambda_N,
|
| 71 |
+
diffusion_covariance_type=diffusion_covariance_type, gamma_scheduler=gamma_scheduler)
|
| 72 |
+
|
| 73 |
+
def create_diagonal_matrix(diagonal_vector):
|
| 74 |
+
return torch.stack([torch.diag(diag) for diag in diagonal_vector], dim=0) # [T, N, N]
|
| 75 |
+
######### forward , for training and inference #####################
|
| 76 |
+
#predict_noise_from_start
|
| 77 |
+
inv_sqrt_Lambda_bar = 1/torch.sqrt(Lambda_bar_t)
|
| 78 |
+
inv_sqrt_Lambda_bar_sqrt_alphas_cumprod = (1/torch.sqrt(Lambda_bar_t))*self.sqrt_alphas_cumprod.unsqueeze(-1)
|
| 79 |
+
register_buffer('inv_sqrt_Lambda_bar_mmUt', create_diagonal_matrix(inv_sqrt_Lambda_bar)@self.U_transposed.unsqueeze(0))
|
| 80 |
+
register_buffer('inv_sqrt_Lambda_bar_sqrt_alphas_cumprod_mmUt', create_diagonal_matrix(inv_sqrt_Lambda_bar_sqrt_alphas_cumprod)@self.U_transposed.unsqueeze(0))
|
| 81 |
+
#predict_start_from_noise
|
| 82 |
+
sqrt_Lambda_bar = torch.sqrt(Lambda_bar_t)
|
| 83 |
+
sqrt_Lambda_bar_sqrt_recip_alphas_cumprod = torch.sqrt(Lambda_bar_t/self.alphas_cumprod.unsqueeze(-1))
|
| 84 |
+
register_buffer('Umm_sqrt_Lambda_bar_t', U.unsqueeze(0)@create_diagonal_matrix(sqrt_Lambda_bar))
|
| 85 |
+
register_buffer('Umm_sqrt_Lambda_bar_t_sqrt_recip_alphas_cumprod', U.unsqueeze(0)@create_diagonal_matrix(sqrt_Lambda_bar_sqrt_recip_alphas_cumprod))
|
| 86 |
+
|
| 87 |
+
######### q_posterior , for reverse process #####################
|
| 88 |
+
#q_posterior
|
| 89 |
+
Lambda_posterior_t = Lambda_t*Lambda_bar_t_prev*(1/Lambda_bar_t)
|
| 90 |
+
sqrt_alphas_cumprod_prev = torch.sqrt(self.alphas_cumprod_prev)
|
| 91 |
+
register_buffer('Lambda_posterior', Lambda_posterior_t)
|
| 92 |
+
register_buffer('Lambda_posterior_log_variance_clipped', torch.log(Lambda_posterior_t.clamp(min =1e-20)))
|
| 93 |
+
|
| 94 |
+
posterior_mean_coef1_x0 = sqrt_alphas_cumprod_prev.unsqueeze(-1).unsqueeze(-1)*(U.unsqueeze(0)@create_diagonal_matrix((1/Lambda_bar_t)*Lambda_t)@self.U_transposed.unsqueeze(0))
|
| 95 |
+
posterior_mean_coef2_xt = torch.sqrt(alphas).unsqueeze(-1).unsqueeze(-1)*(U.unsqueeze(0)@create_diagonal_matrix((1/Lambda_bar_t)*Lambda_bar_t_prev)@self.U_transposed.unsqueeze(0))
|
| 96 |
+
register_buffer('posterior_mean_coef1_x0', posterior_mean_coef1_x0)
|
| 97 |
+
register_buffer('posterior_mean_coef2_xt', posterior_mean_coef2_xt)
|
| 98 |
+
|
| 99 |
+
######### loss #####################
|
| 100 |
+
self.loss_reduction_type = loss_reduction_type
|
| 101 |
+
sqrt_recip_Lambda_bar_t = torch.sqrt(1. / Lambda_bar_t)
|
| 102 |
+
register_buffer('mahalanobis_S_sqrt_recip', create_diagonal_matrix(sqrt_recip_Lambda_bar_t)@self.U_transposed.unsqueeze(0))
|
| 103 |
+
|
| 104 |
+
if self.objective == 'pred_noise':
|
| 105 |
+
loss_weight = torch.ones_like(alphas)
|
| 106 |
+
elif self.objective == 'pred_x0':
|
| 107 |
+
loss_weight = self.alphas_cumprod
|
| 108 |
+
elif self.objective == 'pred_v':
|
| 109 |
+
assert 0, "Not implemented"
|
| 110 |
+
# loss_weight = snr / (snr + 1)
|
| 111 |
+
register_buffer('loss_weight', loss_weight)
|
| 112 |
+
|
| 113 |
+
assert not len(self.mahalanobis_S_sqrt_recip.shape) == 1
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
########################################################################
|
| 117 |
+
# CLASS FUNCTIONS
|
| 118 |
+
#########################################################################
|
| 119 |
+
|
| 120 |
+
def set_rotation_matrix(self, U:torch.Tensor):
|
| 121 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 122 |
+
register_buffer('U', U)
|
| 123 |
+
register_buffer('U_transposed', U.t())
|
| 124 |
+
|
| 125 |
+
def check_eigh(self):
|
| 126 |
+
return torch.isclose(self.U@torch.diag(self.Lambda_N)@self.U_transposed,self.Sigma_N)#.all(), "U@Lambda_N@U^T must be equal to Sigma_N"
|
| 127 |
+
|
| 128 |
+
def get_anisotropic_noise(self, x, *args, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
x is either tensor or shape
|
| 131 |
+
"""
|
| 132 |
+
return self.get_noise(x, *args, **kwargs)*self.Lambda_N.unsqueeze(-1)
|
| 133 |
+
|
| 134 |
+
########################################################################
|
| 135 |
+
# FORWARD PROCESS
|
| 136 |
+
#########################################################################
|
| 137 |
+
|
| 138 |
+
@autocast(enabled = False)
|
| 139 |
+
def q_sample(self, x_start, t, noise=None):
|
| 140 |
+
noise = default(noise, lambda: self.get_white_noise(x_start))
|
| 141 |
+
|
| 142 |
+
return (
|
| 143 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| 144 |
+
extract_matrix(self.Umm_sqrt_Lambda_bar_t, t, x_start.shape) @ noise
|
| 145 |
+
)
|
| 146 |
+
# for inference
|
| 147 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 148 |
+
return (
|
| 149 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
| 150 |
+
extract_matrix(self.Umm_sqrt_Lambda_bar_t_sqrt_recip_alphas_cumprod, t, x_t.shape) @ noise
|
| 151 |
+
)
|
| 152 |
+
# for inference
|
| 153 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 154 |
+
return (
|
| 155 |
+
extract_matrix(self.inv_sqrt_Lambda_bar_mmUt, t, x_t.shape)@x_t -\
|
| 156 |
+
extract_matrix(self.inv_sqrt_Lambda_bar_sqrt_alphas_cumprod_mmUt, t, x_t.shape)@x0
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
########################################################################
|
| 160 |
+
# LOSS
|
| 161 |
+
#########################################################################
|
| 162 |
+
|
| 163 |
+
def mahalanobis_dist(self, matrix, vector):
|
| 164 |
+
return (matrix@vector).abs() # check shape
|
| 165 |
+
|
| 166 |
+
def loss_funct(self, model_out, target, t):
|
| 167 |
+
difference = target - model_out if self.objective == 'pred_noise' else model_out - target
|
| 168 |
+
|
| 169 |
+
loss = self.mahalanobis_dist(extract_matrix(self.mahalanobis_S_sqrt_recip, t, difference.shape), difference)
|
| 170 |
+
if self.loss_reduction_type == 'l1':
|
| 171 |
+
loss = loss
|
| 172 |
+
elif self.loss_reduction_type == 'mse':
|
| 173 |
+
loss = loss**2
|
| 174 |
+
else:
|
| 175 |
+
assert 0, "Not implemented"
|
| 176 |
+
return loss
|
| 177 |
+
|
| 178 |
+
########################################################################
|
| 179 |
+
# REVERSE PROCESS
|
| 180 |
+
#########################################################################
|
| 181 |
+
|
| 182 |
+
def q_posterior_mean(self, x_start, x_t, t):
|
| 183 |
+
return (
|
| 184 |
+
extract_matrix(self.posterior_mean_coef1_x0, t, x_t.shape) @ x_start +
|
| 185 |
+
extract_matrix(self.posterior_mean_coef2_xt, t, x_t.shape) @ x_t
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def q_posterior(self, x_start, x_t, t):
|
| 189 |
+
posterior_mean = self.q_posterior_mean(x_start, x_t, t)
|
| 190 |
+
posterior_variance = extract_matrix(self.Lambda_posterior, t, x_t.shape)
|
| 191 |
+
posterior_log_variance_clipped = extract_matrix(self.Lambda_posterior_log_variance_clipped, t, x_t.shape)
|
| 192 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 193 |
+
|
| 194 |
+
def p_combine_mean_var_noise(self, model_mean, posterior_log_variance, noise):
|
| 195 |
+
""" mean is in not diagonal coordinate system, posterior_log_variance is in diagonal coordinate system"""
|
| 196 |
+
return model_mean + self.U@((0.5 * posterior_log_variance).exp() * noise)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
########################################################################
|
| 200 |
+
# INTERPOLATION
|
| 201 |
+
#########################################################################
|
| 202 |
+
|
| 203 |
+
def interpolate_noise(self, noise1, noise2, posterior_log_variance=None, interpolate_funct=None):
|
| 204 |
+
noise1 = self.U@((0.5 * posterior_log_variance).exp() * noise1)
|
| 205 |
+
noise2 = self.U@((0.5 * posterior_log_variance).exp() * noise2)
|
| 206 |
+
interpolated_noise = interpolate_funct(noise1, noise2)
|
| 207 |
+
return interpolated_noise
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def p_interpolate_mean_var_noise(self, model_mean, model_log_variance, noise, noise2interpolate=None, **kwargs):
|
| 211 |
+
interpolated_noise = self.interpolate_noise(noise, noise2interpolate, posterior_log_variance=model_log_variance, **kwargs)
|
| 212 |
+
return model_mean + interpolated_noise # (0.5 * model_log_variance).exp()
|
| 213 |
+
|
SkeletonDiffusion/src/core/diffusion/utils.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def dim_null_space(matrix):
|
| 4 |
+
assert matrix.shape[-1] == matrix.shape[-2], "Matrix must be square"
|
| 5 |
+
# rank = torch.linalg.matrix_rank(matrix) This is not set to accuracy of PYTORCH float32
|
| 6 |
+
# 1.0 + eps != 1.0
|
| 7 |
+
# torhc.tensor(1.0) + 0.7e-7!= torhc.tensor(1.0)
|
| 8 |
+
return torch.sum(torch.linalg.eigh(matrix)[0].abs() < 0.7e-7)
|
| 9 |
+
|
| 10 |
+
def is_positive_def(matrix):
|
| 11 |
+
#M is symmetric or Hermitian, and all its eigenvalues are real and positive.
|
| 12 |
+
assert torch.allclose(matrix.transpose(-1, -2), matrix), "Matrix must be symmetric"
|
| 13 |
+
eigenvalues = torch.linalg.eigvals(matrix)
|
| 14 |
+
is_pos_def = (torch.real(eigenvalues)> 0).all()
|
| 15 |
+
if is_pos_def:
|
| 16 |
+
assert torch.isreal(eigenvalues).all(), "Eigenvalues must be real"
|
| 17 |
+
return (torch.real(eigenvalues)> 0).all()
|
| 18 |
+
|
| 19 |
+
def make_positive_definite(matrix, epsilon=1e-6, if_submin=False):
|
| 20 |
+
|
| 21 |
+
eigenvalues = torch.linalg.eigvals(matrix)
|
| 22 |
+
# assert torch.isreal(eigenvalues).all()
|
| 23 |
+
if is_positive_def(matrix):
|
| 24 |
+
print("Input Matrix was positive Definitive without adding spectral norm to the diagonal")
|
| 25 |
+
return matrix
|
| 26 |
+
|
| 27 |
+
eigenvalues = torch.real(eigenvalues)
|
| 28 |
+
if not if_submin:
|
| 29 |
+
max_eig = eigenvalues.abs().max() #
|
| 30 |
+
pos_def_matrix = matrix + torch.eye(matrix.shape[0])*(max_eig + epsilon)
|
| 31 |
+
else:
|
| 32 |
+
min_eig = eigenvalues.min()
|
| 33 |
+
pos_def_matrix = matrix + torch.eye(matrix.shape[0])*(- min_eig + epsilon)
|
| 34 |
+
assert dim_null_space(pos_def_matrix) == 0
|
| 35 |
+
return pos_def_matrix
|
| 36 |
+
|
| 37 |
+
def normalize_cov(Sigma_N:torch.Tensor, Lambda_N:torch.Tensor, U:torch.Tensor, if_sigma_n_scale=True, sigma_n_scale='spectral', **kwargs):
|
| 38 |
+
N, _ = Sigma_N.shape
|
| 39 |
+
assert Lambda_N.shape == (N,)
|
| 40 |
+
assert U.shape == (N, N)
|
| 41 |
+
|
| 42 |
+
if if_sigma_n_scale:
|
| 43 |
+
# decrease the scale of Sigma_N to make it more similar to the identity matrix
|
| 44 |
+
if sigma_n_scale == 'spectral':
|
| 45 |
+
relative_scale_factor = Lambda_N.max()
|
| 46 |
+
else:
|
| 47 |
+
if sigma_n_scale == 'frob':
|
| 48 |
+
relative_scale_factor = Lambda_N.sum()/N
|
| 49 |
+
else:
|
| 50 |
+
assert 0, "Not implemented"
|
| 51 |
+
|
| 52 |
+
Lambda_N = Lambda_N/relative_scale_factor
|
| 53 |
+
|
| 54 |
+
Sigma_N = Sigma_N/relative_scale_factor
|
| 55 |
+
cond = U @ torch.diag(Lambda_N) @ U.mT
|
| 56 |
+
assert torch.isclose(Sigma_N, cond, atol=1e-06).all(), "Sigma_N must be equal to U @ Lambda_N @ U.t()"
|
| 57 |
+
# Sigma_N[Sigma_N>0.] = (Sigma_N + Sigma_N.t())[Sigma_N>0.]/2
|
| 58 |
+
cond = Lambda_N>0.7e-7
|
| 59 |
+
assert (cond).all(), f"Lambda_N must be positive definite: {Lambda_N}"
|
| 60 |
+
assert is_positive_def(Sigma_N), "Sigma_N must be positive definite"
|
| 61 |
+
# print("Frobenius Norm of SigmaN: ", torch.linalg.matrix_norm(Sigma_N, ord='fro').mean().item(), "Spectral Norm of SigmaN: ", Lambda_N.max(dim=-1)[0].mean().item())
|
| 62 |
+
return Sigma_N, Lambda_N
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_cov_from_corr(correlation_matrix: torch.Tensor, if_sigma_n_scale=True, sigma_n_scale='spectral', if_run_as_isotropic=False, diffusion_covariance_type='skeleton-diffusion', **kwargs):
|
| 66 |
+
N, _ = correlation_matrix.shape
|
| 67 |
+
|
| 68 |
+
if if_run_as_isotropic:
|
| 69 |
+
if diffusion_covariance_type == 'skeleton-diffusion':
|
| 70 |
+
Lambda_N = torch.ones(N, device=correlation_matrix.device)
|
| 71 |
+
Sigma_N = torch.zeros_like(correlation_matrix)
|
| 72 |
+
U = torch.eye(N, device=correlation_matrix.device)
|
| 73 |
+
elif diffusion_covariance_type == 'anisotropic':
|
| 74 |
+
Lambda_N = torch.ones(N, device=correlation_matrix.device)
|
| 75 |
+
Sigma_N = torch.eye(N, device=correlation_matrix.device)
|
| 76 |
+
U = torch.eye(N, device=correlation_matrix.device)
|
| 77 |
+
else:
|
| 78 |
+
Lambda_N = torch.zeros(N, device=correlation_matrix.device)
|
| 79 |
+
Sigma_N = torch.zeros_like(correlation_matrix)
|
| 80 |
+
U = torch.eye(N, device=correlation_matrix.device)
|
| 81 |
+
else:
|
| 82 |
+
Sigma_N = make_positive_definite(correlation_matrix)
|
| 83 |
+
Lambda_N, U = torch.linalg.eigh(Sigma_N, UPLO='L')
|
| 84 |
+
|
| 85 |
+
Sigma_N, Lambda_N = normalize_cov(Sigma_N=Sigma_N, Lambda_N=Lambda_N, U=U, if_sigma_n_scale=if_sigma_n_scale, sigma_n_scale=sigma_n_scale, **kwargs)
|
| 86 |
+
return Sigma_N, Lambda_N, U
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def verify_noise_scale(diffusion):
|
| 90 |
+
N, *_ = diffusion.Lambda_N.shape
|
| 91 |
+
alphas = 1 - diffusion.betas
|
| 92 |
+
noise = diffusion.get_noise((2000, diffusion.num_timesteps, N))
|
| 93 |
+
zeta_noise = torch.sqrt(diffusion.Lambda_t.unsqueeze(0)) * noise
|
| 94 |
+
print("current: ", (zeta_noise**2).sum(-1).mean(0))
|
| 95 |
+
print("original standard gaussian diffusion: ",(1-alphas) * zeta_noise.shape[-1])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def plot_matrix(matrix):
|
| 99 |
+
import matplotlib
|
| 100 |
+
import matplotlib.pyplot as plt
|
| 101 |
+
from matplotlib.colors import ListedColormap
|
| 102 |
+
import numpy as np
|
| 103 |
+
|
| 104 |
+
Sigma_N = matrix.cpu().clone().numpy()
|
| 105 |
+
color = 'Purples'
|
| 106 |
+
cmap = matplotlib.colormaps[color].set_bad("white")
|
| 107 |
+
# colormap_r = ListedColormap(cmap.colors[::-1])
|
| 108 |
+
|
| 109 |
+
fig, ax = plt.subplots(1,1, figsize=(6, 6),sharex=True, subplot_kw=dict(box_aspect=1),)
|
| 110 |
+
# cax = fig.add_axes([0.93, 0.15, 0.01, 0.7]) # Adjust the position and size of the colorbar
|
| 111 |
+
# for i, ax in enumerate(axes):
|
| 112 |
+
vmax = Sigma_N.max()
|
| 113 |
+
Sigma_N[Sigma_N <=0.0000] = np.nan
|
| 114 |
+
im = ax.imshow(Sigma_N, cmap=color, vmin=0., vmax=vmax)
|
| 115 |
+
# ax.set_xticks(np.arange(len(Sigma_N)))
|
| 116 |
+
# ax.set_xticklabels(labels=list(skeleton.node_dict.values()), rotation=45, ha="right",
|
| 117 |
+
# rotation_mode="anchor")
|
| 118 |
+
# ax.set_yticks(np.arange(len(Sigma_N)))
|
| 119 |
+
# ax.set_yticklabels(labels=list(skeleton.node_dict.values()), rotation=45, ha="right",
|
| 120 |
+
# rotation_mode="anchor")
|
| 121 |
+
# ax.set_title(list(method2sigman.keys())[i])
|
| 122 |
+
fig.colorbar(im, cmap=cmap)
|
| 123 |
+
# plt.title('Adjancecy Matrix')
|
| 124 |
+
plt.show()
|
| 125 |
+
# fig.savefig("../paper_plots/sigmaN.pdf", format="pdf", bbox_inches="tight")
|
SkeletonDiffusion/src/core/diffusion_manager.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Optional, List, Union, Dict, Any
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from SkeletonDiffusion.src.core.network.nn import Denoiser
|
| 5 |
+
from SkeletonDiffusion.src.core.diffusion import IsotropicGaussianDiffusion, NonisotropicGaussianDiffusion, get_cov_from_corr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DiffusionManager():
|
| 9 |
+
def __init__(self, diffusion_type: str='IsotropicGaussianDiffusion', skeleton=None, covariance_matrix_type: str = 'adjacency',
|
| 10 |
+
reachability_matrix_degree_factor=0.5, reachability_matrix_stop_at=0, if_sigma_n_scale=True, sigma_n_scale='spectral', if_run_as_isotropic=False,
|
| 11 |
+
**kwargs):
|
| 12 |
+
|
| 13 |
+
model = self.get_network(**kwargs)
|
| 14 |
+
self.diffusion_type = diffusion_type
|
| 15 |
+
|
| 16 |
+
if diffusion_type == 'NonisotropicGaussianDiffusion':
|
| 17 |
+
# define SigmaN
|
| 18 |
+
if covariance_matrix_type == 'adjacency':
|
| 19 |
+
correlation_matrix = skeleton.adj_matrix
|
| 20 |
+
elif covariance_matrix_type == 'reachability':
|
| 21 |
+
correlation_matrix = skeleton.reachability_matrix(factor=reachability_matrix_degree_factor, stop_at=reachability_matrix_stop_at)
|
| 22 |
+
else:
|
| 23 |
+
assert 0, "Not implemented"
|
| 24 |
+
N, *_ = correlation_matrix.shape
|
| 25 |
+
|
| 26 |
+
Sigma_N, Lambda_N, U = get_cov_from_corr(correlation_matrix=correlation_matrix, if_sigma_n_scale=if_sigma_n_scale, sigma_n_scale=sigma_n_scale, if_run_as_isotropic=if_run_as_isotropic, **kwargs)
|
| 27 |
+
self.diffusion = NonisotropicGaussianDiffusion(Sigma_N=Sigma_N, Lambda_N=Lambda_N, U=U, model=model, **kwargs)
|
| 28 |
+
elif diffusion_type == 'IsotropicGaussianDiffusion':
|
| 29 |
+
self.diffusion = IsotropicGaussianDiffusion(model=model, **kwargs)
|
| 30 |
+
else:
|
| 31 |
+
assert 0, f"{diffusion_type} Not implemented"
|
| 32 |
+
|
| 33 |
+
def get_diffusion(self):
|
| 34 |
+
return self.diffusion
|
| 35 |
+
|
| 36 |
+
def get_network(self, num_nodes, diffusion_conditioning=False, latent_size=96, node_types: torch.Tensor = None, diffusion_arch=Dict[str, Any], **kwargs):
|
| 37 |
+
|
| 38 |
+
if diffusion_conditioning:
|
| 39 |
+
cond_dim = latent_size
|
| 40 |
+
else:
|
| 41 |
+
cond_dim = 0
|
| 42 |
+
|
| 43 |
+
model = Denoiser(dim=latent_size, cond_dim=cond_dim, out_dim=latent_size, channels=num_nodes, num_nodes=num_nodes, node_types=node_types,**diffusion_arch)
|
| 44 |
+
|
| 45 |
+
return model
|
SkeletonDiffusion/src/core/network/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .nn import AutoEncoder, Denoiser
|
| 2 |
+
|
| 3 |
+
__all__ = ['AutoEncoder', 'Denoiser']
|
SkeletonDiffusion/src/core/network/layers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .graph_structural import StaticGraphLinear
|
| 2 |
+
from .recurrent import StaticGraphGRU, GraphGRUState, StaticGraphLSTM, GraphLSTMState
|
| 3 |
+
from .attention import Attention, ResnetBlock, Residual, PreNorm, RMSNorm
|
SkeletonDiffusion/src/core/network/layers/attention.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, einsum, Tensor
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import nn, einsum
|
| 6 |
+
|
| 7 |
+
from .graph_structural import StaticGraphLinear
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Residual(nn.Module):
|
| 12 |
+
def __init__(self, fn):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.fn = fn
|
| 15 |
+
|
| 16 |
+
def forward(self, x, *args, **kwargs):
|
| 17 |
+
return self.fn(x, *args, **kwargs) + x
|
| 18 |
+
|
| 19 |
+
class LayerNorm(nn.Module):
|
| 20 |
+
def __init__(self, dim):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.norm = torch.nn.LayerNorm((dim), elementwise_affine=True)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = torch.swapaxes(x, -2, -1)
|
| 26 |
+
x = self.norm(x)
|
| 27 |
+
x = torch.swapaxes(x, -2, -1)
|
| 28 |
+
return x
|
| 29 |
+
|
| 30 |
+
class RMSNorm(nn.Module):
|
| 31 |
+
def __init__(self, dim):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.g = nn.Parameter(torch.ones(1, 1, dim))
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
return F.normalize(x, dim = -1) * self.g * (x.shape[-1] ** 0.5) #normalize divides by maximum norm element. Different from original in which we take the max norma nd not the sum of square elem.
|
| 37 |
+
|
| 38 |
+
class PreNorm(nn.Module):
|
| 39 |
+
def __init__(self, dim, fn):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.fn = fn
|
| 42 |
+
self.norm = RMSNorm(dim)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = self.norm(x)
|
| 46 |
+
return self.fn(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Block(nn.Module):
|
| 50 |
+
def __init__(self, dim, dim_out, norm_type='none', act_type='tanh', *args, **kwargs):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.proj = StaticGraphLinear(dim, dim_out, *args, **kwargs)
|
| 53 |
+
# num_nodes=num_nodes,
|
| 54 |
+
# node_types=T)
|
| 55 |
+
if norm_type == 'none':
|
| 56 |
+
self.norm = nn.Identity() #nn.GroupNorm(groups, dim_out)
|
| 57 |
+
elif norm_type == 'layer':
|
| 58 |
+
self.norm = LayerNorm(kwargs['num_nodes'])
|
| 59 |
+
else:
|
| 60 |
+
assert 0, f"Norm type {norm_type} not implemented!"
|
| 61 |
+
if act_type == 'tanh':
|
| 62 |
+
self.act = nn.Tanh()
|
| 63 |
+
else:
|
| 64 |
+
assert 0, f"Activation type {act_type} not implemented!"
|
| 65 |
+
|
| 66 |
+
def forward(self, x, scale_shift = None):
|
| 67 |
+
x = self.proj(x)
|
| 68 |
+
x = self.norm(x)
|
| 69 |
+
|
| 70 |
+
if scale_shift is not None:
|
| 71 |
+
scale, shift = scale_shift
|
| 72 |
+
x = x * (scale + 1) + shift
|
| 73 |
+
|
| 74 |
+
x = self.act(x)
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ResnetBlock(nn.Module):
|
| 79 |
+
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8, **kwargs):
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.mlp = nn.Sequential(
|
| 82 |
+
nn.Tanh(),
|
| 83 |
+
nn.Linear(time_emb_dim, dim_out * 2)
|
| 84 |
+
) if time_emb_dim is not None else None
|
| 85 |
+
|
| 86 |
+
self.block1 = Block(dim, dim_out, groups = groups, **kwargs)
|
| 87 |
+
self.block2 = Block(dim_out, dim_out, groups = groups, **kwargs)
|
| 88 |
+
self.res_linear = StaticGraphLinear(dim, dim_out, bias=False, **kwargs) if dim != dim_out else nn.Identity()
|
| 89 |
+
|
| 90 |
+
def forward(self, x, time_emb = None):
|
| 91 |
+
|
| 92 |
+
scale_shift = None
|
| 93 |
+
if self.mlp is not None and time_emb is not None:
|
| 94 |
+
time_emb = self.mlp(time_emb)
|
| 95 |
+
time_emb = rearrange(time_emb, 'b c -> b 1 c')
|
| 96 |
+
scale_shift = time_emb.chunk(2, dim = -1)
|
| 97 |
+
|
| 98 |
+
h = self.block1(x, scale_shift = scale_shift)
|
| 99 |
+
|
| 100 |
+
h = self.block2(h)
|
| 101 |
+
|
| 102 |
+
return h + self.res_linear(x)
|
| 103 |
+
|
| 104 |
+
# We need default num_heads: int = 8,
|
| 105 |
+
class Attention(nn.Module):
|
| 106 |
+
def __init__(self, dim, dim_out=None, heads = 4, dim_head = 32,qkv_bias: bool = False, attn_dropout: float = 0., proj_dropout: float = 0., qk_norm: bool = False, norm_layer: nn.Module = nn.Identity, **kwargs):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.scale = dim_head ** -0.5
|
| 109 |
+
self.heads = heads
|
| 110 |
+
hidden_dim = dim_head * heads
|
| 111 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 112 |
+
|
| 113 |
+
self.to_qkv = StaticGraphLinear(dim,hidden_dim * 3,bias=qkv_bias, **kwargs)
|
| 114 |
+
self.to_out = StaticGraphLinear(hidden_dim,dim_out,bias=False,**kwargs)
|
| 115 |
+
self.attn_dropout = nn.Dropout(attn_dropout)
|
| 116 |
+
self.out_dropout = nn.Dropout(proj_dropout)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
self.q_norm = norm_layer(dim_head) if qk_norm else nn.Identity()
|
| 120 |
+
self.k_norm = norm_layer(dim_head) if qk_norm else nn.Identity()
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
b, n, c = x.shape
|
| 124 |
+
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
| 125 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h c) -> b h c n', h = self.heads), qkv)
|
| 126 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 127 |
+
|
| 128 |
+
q = q * self.scale
|
| 129 |
+
sim = einsum('b h c n, b h c j -> b h n j', q, k)
|
| 130 |
+
attn = sim.softmax(dim = -1)
|
| 131 |
+
attn = self.attn_dropout(attn)
|
| 132 |
+
|
| 133 |
+
out = einsum('b h n j, b h d j -> b h n d', attn, v)
|
| 134 |
+
|
| 135 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 136 |
+
return self.out_dropout(self.to_out(out))
|
| 137 |
+
|
| 138 |
+
|
SkeletonDiffusion/src/core/network/layers/graph_structural.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Optional, List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import *
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
def gmm(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
| 8 |
+
return torch.einsum('ndo,bnd->bno', w, x)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GraphLinear(Module):
|
| 12 |
+
def __init__(self, in_features: int, out_features: int):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.in_features = in_features
|
| 15 |
+
self.out_features = out_features
|
| 16 |
+
|
| 17 |
+
def reset_parameters(self) -> None:
|
| 18 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 19 |
+
#stdv = 1. / math.sqrt(self.weight.size(1))
|
| 20 |
+
#self.weight.data.uniform_(-stdv, stdv)
|
| 21 |
+
#if self.learn_influence:
|
| 22 |
+
# self.G.data.uniform_(-stdv, stdv)
|
| 23 |
+
if len(self.weight.shape) == 3:
|
| 24 |
+
self.weight.data[1:] = self.weight.data[0]
|
| 25 |
+
if self.bias is not None:
|
| 26 |
+
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
| 27 |
+
bound = 1 / math.sqrt(fan_in)
|
| 28 |
+
init.uniform_(self.bias, -bound, bound)
|
| 29 |
+
|
| 30 |
+
def forward(self, input: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 31 |
+
if g is None and self.learn_influence:
|
| 32 |
+
g = torch.nn.functional.normalize(self.G, p=1., dim=1)
|
| 33 |
+
#g = torch.softmax(self.G, dim=1)
|
| 34 |
+
elif g is None:
|
| 35 |
+
g = self.G
|
| 36 |
+
w = self.weight[self.node_type_index]
|
| 37 |
+
output = self.mm(input, w.transpose(-2, -1))
|
| 38 |
+
if self.bias is not None:
|
| 39 |
+
bias = self.bias[self.node_type_index]
|
| 40 |
+
output += bias
|
| 41 |
+
output = g.matmul(output)
|
| 42 |
+
|
| 43 |
+
return output
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DynamicGraphLinear(GraphLinear):
|
| 47 |
+
def __init__(self, num_node_types: int = 1, *args):
|
| 48 |
+
super().__init__(*args)
|
| 49 |
+
|
| 50 |
+
def forward(self, input: torch.Tensor, g: torch.Tensor = None, t: torch.Tensor = None) -> torch.Tensor:
|
| 51 |
+
assert g is not None or t is not None, "Either Graph Influence Matrix or Node Type Vector is needed"
|
| 52 |
+
if g is None:
|
| 53 |
+
g = self.G[t][:, t]
|
| 54 |
+
return super().forward(input, g)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class StaticGraphLinear(GraphLinear):
|
| 59 |
+
def __init__(self, *args, bias: bool = True, num_nodes: int = None, graph_influence: Union[torch.Tensor, Parameter] = None,
|
| 60 |
+
learn_influence: bool = False, node_types: torch.Tensor = None, weights_per_type: bool = False, **kwargs):
|
| 61 |
+
"""
|
| 62 |
+
:param in_features: Size of each input sample
|
| 63 |
+
:param out_features: Size of each output sample
|
| 64 |
+
:param num_nodes: Number of nodes.
|
| 65 |
+
:param graph_influence: Graph Influence Matrix
|
| 66 |
+
:param learn_influence: If set to ``False``, the layer will not learn an the Graph Influence Matrix.
|
| 67 |
+
:param node_types: List of Type for each node. All nodes of same type will share weights.
|
| 68 |
+
Default: All nodes have unique types.
|
| 69 |
+
:param weights_per_type: If set to ``False``, the layer will not learn weights for each node type.
|
| 70 |
+
:param bias: If set to ``False``, the layer will not learn an additive bias.
|
| 71 |
+
"""
|
| 72 |
+
super().__init__(*args)
|
| 73 |
+
|
| 74 |
+
self.learn_influence = learn_influence
|
| 75 |
+
|
| 76 |
+
if graph_influence is not None:
|
| 77 |
+
assert num_nodes == graph_influence.shape[0] or num_nodes is None, 'Number of Nodes or Graph Influence Matrix has to be given.'
|
| 78 |
+
num_nodes = graph_influence.shape[0]
|
| 79 |
+
if type(graph_influence) is Parameter:
|
| 80 |
+
assert learn_influence, "Graph Influence Matrix is a Parameter, therefore it must be learnable."
|
| 81 |
+
self.G = graph_influence
|
| 82 |
+
elif learn_influence:
|
| 83 |
+
self.G = Parameter(graph_influence)
|
| 84 |
+
else:
|
| 85 |
+
self.register_buffer('G', graph_influence)
|
| 86 |
+
else:
|
| 87 |
+
assert num_nodes, 'Number of Nodes or Graph Influence Matrix has to be given.'
|
| 88 |
+
eye_influence = torch.eye(num_nodes, num_nodes)
|
| 89 |
+
if learn_influence:
|
| 90 |
+
self.G = Parameter(eye_influence)
|
| 91 |
+
else:
|
| 92 |
+
self.register_buffer('G', eye_influence)
|
| 93 |
+
|
| 94 |
+
if weights_per_type and node_types is None:
|
| 95 |
+
node_types = torch.tensor([i for i in range(num_nodes)])
|
| 96 |
+
if node_types is not None:
|
| 97 |
+
num_node_types = node_types.max() + 1
|
| 98 |
+
self.weight = Parameter(torch.Tensor(num_node_types, self.out_features, self.in_features))
|
| 99 |
+
self.mm = gmm
|
| 100 |
+
self.node_type_index = node_types
|
| 101 |
+
else:
|
| 102 |
+
self.weight = Parameter(torch.Tensor(self.out_features, self.in_features))
|
| 103 |
+
self.mm = torch.matmul
|
| 104 |
+
self.node_type_index = None
|
| 105 |
+
|
| 106 |
+
if bias:
|
| 107 |
+
if node_types is not None:
|
| 108 |
+
self.bias = Parameter(torch.Tensor(num_node_types, self.out_features))
|
| 109 |
+
else:
|
| 110 |
+
self.bias = Parameter(torch.Tensor(self.out_features))
|
| 111 |
+
else:
|
| 112 |
+
self.register_parameter('bias', None)
|
| 113 |
+
|
| 114 |
+
self.reset_parameters()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# class BN(Module):
|
| 118 |
+
# def __init__(self, num_nodes, num_features):
|
| 119 |
+
# super().__init__()
|
| 120 |
+
# self.num_nodes = num_nodes
|
| 121 |
+
# self.num_features = num_features
|
| 122 |
+
# self.bn = BatchNorm1d(num_nodes * num_features)
|
| 123 |
+
|
| 124 |
+
# def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 125 |
+
# return self.bn(x.view(-1, self.num_nodes * self.num_features)).view(-1, self.num_nodes, self.num_features)
|
| 126 |
+
|
| 127 |
+
# class LinearX(Module):
|
| 128 |
+
# def __init__(self):
|
| 129 |
+
# super().__init__()
|
| 130 |
+
|
| 131 |
+
# def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
# return input
|
| 133 |
+
|
SkeletonDiffusion/src/core/network/layers/recurrent.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Optional, List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import *
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
from .graph_structural import gmm
|
| 8 |
+
|
| 9 |
+
IFDEF_JITSCRIPT = False # generates nan for long sequences.
|
| 10 |
+
|
| 11 |
+
GraphLSTMState = Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
|
| 12 |
+
|
| 13 |
+
class StaticGraphLSTMCell_(Module):
|
| 14 |
+
def __init__(self, input_size: int, hidden_size: int, num_nodes: int = None, dropout: float = 0.,
|
| 15 |
+
recurrent_dropout: float = 0., graph_influence: Union[torch.Tensor, Parameter] = None,
|
| 16 |
+
learn_influence: bool = False, additive_graph_influence: Union[torch.Tensor, Parameter] = None,
|
| 17 |
+
learn_additive_graph_influence: bool = False, node_types: torch.Tensor = None,
|
| 18 |
+
weights_per_type: bool = False, clockwork: bool = False, bias: bool = True):
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
:param input_size: The number of expected features in the input `x`
|
| 22 |
+
:param hidden_size: The number of features in the hidden state `h`
|
| 23 |
+
:param num_nodes:
|
| 24 |
+
:param dropout:
|
| 25 |
+
:param recurrent_dropout:
|
| 26 |
+
:param graph_influence:
|
| 27 |
+
:param learn_influence:
|
| 28 |
+
:param additive_graph_influence:
|
| 29 |
+
:param learn_additive_graph_influence:
|
| 30 |
+
:param node_types:
|
| 31 |
+
:param weights_per_type:
|
| 32 |
+
:param bias:
|
| 33 |
+
"""
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.input_size = input_size
|
| 36 |
+
self.hidden_size = hidden_size
|
| 37 |
+
|
| 38 |
+
self.learn_influence = learn_influence
|
| 39 |
+
self.learn_additive_graph_influence = learn_additive_graph_influence
|
| 40 |
+
if graph_influence is not None:
|
| 41 |
+
assert num_nodes == graph_influence.shape[0] or num_nodes is None, 'Number of Nodes or Graph Influence Matrix has to be given.'
|
| 42 |
+
num_nodes = graph_influence.shape[0]
|
| 43 |
+
if type(graph_influence) is Parameter:
|
| 44 |
+
assert learn_influence, "Graph Influence Matrix is a Parameter, therefore it must be learnable."
|
| 45 |
+
self.G = graph_influence
|
| 46 |
+
elif learn_influence:
|
| 47 |
+
self.G = Parameter(graph_influence)
|
| 48 |
+
else:
|
| 49 |
+
self.register_buffer('G', graph_influence)
|
| 50 |
+
else:
|
| 51 |
+
assert num_nodes, 'Number of Nodes or Graph Influence Matrix has to be given.'
|
| 52 |
+
eye_influence = torch.eye(num_nodes, num_nodes)
|
| 53 |
+
if learn_influence:
|
| 54 |
+
self.G = Parameter(eye_influence)
|
| 55 |
+
else:
|
| 56 |
+
self.register_buffer('G', eye_influence)
|
| 57 |
+
|
| 58 |
+
if additive_graph_influence is not None:
|
| 59 |
+
if type(additive_graph_influence) is Parameter:
|
| 60 |
+
self.G_add = additive_graph_influence
|
| 61 |
+
elif learn_additive_graph_influence:
|
| 62 |
+
self.G_add = Parameter(additive_graph_influence)
|
| 63 |
+
else:
|
| 64 |
+
self.register_buffer('G_add', additive_graph_influence)
|
| 65 |
+
else:
|
| 66 |
+
if learn_additive_graph_influence:
|
| 67 |
+
self.G_add = Parameter(torch.zeros_like(self.G))
|
| 68 |
+
else:
|
| 69 |
+
self.G_add = 0.
|
| 70 |
+
|
| 71 |
+
if weights_per_type and node_types is None:
|
| 72 |
+
node_types = torch.tensor([i for i in range(num_nodes)])
|
| 73 |
+
if node_types is not None:
|
| 74 |
+
num_node_types = node_types.max() + 1
|
| 75 |
+
self.weight_ih = Parameter(torch.Tensor(num_node_types, 4 * hidden_size, input_size))
|
| 76 |
+
self.weight_hh = Parameter(torch.Tensor(num_node_types, 4 * hidden_size, hidden_size))
|
| 77 |
+
self.mm = gmm
|
| 78 |
+
self.register_buffer('node_type_index', node_types)
|
| 79 |
+
else:
|
| 80 |
+
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
|
| 81 |
+
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
|
| 82 |
+
self.mm = torch.matmul
|
| 83 |
+
self.register_buffer('node_type_index', None)
|
| 84 |
+
|
| 85 |
+
if bias:
|
| 86 |
+
if node_types is not None:
|
| 87 |
+
self.bias_ih = Parameter(torch.Tensor(num_node_types, 4 * hidden_size))
|
| 88 |
+
self.bias_hh = Parameter(torch.Tensor(num_node_types, 4 * hidden_size))
|
| 89 |
+
else:
|
| 90 |
+
self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
|
| 91 |
+
self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))
|
| 92 |
+
else:
|
| 93 |
+
self.bias_ih = None
|
| 94 |
+
self.bias_hh = None
|
| 95 |
+
|
| 96 |
+
self.clockwork = clockwork
|
| 97 |
+
if clockwork:
|
| 98 |
+
phase = torch.arange(0., hidden_size)
|
| 99 |
+
phase = phase - phase.min()
|
| 100 |
+
phase = (phase / phase.max()) * 8.
|
| 101 |
+
phase += 1.
|
| 102 |
+
phase = torch.floor(phase)
|
| 103 |
+
self.register_buffer('phase', phase)
|
| 104 |
+
else:
|
| 105 |
+
phase = torch.ones(hidden_size)
|
| 106 |
+
self.register_buffer('phase', phase)
|
| 107 |
+
|
| 108 |
+
self.dropout = Dropout(dropout)
|
| 109 |
+
self.r_dropout = Dropout(recurrent_dropout)
|
| 110 |
+
|
| 111 |
+
self.num_nodes = num_nodes
|
| 112 |
+
|
| 113 |
+
self.init_weights()
|
| 114 |
+
|
| 115 |
+
def init_weights(self):
|
| 116 |
+
stdv = 1.0 / math.sqrt(self.hidden_size)
|
| 117 |
+
for weight in self.parameters():
|
| 118 |
+
if weight is self.G:
|
| 119 |
+
continue
|
| 120 |
+
if weight is self.G_add:
|
| 121 |
+
continue
|
| 122 |
+
weight.data.uniform_(-stdv, stdv)
|
| 123 |
+
if weight is self.weight_hh or weight is self.weight_ih and len(self.weight_ih.shape) == 3:
|
| 124 |
+
weight.data[1:] = weight.data[0]
|
| 125 |
+
|
| 126 |
+
def forward(self, input: torch.Tensor, state: GraphLSTMState, t: int = 0) -> Tuple[torch.Tensor, GraphLSTMState]:
|
| 127 |
+
hx, cx, gx = state
|
| 128 |
+
if hx is None:
|
| 129 |
+
hx = torch.zeros(input.shape[0], self.num_nodes, self.hidden_size, dtype=input.dtype, device=input.device)
|
| 130 |
+
if cx is None:
|
| 131 |
+
cx = torch.zeros(input.shape[0], self.num_nodes, self.hidden_size, dtype=input.dtype, device=input.device)
|
| 132 |
+
if gx is None and self.learn_influence:
|
| 133 |
+
gx = torch.nn.functional.normalize(self.G, p=1., dim=1)
|
| 134 |
+
#gx = torch.softmax(self.G, dim=1)
|
| 135 |
+
elif gx is None:
|
| 136 |
+
gx = self.G
|
| 137 |
+
|
| 138 |
+
hx = self.r_dropout(hx)
|
| 139 |
+
|
| 140 |
+
weight_ih = self.weight_ih[self.node_type_index]
|
| 141 |
+
weight_hh = self.weight_hh[self.node_type_index]
|
| 142 |
+
if self.bias_hh is not None:
|
| 143 |
+
bias_hh = self.bias_hh[self.node_type_index]
|
| 144 |
+
else:
|
| 145 |
+
bias_hh = 0.
|
| 146 |
+
|
| 147 |
+
c_mask = (torch.remainder(torch.tensor(t + 1., device=input.device), self.phase) < 0.01).type_as(cx)
|
| 148 |
+
|
| 149 |
+
gates = (self.dropout(self.mm(input, weight_ih.transpose(-2, -1))) +
|
| 150 |
+
self.mm(hx, weight_hh.transpose(-2, -1)) + bias_hh)
|
| 151 |
+
gates = torch.matmul(gx, gates)
|
| 152 |
+
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 2)
|
| 153 |
+
|
| 154 |
+
ingate = torch.sigmoid(ingate)
|
| 155 |
+
forgetgate = torch.sigmoid(forgetgate)
|
| 156 |
+
cellgate = torch.tanh(cellgate)
|
| 157 |
+
outgate = torch.sigmoid(outgate)
|
| 158 |
+
|
| 159 |
+
cy = c_mask * ((forgetgate * cx) + (ingate * cellgate)) + (1 - c_mask) * cx
|
| 160 |
+
hy = outgate * torch.tanh(cy)
|
| 161 |
+
|
| 162 |
+
gx = gx + self.G_add
|
| 163 |
+
if self.learn_influence or self.learn_additive_graph_influence:
|
| 164 |
+
gx = torch.nn.functional.normalize(gx, p=1., dim=1)
|
| 165 |
+
#gx = torch.softmax(gx, dim=1)
|
| 166 |
+
|
| 167 |
+
return hy, (hy, cy, gx)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class StaticGraphLSTM_(Module):
|
| 171 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, layer_dropout: float = 0.0, **kwargs):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.layers = ModuleList([StaticGraphLSTMCell_(input_size, hidden_size, **kwargs)]
|
| 174 |
+
+ [StaticGraphLSTMCell_(hidden_size, hidden_size, **kwargs) for _ in range(num_layers - 1)])
|
| 175 |
+
self.dropout = Dropout(layer_dropout)
|
| 176 |
+
|
| 177 |
+
def forward(self, input: torch.Tensor, states: Optional[List[GraphLSTMState]] = None, t_i: int = 0) -> Tuple[torch.Tensor, List[GraphLSTMState]]:
|
| 178 |
+
if states is None:
|
| 179 |
+
n: Optional[torch.Tensor] = None
|
| 180 |
+
states = [(n, n, n)] * len(self.layers)
|
| 181 |
+
|
| 182 |
+
output_states: List[GraphLSTMState] = []
|
| 183 |
+
output = input
|
| 184 |
+
i = 0
|
| 185 |
+
for rnn_layer in self.layers:
|
| 186 |
+
state = states[i]
|
| 187 |
+
inputs = output.unbind(1)
|
| 188 |
+
outputs: List[torch.Tensor] = []
|
| 189 |
+
for t, input in enumerate(inputs):
|
| 190 |
+
out, state = rnn_layer(input, state, t_i+t)
|
| 191 |
+
outputs += [out]
|
| 192 |
+
output = torch.stack(outputs, dim=1)
|
| 193 |
+
output = self.dropout(output)
|
| 194 |
+
output_states += [state]
|
| 195 |
+
i += 1
|
| 196 |
+
return output, output_states
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def StaticGraphLSTM(*args, **kwargs):
|
| 200 |
+
if IFDEF_JITSCRIPT:
|
| 201 |
+
return torch.jit.script(StaticGraphLSTM_(*args, **kwargs))
|
| 202 |
+
else:
|
| 203 |
+
return StaticGraphLSTM_(*args, **kwargs)
|
| 204 |
+
|
| 205 |
+
GraphGRUState = Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class StaticGraphGRUCell_(Module):
|
| 209 |
+
def __init__(self, input_size: int, hidden_size: int, num_nodes: int = None, dropout: float = 0.,
|
| 210 |
+
recurrent_dropout: float = 0., graph_influence: Union[torch.Tensor, Parameter] = None,
|
| 211 |
+
learn_influence: bool = False, additive_graph_influence: Union[torch.Tensor, Parameter] = None,
|
| 212 |
+
learn_additive_graph_influence: bool = False, node_types: torch.Tensor = None,
|
| 213 |
+
weights_per_type: bool = False, clockwork: bool = False, bias: bool = True):
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
:param input_size: The number of expected features in the input `x`
|
| 217 |
+
:param hidden_size: The number of features in the hidden state `h`
|
| 218 |
+
:param num_nodes:
|
| 219 |
+
:param dropout:
|
| 220 |
+
:param recurrent_dropout:
|
| 221 |
+
:param graph_influence:
|
| 222 |
+
:param learn_influence:
|
| 223 |
+
:param additive_graph_influence:
|
| 224 |
+
:param learn_additive_graph_influence:
|
| 225 |
+
:param node_types:
|
| 226 |
+
:param weights_per_type:
|
| 227 |
+
:param bias:
|
| 228 |
+
"""
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.input_size = input_size
|
| 231 |
+
self.hidden_size = hidden_size
|
| 232 |
+
|
| 233 |
+
self.learn_influence = learn_influence
|
| 234 |
+
self.learn_additive_graph_influence = learn_additive_graph_influence
|
| 235 |
+
if graph_influence is not None:
|
| 236 |
+
assert num_nodes == graph_influence.shape[0] or num_nodes is None, 'Number of Nodes or Graph Influence Matrix has to be given.'
|
| 237 |
+
num_nodes = graph_influence.shape[0]
|
| 238 |
+
if type(graph_influence) is Parameter:
|
| 239 |
+
assert learn_influence, "Graph Influence Matrix is a Parameter, therefore it must be learnable."
|
| 240 |
+
self.G = graph_influence
|
| 241 |
+
elif learn_influence:
|
| 242 |
+
self.G = Parameter(graph_influence)
|
| 243 |
+
else:
|
| 244 |
+
self.register_buffer('G', graph_influence)
|
| 245 |
+
else:
|
| 246 |
+
assert num_nodes, 'Number of Nodes or Graph Influence Matrix has to be given.'
|
| 247 |
+
eye_influence = torch.eye(num_nodes, num_nodes)
|
| 248 |
+
if learn_influence:
|
| 249 |
+
self.G = Parameter(eye_influence)
|
| 250 |
+
else:
|
| 251 |
+
self.register_buffer('G', eye_influence)
|
| 252 |
+
|
| 253 |
+
if additive_graph_influence is not None:
|
| 254 |
+
if type(additive_graph_influence) is Parameter:
|
| 255 |
+
self.G_add = additive_graph_influence
|
| 256 |
+
elif learn_additive_graph_influence:
|
| 257 |
+
self.G_add = Parameter(additive_graph_influence)
|
| 258 |
+
else:
|
| 259 |
+
self.register_buffer('G_add', additive_graph_influence)
|
| 260 |
+
else:
|
| 261 |
+
if learn_additive_graph_influence:
|
| 262 |
+
self.G_add = Parameter(torch.zeros_like(self.G))
|
| 263 |
+
else:
|
| 264 |
+
self.G_add = 0.
|
| 265 |
+
|
| 266 |
+
if weights_per_type and node_types is None:
|
| 267 |
+
node_types = torch.tensor([i for i in range(num_nodes)])
|
| 268 |
+
if node_types is not None:
|
| 269 |
+
num_node_types = node_types.max() + 1
|
| 270 |
+
self.weight_ih = Parameter(torch.Tensor(num_node_types, 3 * hidden_size, input_size))
|
| 271 |
+
self.weight_hh = Parameter(torch.Tensor(num_node_types, 3 * hidden_size, hidden_size))
|
| 272 |
+
self.mm = gmm
|
| 273 |
+
self.register_buffer('node_type_index', node_types)
|
| 274 |
+
else:
|
| 275 |
+
self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
|
| 276 |
+
self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size))
|
| 277 |
+
self.mm = torch.matmul
|
| 278 |
+
self.register_buffer('node_type_index', None)
|
| 279 |
+
|
| 280 |
+
if bias:
|
| 281 |
+
if node_types is not None:
|
| 282 |
+
self.bias_ih = Parameter(torch.Tensor(num_node_types, 3 * hidden_size))
|
| 283 |
+
self.bias_hh = Parameter(torch.Tensor(num_node_types, 3 * hidden_size))
|
| 284 |
+
else:
|
| 285 |
+
self.bias_ih = Parameter(torch.Tensor(3 * hidden_size))
|
| 286 |
+
self.bias_hh = Parameter(torch.Tensor(3 * hidden_size))
|
| 287 |
+
else:
|
| 288 |
+
self.bias_ih = None
|
| 289 |
+
self.bias_hh = None
|
| 290 |
+
|
| 291 |
+
self.clockwork = clockwork
|
| 292 |
+
if clockwork:
|
| 293 |
+
phase = torch.arange(0., hidden_size)
|
| 294 |
+
phase = phase - phase.min()
|
| 295 |
+
phase = (phase / phase.max()) * 8.
|
| 296 |
+
phase += 1.
|
| 297 |
+
phase = torch.floor(phase)
|
| 298 |
+
self.register_buffer('phase', phase)
|
| 299 |
+
else:
|
| 300 |
+
phase = torch.ones(hidden_size)
|
| 301 |
+
self.register_buffer('phase', phase)
|
| 302 |
+
|
| 303 |
+
self.dropout = Dropout(dropout)
|
| 304 |
+
self.r_dropout = Dropout(recurrent_dropout)
|
| 305 |
+
|
| 306 |
+
self.num_nodes = num_nodes
|
| 307 |
+
|
| 308 |
+
self.init_weights()
|
| 309 |
+
|
| 310 |
+
def init_weights(self):
|
| 311 |
+
stdv = 1.0 / math.sqrt(self.hidden_size)
|
| 312 |
+
for weight in self.parameters():
|
| 313 |
+
if weight is self.G:
|
| 314 |
+
continue
|
| 315 |
+
if weight is self.G_add:
|
| 316 |
+
continue
|
| 317 |
+
weight.data.uniform_(-stdv, stdv)
|
| 318 |
+
#if weight is self.weight_hh or weight is self.weight_ih and len(self.weight_ih.shape) == 3:
|
| 319 |
+
# weight.data[1:] = weight.data[0]
|
| 320 |
+
|
| 321 |
+
def forward(self, input: torch.Tensor, state: GraphGRUState, t: int = 0) -> Tuple[torch.Tensor, GraphGRUState]:
|
| 322 |
+
hx, gx = state
|
| 323 |
+
if hx is None:
|
| 324 |
+
hx = torch.zeros(input.shape[0], self.num_nodes, self.hidden_size, dtype=input.dtype, device=input.device)
|
| 325 |
+
if gx is None and self.learn_influence:
|
| 326 |
+
gx = torch.nn.functional.normalize(self.G, p=1., dim=1)
|
| 327 |
+
#gx = torch.softmax(self.G, dim=1)
|
| 328 |
+
elif gx is None:
|
| 329 |
+
gx = self.G
|
| 330 |
+
|
| 331 |
+
hx = self.r_dropout(hx)
|
| 332 |
+
|
| 333 |
+
weight_ih = self.weight_ih[self.node_type_index]
|
| 334 |
+
weight_hh = self.weight_hh[self.node_type_index]
|
| 335 |
+
if self.bias_hh is not None:
|
| 336 |
+
bias_hh = self.bias_hh[self.node_type_index]
|
| 337 |
+
else:
|
| 338 |
+
bias_hh = 0.
|
| 339 |
+
if self.bias_ih is not None:
|
| 340 |
+
bias_ih = self.bias_ih[self.node_type_index]
|
| 341 |
+
else:
|
| 342 |
+
bias_ih = 0.
|
| 343 |
+
|
| 344 |
+
c_mask = (torch.remainder(torch.tensor(t + 1., device=input.device), self.phase) < 0.01).type_as(hx)
|
| 345 |
+
|
| 346 |
+
x_results = self.dropout(self.mm(input, weight_ih.transpose(-2, -1))) + bias_ih
|
| 347 |
+
h_results = self.mm(hx, weight_hh.transpose(-2, -1)) + bias_hh
|
| 348 |
+
x_results = torch.matmul(gx, x_results)
|
| 349 |
+
h_results = torch.matmul(gx, h_results)
|
| 350 |
+
|
| 351 |
+
i_r, i_z, i_n = x_results.chunk(3, 2)
|
| 352 |
+
h_r, h_z, h_n = h_results.chunk(3, 2)
|
| 353 |
+
|
| 354 |
+
r = torch.sigmoid(i_r + h_r)
|
| 355 |
+
z = torch.sigmoid(i_z + h_z)
|
| 356 |
+
n = torch.tanh(i_n + r * h_n)
|
| 357 |
+
|
| 358 |
+
hy = n - torch.mul(n, z) + torch.mul(z, hx)
|
| 359 |
+
hy = c_mask * hy + (1 - c_mask) * hx
|
| 360 |
+
|
| 361 |
+
gx = gx + self.G_add
|
| 362 |
+
if self.learn_influence or self.learn_additive_graph_influence:
|
| 363 |
+
gx = torch.nn.functional.normalize(gx, p=1., dim=1)
|
| 364 |
+
#gx = torch.softmax(gx, dim=1)
|
| 365 |
+
|
| 366 |
+
return hy, (hy, gx)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class StaticGraphGRU_(Module):
|
| 370 |
+
def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, layer_dropout: float = 0.0, **kwargs):
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.layers = ModuleList([StaticGraphGRUCell_(input_size, hidden_size, **kwargs)]
|
| 373 |
+
+ [StaticGraphGRUCell_(hidden_size, hidden_size, **kwargs) for _ in range(num_layers - 1)])
|
| 374 |
+
self.dropout = Dropout(layer_dropout)
|
| 375 |
+
|
| 376 |
+
def forward(self, input: torch.Tensor, states: Optional[List[GraphGRUState]] = None, t_i: int = 0) -> Tuple[torch.Tensor, List[GraphGRUState]]:
|
| 377 |
+
if states is None:
|
| 378 |
+
n: Optional[torch.Tensor] = None
|
| 379 |
+
states = [(n, n)] * len(self.layers)
|
| 380 |
+
|
| 381 |
+
output_states: List[GraphGRUState] = []
|
| 382 |
+
output = input
|
| 383 |
+
i = 0
|
| 384 |
+
for rnn_layer in self.layers:
|
| 385 |
+
state = states[i]
|
| 386 |
+
inputs = output.unbind(1)
|
| 387 |
+
outputs: List[torch.Tensor] = []
|
| 388 |
+
for t, input in enumerate(inputs):
|
| 389 |
+
out, state = rnn_layer(input, state, t_i+t)
|
| 390 |
+
outputs += [out]
|
| 391 |
+
output = torch.stack(outputs, dim=1)
|
| 392 |
+
output = self.dropout(output)
|
| 393 |
+
output_states += [state]
|
| 394 |
+
i += 1
|
| 395 |
+
return output, output_states
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def StaticGraphGRU(*args, **kwargs):
|
| 399 |
+
if IFDEF_JITSCRIPT:
|
| 400 |
+
return torch.jit.script(StaticGraphGRU_(*args, **kwargs))
|
| 401 |
+
else:
|
| 402 |
+
return StaticGraphGRU_(*args, **kwargs)
|
SkeletonDiffusion/src/core/network/nn/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .generator import Denoiser
|
| 2 |
+
from .autoencoder import AutoEncoder
|
SkeletonDiffusion/src/core/network/nn/autoencoder.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from ..layers import StaticGraphLinear
|
| 7 |
+
from .decoder import Decoder
|
| 8 |
+
from .encoder import Encoder
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AutoEncoder(nn.Module):
|
| 12 |
+
def __init__(self,
|
| 13 |
+
num_nodes: int,
|
| 14 |
+
encoder_hidden_size: int,
|
| 15 |
+
decoder_hidden_size: int,
|
| 16 |
+
latent_size: int,
|
| 17 |
+
node_types: torch.Tensor = None,
|
| 18 |
+
input_size: int = 3,
|
| 19 |
+
z_activation: str = 'tanh',
|
| 20 |
+
enc_num_layers: int = 1,
|
| 21 |
+
loss_pose_type: str = 'l1',
|
| 22 |
+
**kwargs):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.param_groups = [{}]
|
| 25 |
+
self.latent_size = latent_size
|
| 26 |
+
self.loss_pose_type = loss_pose_type
|
| 27 |
+
|
| 28 |
+
self.encoder = Encoder(num_nodes=num_nodes,
|
| 29 |
+
input_size=input_size,
|
| 30 |
+
hidden_size=encoder_hidden_size,
|
| 31 |
+
output_size=latent_size,
|
| 32 |
+
node_types=node_types,
|
| 33 |
+
enc_num_layers = enc_num_layers,
|
| 34 |
+
recurrent_arch = kwargs['recurrent_arch_enc'],)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
assert kwargs['output_size'] == input_size
|
| 38 |
+
self.decoder = Decoder( num_nodes=num_nodes,
|
| 39 |
+
input_size=latent_size ,
|
| 40 |
+
feature_size=input_size,
|
| 41 |
+
hidden_size=decoder_hidden_size,
|
| 42 |
+
node_types=node_types,
|
| 43 |
+
param_groups=self.param_groups,
|
| 44 |
+
**kwargs
|
| 45 |
+
)
|
| 46 |
+
assert z_activation in ['tanh', 'identity'], f"z_activation must be either 'tanh' or 'identity', but got {z_activation}"
|
| 47 |
+
self.z_activation = nn.Tanh() if z_activation == "tanh" else nn.Identity()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
h, _ = self.encoder(x)
|
| 52 |
+
return h
|
| 53 |
+
|
| 54 |
+
def get_past_embedding(self, past, state=None):
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
h_hat_embedding = self(past)
|
| 57 |
+
z_past = self.z_activation(h_hat_embedding)
|
| 58 |
+
return z_past
|
| 59 |
+
|
| 60 |
+
def get_embedding(self, future, state=None):
|
| 61 |
+
z = self.forward(future)
|
| 62 |
+
return z
|
| 63 |
+
|
| 64 |
+
def get_train_embeddings(self, y, past, state=None):
|
| 65 |
+
z_past = self.get_past_embedding(past, state=state)
|
| 66 |
+
z = self.get_embedding(y, state=state)
|
| 67 |
+
return z_past, z
|
| 68 |
+
|
| 69 |
+
def decode(self, x: torch.Tensor, h: torch.Tensor, z: torch.Tensor, ph=1, state=None):
|
| 70 |
+
x_tiled = x[:, -2:]
|
| 71 |
+
out, _ = self.decoder(x=x_tiled,
|
| 72 |
+
h=h,
|
| 73 |
+
z=z,
|
| 74 |
+
ph=ph,
|
| 75 |
+
state=state) # [B * Z, T, N, D]
|
| 76 |
+
return out
|
| 77 |
+
|
| 78 |
+
def autoencode(self, y, past, ph=1, state=None):
|
| 79 |
+
z_past, z = self.get_train_embeddings(y, past, state=state)
|
| 80 |
+
out = self.decode(past, z, z_past, ph)
|
| 81 |
+
return out, z_past, z
|
| 82 |
+
|
| 83 |
+
def loss(self, y_pred, y, type=None, reduction="mean", **kwargs):
|
| 84 |
+
type = self.loss_pose_type if type is None else type
|
| 85 |
+
if type=="mse":
|
| 86 |
+
out = torch.nn.MSELoss(reduction="none")(y_pred,y)
|
| 87 |
+
elif type in ["l1", "L1"]:
|
| 88 |
+
out = torch.nn.L1Loss(reduction="none")(y_pred,y)
|
| 89 |
+
else:
|
| 90 |
+
assert 0, "Not implemnted"
|
| 91 |
+
loss = (out.sum(-1) #spatial size
|
| 92 |
+
.mean(-1) #keypoints
|
| 93 |
+
.mean(-1) # timesteps
|
| 94 |
+
)
|
| 95 |
+
if reduction == "mean":
|
| 96 |
+
return loss.mean()
|
| 97 |
+
elif reduction == "none":
|
| 98 |
+
return loss
|
| 99 |
+
else:
|
| 100 |
+
assert 0, "Not implemnted"
|
| 101 |
+
return loss
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|