Spaces:
Running
on
Zero
Running
on
Zero
root
commited on
Commit
·
7b75adb
1
Parent(s):
be73458
add our app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +215 -0
- P3-SAM/demo/assets/1.glb +3 -0
- P3-SAM/demo/assets/2.glb +3 -0
- P3-SAM/demo/assets/3.glb +3 -0
- P3-SAM/demo/assets/4.glb +3 -0
- P3-SAM/demo/auto_mask.py +1405 -0
- P3-SAM/demo/auto_mask_no_postprocess.py +943 -0
- P3-SAM/model.py +156 -0
- P3-SAM/utils/chamfer3D/chamfer3D.cu +196 -0
- P3-SAM/utils/chamfer3D/chamfer_cuda.cpp +29 -0
- P3-SAM/utils/chamfer3D/dist_chamfer_3D.py +81 -0
- P3-SAM/utils/chamfer3D/setup.py +14 -0
- XPart/data/000.glb +3 -0
- XPart/data/001.glb +3 -0
- XPart/data/002.glb +3 -0
- XPart/data/003.glb +3 -0
- XPart/data/004.glb +3 -0
- XPart/partgen/bbox_estimator/auto_mask_api.py +1417 -0
- XPart/partgen/config/infer.yaml +122 -0
- XPart/partgen/config/sonata.json +58 -0
- XPart/partgen/models/autoencoders/__init__.py +29 -0
- XPart/partgen/models/autoencoders/attention_blocks.py +770 -0
- XPart/partgen/models/autoencoders/attention_processors.py +32 -0
- XPart/partgen/models/autoencoders/model.py +452 -0
- XPart/partgen/models/autoencoders/surface_extractors.py +164 -0
- XPart/partgen/models/autoencoders/volume_decoders.py +107 -0
- XPart/partgen/models/conditioner/condioner_release.py +170 -0
- XPart/partgen/models/conditioner/part_encoders.py +89 -0
- XPart/partgen/models/conditioner/sonata_extractor.py +315 -0
- XPart/partgen/models/diffusion/schedulers.py +329 -0
- XPart/partgen/models/diffusion/transport/__init__.py +97 -0
- XPart/partgen/models/diffusion/transport/integrators.py +142 -0
- XPart/partgen/models/diffusion/transport/path.py +220 -0
- XPart/partgen/models/diffusion/transport/transport.py +506 -0
- XPart/partgen/models/diffusion/transport/utils.py +54 -0
- XPart/partgen/models/moe_layers.py +209 -0
- XPart/partgen/models/partformer_dit.py +756 -0
- XPart/partgen/models/sonata/__init__.py +35 -0
- XPart/partgen/models/sonata/data.py +84 -0
- XPart/partgen/models/sonata/model.py +874 -0
- XPart/partgen/models/sonata/module.py +107 -0
- XPart/partgen/models/sonata/registry.py +340 -0
- XPart/partgen/models/sonata/serialization/__init__.py +9 -0
- XPart/partgen/models/sonata/serialization/default.py +82 -0
- XPart/partgen/models/sonata/serialization/hilbert.py +318 -0
- XPart/partgen/models/sonata/serialization/z_order.py +145 -0
- XPart/partgen/models/sonata/structure.py +159 -0
- XPart/partgen/models/sonata/transform.py +1330 -0
- XPart/partgen/models/sonata/utils.py +75 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
# Streamlit
|
| 210 |
+
.streamlit/secrets.toml
|
| 211 |
+
|
| 212 |
+
results/
|
| 213 |
+
weights/
|
| 214 |
+
.gradio/
|
| 215 |
+
demo/segment_result.glb
|
P3-SAM/demo/assets/1.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b57626cf269cb1e5b949586b6b6b87efa552c66d0a45371fed0c9f47db4c3314
|
| 3 |
+
size 29140044
|
P3-SAM/demo/assets/2.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6cedbb7365e8506d42809cc547fc0d69d9118af1d419c8b8748bae01b545634c
|
| 3 |
+
size 8529600
|
P3-SAM/demo/assets/3.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b179b1141dc0dd847e0be289c3f9b0dfd9446995b38b5050d41df7cbabbc516d
|
| 3 |
+
size 30475016
|
P3-SAM/demo/assets/4.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be9fea2e4b63233ab30f929ca9f735ccbd154b1f9652e9bbaacd87839829f02b
|
| 3 |
+
size 29621968
|
P3-SAM/demo/auto_mask.py
ADDED
|
@@ -0,0 +1,1405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
import argparse
|
| 7 |
+
import trimesh
|
| 8 |
+
from sklearn.decomposition import PCA
|
| 9 |
+
import fpsample
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import threading
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
# from tqdm.notebook import tqdm
|
| 15 |
+
import time
|
| 16 |
+
import copy
|
| 17 |
+
import shutil
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
|
| 22 |
+
import numba
|
| 23 |
+
from numba import njit
|
| 24 |
+
|
| 25 |
+
sys.path.append('..')
|
| 26 |
+
from model import build_P3SAM, load_state_dict
|
| 27 |
+
|
| 28 |
+
class P3SAM(nn.Module):
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super().__init__()
|
| 31 |
+
build_P3SAM(self)
|
| 32 |
+
|
| 33 |
+
def load_state_dict(self,
|
| 34 |
+
ckpt_path=None,
|
| 35 |
+
state_dict=None,
|
| 36 |
+
strict=True,
|
| 37 |
+
assign=False,
|
| 38 |
+
ignore_seg_mlp=False,
|
| 39 |
+
ignore_seg_s2_mlp=False,
|
| 40 |
+
ignore_iou_mlp=False):
|
| 41 |
+
load_state_dict(self,
|
| 42 |
+
ckpt_path=ckpt_path,
|
| 43 |
+
state_dict=state_dict,
|
| 44 |
+
strict=strict,
|
| 45 |
+
assign=assign,
|
| 46 |
+
ignore_seg_mlp=ignore_seg_mlp,
|
| 47 |
+
ignore_seg_s2_mlp=ignore_seg_s2_mlp,
|
| 48 |
+
ignore_iou_mlp=ignore_iou_mlp)
|
| 49 |
+
|
| 50 |
+
def forward(self, feats, points, point_prompt, iter=1):
|
| 51 |
+
'''
|
| 52 |
+
feats: [K, N, 512]
|
| 53 |
+
points: [K, N, 3]
|
| 54 |
+
point_prompt: [K, N, 3]
|
| 55 |
+
'''
|
| 56 |
+
# print(feats.shape, points.shape, point_prompt.shape)
|
| 57 |
+
point_num = points.shape[1]
|
| 58 |
+
feats = feats.transpose(0, 1) # [N, K, 512]
|
| 59 |
+
points = points.transpose(0, 1) # [N, K, 3]
|
| 60 |
+
point_prompt = point_prompt.transpose(0, 1) # [N, K, 3]
|
| 61 |
+
feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3]
|
| 62 |
+
|
| 63 |
+
# 预测mask stage-1
|
| 64 |
+
pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K]
|
| 65 |
+
pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K]
|
| 66 |
+
pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K]
|
| 67 |
+
pred_mask = torch.stack(
|
| 68 |
+
[pred_mask_1, pred_mask_2, pred_mask_3], dim=-1
|
| 69 |
+
) # [N, K, 3]
|
| 70 |
+
|
| 71 |
+
for _ in range(iter):
|
| 72 |
+
# 预测mask stage-2
|
| 73 |
+
feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3]
|
| 74 |
+
feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512]
|
| 75 |
+
feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512]
|
| 76 |
+
feats_seg_global = feats_seg_global.unsqueeze(0).repeat(
|
| 77 |
+
point_num, 1, 1
|
| 78 |
+
) # [N, K, 512]
|
| 79 |
+
feats_seg_3 = torch.cat(
|
| 80 |
+
[feats_seg_global, feats_seg_2], dim=-1
|
| 81 |
+
) # [N, K, 512+3+3+3+512]
|
| 82 |
+
pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K]
|
| 83 |
+
pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K]
|
| 84 |
+
pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K]
|
| 85 |
+
pred_mask_s2 = torch.stack(
|
| 86 |
+
[pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1
|
| 87 |
+
) # [N,, K 3]
|
| 88 |
+
pred_mask = pred_mask_s2
|
| 89 |
+
|
| 90 |
+
mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K]
|
| 91 |
+
mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K]
|
| 92 |
+
mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K]
|
| 93 |
+
|
| 94 |
+
feats_iou = torch.cat(
|
| 95 |
+
[feats_seg_global, feats_seg, pred_mask_s2], dim=-1
|
| 96 |
+
) # [N, K, 512+3+3+3+512]
|
| 97 |
+
feats_iou = self.iou_mlp(feats_iou) # [N, K, 512]
|
| 98 |
+
feats_iou = torch.max(feats_iou, dim=0).values # [K, 512]
|
| 99 |
+
pred_iou = self.iou_mlp_out(feats_iou) # [K, 3]
|
| 100 |
+
pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3]
|
| 101 |
+
|
| 102 |
+
mask_1 = mask_1.transpose(0, 1) # [K, N]
|
| 103 |
+
mask_2 = mask_2.transpose(0, 1) # [K, N]
|
| 104 |
+
mask_3 = mask_3.transpose(0, 1) # [K, N]
|
| 105 |
+
|
| 106 |
+
return mask_1, mask_2, mask_3, pred_iou
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def normalize_pc(pc):
|
| 110 |
+
"""
|
| 111 |
+
pc: (N, 3)
|
| 112 |
+
"""
|
| 113 |
+
max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0)
|
| 114 |
+
center = (max_ + min_) / 2
|
| 115 |
+
scale = (max_ - min_) / 2
|
| 116 |
+
scale = np.max(np.abs(scale))
|
| 117 |
+
pc = (pc - center) / (scale + 1e-10)
|
| 118 |
+
return pc
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def get_feat(model, points, normals):
|
| 123 |
+
data_dict = {
|
| 124 |
+
"coord": points,
|
| 125 |
+
"normal": normals,
|
| 126 |
+
"color": np.ones_like(points),
|
| 127 |
+
"batch": np.zeros(points.shape[0], dtype=np.int64),
|
| 128 |
+
}
|
| 129 |
+
data_dict = model.transform(data_dict)
|
| 130 |
+
for k in data_dict:
|
| 131 |
+
if isinstance(data_dict[k], torch.Tensor):
|
| 132 |
+
data_dict[k] = data_dict[k].cuda()
|
| 133 |
+
point = model.sonata(data_dict)
|
| 134 |
+
while "pooling_parent" in point.keys():
|
| 135 |
+
assert "pooling_inverse" in point.keys()
|
| 136 |
+
parent = point.pop("pooling_parent")
|
| 137 |
+
inverse = point.pop("pooling_inverse")
|
| 138 |
+
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
|
| 139 |
+
point = parent
|
| 140 |
+
feat = point.feat # [M, 1232]
|
| 141 |
+
feat = model.mlp(feat) # [M, 512]
|
| 142 |
+
feat = feat[point.inverse] # [N, 512]
|
| 143 |
+
feats = feat
|
| 144 |
+
return feats
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@torch.no_grad()
|
| 148 |
+
def get_mask(model, feats, points, point_prompt, iter=1):
|
| 149 |
+
"""
|
| 150 |
+
feats: [N, 512]
|
| 151 |
+
points: [N, 3]
|
| 152 |
+
point_prompt: [K, 3]
|
| 153 |
+
"""
|
| 154 |
+
point_num = points.shape[0]
|
| 155 |
+
prompt_num = point_prompt.shape[0]
|
| 156 |
+
feats = feats.unsqueeze(1) # [N, 1, 512]
|
| 157 |
+
feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512]
|
| 158 |
+
points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3]
|
| 159 |
+
points = points.repeat(1, prompt_num, 1) # [N, K, 3]
|
| 160 |
+
prompt_coord = (
|
| 161 |
+
torch.from_numpy(point_prompt).float().cuda().unsqueeze(0)
|
| 162 |
+
) # [1, K, 3]
|
| 163 |
+
prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3]
|
| 164 |
+
|
| 165 |
+
feats = feats.transpose(0, 1) # [K, N, 512]
|
| 166 |
+
points = points.transpose(0, 1) # [K, N, 3]
|
| 167 |
+
prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3]
|
| 168 |
+
|
| 169 |
+
mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter)
|
| 170 |
+
|
| 171 |
+
mask_1 = mask_1.transpose(0, 1) # [N, K]
|
| 172 |
+
mask_2 = mask_2.transpose(0, 1) # [N, K]
|
| 173 |
+
mask_3 = mask_3.transpose(0, 1) # [N, K]
|
| 174 |
+
|
| 175 |
+
mask_1 = mask_1.detach().cpu().numpy() > 0.5
|
| 176 |
+
mask_2 = mask_2.detach().cpu().numpy() > 0.5
|
| 177 |
+
mask_3 = mask_3.detach().cpu().numpy() > 0.5
|
| 178 |
+
|
| 179 |
+
org_iou = pred_iou.detach().cpu().numpy() # [K, 3]
|
| 180 |
+
|
| 181 |
+
return mask_1, mask_2, mask_3, org_iou
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def cal_iou(m1, m2):
|
| 185 |
+
return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def cal_single_iou(m1, m2):
|
| 189 |
+
return np.sum(np.logical_and(m1, m2)) / np.sum(m1)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def iou_3d(box1, box2, signle=None):
|
| 193 |
+
"""
|
| 194 |
+
计算两个三维边界框的交并比 (IoU)
|
| 195 |
+
|
| 196 |
+
参数:
|
| 197 |
+
box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max]
|
| 198 |
+
box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max]
|
| 199 |
+
|
| 200 |
+
返回:
|
| 201 |
+
float: 交并比 (IoU) 值
|
| 202 |
+
"""
|
| 203 |
+
# 计算交集的坐标
|
| 204 |
+
intersection_xmin = max(box1[0], box2[0])
|
| 205 |
+
intersection_ymin = max(box1[1], box2[1])
|
| 206 |
+
intersection_zmin = max(box1[2], box2[2])
|
| 207 |
+
intersection_xmax = min(box1[3], box2[3])
|
| 208 |
+
intersection_ymax = min(box1[4], box2[4])
|
| 209 |
+
intersection_zmax = min(box1[5], box2[5])
|
| 210 |
+
|
| 211 |
+
# 判断是否有交集
|
| 212 |
+
if (
|
| 213 |
+
intersection_xmin >= intersection_xmax
|
| 214 |
+
or intersection_ymin >= intersection_ymax
|
| 215 |
+
or intersection_zmin >= intersection_zmax
|
| 216 |
+
):
|
| 217 |
+
return 0.0 # 无交集
|
| 218 |
+
|
| 219 |
+
# 计算交集的体积
|
| 220 |
+
intersection_volume = (
|
| 221 |
+
(intersection_xmax - intersection_xmin)
|
| 222 |
+
* (intersection_ymax - intersection_ymin)
|
| 223 |
+
* (intersection_zmax - intersection_zmin)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# 计算两个盒子的体积
|
| 227 |
+
box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2])
|
| 228 |
+
box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2])
|
| 229 |
+
|
| 230 |
+
if signle is None:
|
| 231 |
+
# 计算并集的体积
|
| 232 |
+
union_volume = box1_volume + box2_volume - intersection_volume
|
| 233 |
+
elif signle == "1":
|
| 234 |
+
union_volume = box1_volume
|
| 235 |
+
elif signle == "2":
|
| 236 |
+
union_volume = box2_volume
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError("signle must be None or 1 or 2")
|
| 239 |
+
|
| 240 |
+
# 计算 IoU
|
| 241 |
+
iou = intersection_volume / union_volume if union_volume > 0 else 0.0
|
| 242 |
+
return iou
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def cal_point_bbox_iou(p1, p2, signle=None):
|
| 246 |
+
min_p1 = np.min(p1, axis=0)
|
| 247 |
+
max_p1 = np.max(p1, axis=0)
|
| 248 |
+
min_p2 = np.min(p2, axis=0)
|
| 249 |
+
max_p2 = np.max(p2, axis=0)
|
| 250 |
+
box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]]
|
| 251 |
+
box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]]
|
| 252 |
+
return iou_3d(box1, box2, signle)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def cal_bbox_iou(points, m1, m2):
|
| 256 |
+
p1 = points[m1]
|
| 257 |
+
p2 = points[m2]
|
| 258 |
+
return cal_point_bbox_iou(p1, p2)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def clean_mesh(mesh):
|
| 262 |
+
"""
|
| 263 |
+
mesh: trimesh.Trimesh
|
| 264 |
+
"""
|
| 265 |
+
# 1. 合并接近的顶点
|
| 266 |
+
mesh.merge_vertices()
|
| 267 |
+
|
| 268 |
+
# 2. 删除重复的顶点
|
| 269 |
+
# 3. 删除重复的面片
|
| 270 |
+
mesh.process(True)
|
| 271 |
+
return mesh
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def remove_outliers_iqr(data, factor=1.5):
|
| 275 |
+
"""
|
| 276 |
+
基于 IQR 去除离群值
|
| 277 |
+
:param data: 输入的列表或 NumPy 数组
|
| 278 |
+
:param factor: IQR 的倍数(默认 1.5)
|
| 279 |
+
:return: 去除离群值后的列表
|
| 280 |
+
"""
|
| 281 |
+
data = np.array(data, dtype=np.float32)
|
| 282 |
+
q1 = np.percentile(data, 25) # 第一四分位数
|
| 283 |
+
q3 = np.percentile(data, 75) # 第三四分位数
|
| 284 |
+
iqr = q3 - q1 # 四分位距
|
| 285 |
+
lower_bound = q1 - factor * iqr
|
| 286 |
+
upper_bound = q3 + factor * iqr
|
| 287 |
+
return data[(data >= lower_bound) & (data <= upper_bound)].tolist()
|
| 288 |
+
|
| 289 |
+
def better_aabb(points):
|
| 290 |
+
x = points[:, 0]
|
| 291 |
+
y = points[:, 1]
|
| 292 |
+
z = points[:, 2]
|
| 293 |
+
x = remove_outliers_iqr(x)
|
| 294 |
+
y = remove_outliers_iqr(y)
|
| 295 |
+
z = remove_outliers_iqr(z)
|
| 296 |
+
min_xyz = np.array([np.min(x), np.min(y), np.min(z)])
|
| 297 |
+
max_xyz = np.array([np.max(x), np.max(y), np.max(z)])
|
| 298 |
+
return [min_xyz, max_xyz]
|
| 299 |
+
|
| 300 |
+
def fix_label(face_ids, adjacent_faces, use_aabb=False, mesh=None, show_info=False):
|
| 301 |
+
if use_aabb:
|
| 302 |
+
def _cal_aabb(face_ids, i, _points_org):
|
| 303 |
+
_part_mask = face_ids == i
|
| 304 |
+
_faces = mesh.faces[_part_mask]
|
| 305 |
+
_faces = np.reshape(_faces, (-1))
|
| 306 |
+
_points = mesh.vertices[_faces]
|
| 307 |
+
min_xyz, max_xyz = better_aabb(_points)
|
| 308 |
+
_part_mask = (
|
| 309 |
+
(_points_org[:, 0] >= min_xyz[0])
|
| 310 |
+
& (_points_org[:, 0] <= max_xyz[0])
|
| 311 |
+
& (_points_org[:, 1] >= min_xyz[1])
|
| 312 |
+
& (_points_org[:, 1] <= max_xyz[1])
|
| 313 |
+
& (_points_org[:, 2] >= min_xyz[2])
|
| 314 |
+
& (_points_org[:, 2] <= max_xyz[2])
|
| 315 |
+
)
|
| 316 |
+
_part_mask = np.reshape(_part_mask, (-1, 3))
|
| 317 |
+
_part_mask = np.all(_part_mask, axis=1)
|
| 318 |
+
return i, [min_xyz, max_xyz], _part_mask
|
| 319 |
+
with Timer("计算aabb"):
|
| 320 |
+
aabb = {}
|
| 321 |
+
unique_ids = np.unique(face_ids)
|
| 322 |
+
# print(max(unique_ids))
|
| 323 |
+
aabb_face_mask = {}
|
| 324 |
+
_faces = mesh.faces
|
| 325 |
+
_vertices = mesh.vertices
|
| 326 |
+
_faces = np.reshape(_faces, (-1))
|
| 327 |
+
_points = _vertices[_faces]
|
| 328 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
| 329 |
+
futures = []
|
| 330 |
+
for i in unique_ids:
|
| 331 |
+
if i < 0:
|
| 332 |
+
continue
|
| 333 |
+
futures.append(executor.submit(_cal_aabb, face_ids, i, _points))
|
| 334 |
+
for future in futures:
|
| 335 |
+
res = future.result()
|
| 336 |
+
aabb[res[0]] = res[1]
|
| 337 |
+
aabb_face_mask[res[0]] = res[2]
|
| 338 |
+
|
| 339 |
+
# _faces = mesh.faces
|
| 340 |
+
# _vertices = mesh.vertices
|
| 341 |
+
# _faces = np.reshape(_faces, (-1))
|
| 342 |
+
# _points = _vertices[_faces]
|
| 343 |
+
# aabb_face_mask = cal_aabb_mask(_points, face_ids)
|
| 344 |
+
|
| 345 |
+
with Timer("合并mesh"):
|
| 346 |
+
loop_cnt = 1
|
| 347 |
+
changed = True
|
| 348 |
+
progress = tqdm(disable=not show_info)
|
| 349 |
+
no_mask_ids = np.where(face_ids < 0)[0].tolist()
|
| 350 |
+
faces_max = adjacent_faces.shape[0]
|
| 351 |
+
while changed and loop_cnt <= 50:
|
| 352 |
+
changed = False
|
| 353 |
+
# 获取无色面片
|
| 354 |
+
new_no_mask_ids = []
|
| 355 |
+
for i in no_mask_ids:
|
| 356 |
+
# if face_ids[i] < 0:
|
| 357 |
+
# 找邻居
|
| 358 |
+
if not (0 <= i < faces_max):
|
| 359 |
+
continue
|
| 360 |
+
_adj_faces = adjacent_faces[i]
|
| 361 |
+
_adj_ids = []
|
| 362 |
+
for j in _adj_faces:
|
| 363 |
+
if j == -1:
|
| 364 |
+
break
|
| 365 |
+
if face_ids[j] >= 0:
|
| 366 |
+
_tar_id = face_ids[j]
|
| 367 |
+
if use_aabb:
|
| 368 |
+
_mask = aabb_face_mask[_tar_id]
|
| 369 |
+
if _mask[i]:
|
| 370 |
+
_adj_ids.append(_tar_id)
|
| 371 |
+
else:
|
| 372 |
+
_adj_ids.append(_tar_id)
|
| 373 |
+
if len(_adj_ids) == 0:
|
| 374 |
+
new_no_mask_ids.append(i)
|
| 375 |
+
continue
|
| 376 |
+
_max_id = np.argmax(np.bincount(_adj_ids))
|
| 377 |
+
face_ids[i] = _max_id
|
| 378 |
+
changed = True
|
| 379 |
+
no_mask_ids = new_no_mask_ids
|
| 380 |
+
# print(loop_cnt)
|
| 381 |
+
progress.update(1)
|
| 382 |
+
# progress.set_description(f"合并mesh循环:{loop_cnt} {np.sum(face_ids < 0)}")
|
| 383 |
+
loop_cnt += 1
|
| 384 |
+
return face_ids
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def save_mesh(save_path, mesh, face_ids, color_map):
|
| 388 |
+
face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8)
|
| 389 |
+
for i in tqdm(range(len(mesh.faces)), disable=True):
|
| 390 |
+
_max_id = face_ids[i]
|
| 391 |
+
if _max_id == -2:
|
| 392 |
+
continue
|
| 393 |
+
face_colors[i, :3] = color_map[_max_id]
|
| 394 |
+
|
| 395 |
+
mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
|
| 396 |
+
mesh_save.visual.face_colors = face_colors
|
| 397 |
+
mesh_save.export(save_path)
|
| 398 |
+
mesh_save.export(save_path.replace(".glb", ".ply"))
|
| 399 |
+
# print('保存mesh完成')
|
| 400 |
+
|
| 401 |
+
scene_mesh = trimesh.Scene()
|
| 402 |
+
scene_mesh.add_geometry(mesh_save)
|
| 403 |
+
unique_ids = np.unique(face_ids)
|
| 404 |
+
aabb = []
|
| 405 |
+
for i in unique_ids:
|
| 406 |
+
if i == -1 or i == -2:
|
| 407 |
+
continue
|
| 408 |
+
_part_mask = face_ids == i
|
| 409 |
+
_faces = mesh.faces[_part_mask]
|
| 410 |
+
_faces = np.reshape(_faces, (-1))
|
| 411 |
+
_points = mesh.vertices[_faces]
|
| 412 |
+
min_xyz, max_xyz = better_aabb(_points)
|
| 413 |
+
center = (min_xyz + max_xyz) / 2
|
| 414 |
+
size = max_xyz - min_xyz
|
| 415 |
+
box = trimesh.path.creation.box_outline()
|
| 416 |
+
box.vertices *= size
|
| 417 |
+
box.vertices += center
|
| 418 |
+
box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]])
|
| 419 |
+
box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8)
|
| 420 |
+
box.colors = box_color
|
| 421 |
+
scene_mesh.add_geometry(box)
|
| 422 |
+
min_xyz = np.min(_points, axis=0)
|
| 423 |
+
max_xyz = np.max(_points, axis=0)
|
| 424 |
+
aabb.append([min_xyz, max_xyz])
|
| 425 |
+
scene_mesh.export(save_path.replace(".glb", "_aabb.glb"))
|
| 426 |
+
aabb = np.array(aabb)
|
| 427 |
+
np.save(save_path.replace(".glb", "_aabb.npy"), aabb)
|
| 428 |
+
np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def get_aabb_from_face_ids(mesh, face_ids):
|
| 432 |
+
unique_ids = np.unique(face_ids)
|
| 433 |
+
aabb = []
|
| 434 |
+
for i in unique_ids:
|
| 435 |
+
if i == -1 or i == -2:
|
| 436 |
+
continue
|
| 437 |
+
_part_mask = face_ids == i
|
| 438 |
+
_faces = mesh.faces[_part_mask]
|
| 439 |
+
_faces = np.reshape(_faces, (-1))
|
| 440 |
+
_points = mesh.vertices[_faces]
|
| 441 |
+
min_xyz = np.min(_points, axis=0)
|
| 442 |
+
max_xyz = np.max(_points, axis=0)
|
| 443 |
+
aabb.append([min_xyz, max_xyz])
|
| 444 |
+
return np.array(aabb)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def calculate_face_areas(mesh):
|
| 448 |
+
"""
|
| 449 |
+
计算每个三角形面片的面积
|
| 450 |
+
:param mesh: trimesh.Trimesh 对象
|
| 451 |
+
:return: 面片面积数组 (n_faces,)
|
| 452 |
+
"""
|
| 453 |
+
return mesh.area_faces
|
| 454 |
+
# # 提取顶点和面片索引
|
| 455 |
+
# vertices = mesh.vertices
|
| 456 |
+
# faces = mesh.faces
|
| 457 |
+
|
| 458 |
+
# # 获取所有三个顶点的坐标
|
| 459 |
+
# v0 = vertices[faces[:, 0]]
|
| 460 |
+
# v1 = vertices[faces[:, 1]]
|
| 461 |
+
# v2 = vertices[faces[:, 2]]
|
| 462 |
+
|
| 463 |
+
# # 计算两个边向量
|
| 464 |
+
# edge1 = v1 - v0
|
| 465 |
+
# edge2 = v2 - v0
|
| 466 |
+
|
| 467 |
+
# # 计算叉积的模长(向量面积的两倍)
|
| 468 |
+
# cross_product = np.cross(edge1, edge2)
|
| 469 |
+
# areas = 0.5 * np.linalg.norm(cross_product, axis=1)
|
| 470 |
+
|
| 471 |
+
# return areas
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def get_connected_region(face_ids, adjacent_faces, return_face_part_ids=False):
|
| 475 |
+
vis = [False] * len(face_ids)
|
| 476 |
+
parts = []
|
| 477 |
+
face_part_ids = np.ones_like(face_ids) * -1
|
| 478 |
+
for i in range(len(face_ids)):
|
| 479 |
+
if vis[i]:
|
| 480 |
+
continue
|
| 481 |
+
_part = []
|
| 482 |
+
_queue = [i]
|
| 483 |
+
while len(_queue) > 0:
|
| 484 |
+
_cur_face = _queue.pop(0)
|
| 485 |
+
if vis[_cur_face]:
|
| 486 |
+
continue
|
| 487 |
+
vis[_cur_face] = True
|
| 488 |
+
_part.append(_cur_face)
|
| 489 |
+
face_part_ids[_cur_face] = len(parts)
|
| 490 |
+
if not (0 <= _cur_face < adjacent_faces.shape[0]):
|
| 491 |
+
continue
|
| 492 |
+
_cur_face_id = face_ids[_cur_face]
|
| 493 |
+
_adj_faces = adjacent_faces[_cur_face]
|
| 494 |
+
for j in _adj_faces:
|
| 495 |
+
if j == -1:
|
| 496 |
+
break
|
| 497 |
+
if not vis[j] and face_ids[j] == _cur_face_id:
|
| 498 |
+
_queue.append(j)
|
| 499 |
+
parts.append(_part)
|
| 500 |
+
if return_face_part_ids:
|
| 501 |
+
return parts, face_part_ids
|
| 502 |
+
else:
|
| 503 |
+
return parts
|
| 504 |
+
|
| 505 |
+
def aabb_distance(box1, box2):
|
| 506 |
+
"""
|
| 507 |
+
计算两个轴对齐包围盒(AABB)之间的最近距离。
|
| 508 |
+
:param box1: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
|
| 509 |
+
:param box2: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
|
| 510 |
+
:return: 最近距离(浮点数)
|
| 511 |
+
"""
|
| 512 |
+
# 解包坐标
|
| 513 |
+
min1, max1 = box1
|
| 514 |
+
min2, max2 = box2
|
| 515 |
+
|
| 516 |
+
# 计算各轴上的分离距离
|
| 517 |
+
dx = max(0, max2[0] - min1[0], max1[0] - min2[0]) # x轴分离距离
|
| 518 |
+
dy = max(0, max2[1] - min1[1], max1[1] - min2[1]) # y轴分离距离
|
| 519 |
+
dz = max(0, max2[2] - min1[2], max1[2] - min2[2]) # z轴分离距离
|
| 520 |
+
|
| 521 |
+
# 如果所有轴都重叠,则距离为0
|
| 522 |
+
if dx == 0 and dy == 0 and dz == 0:
|
| 523 |
+
return 0.0
|
| 524 |
+
|
| 525 |
+
# 计算欧几里得距离
|
| 526 |
+
return np.sqrt(dx**2 + dy**2 + dz**2)
|
| 527 |
+
|
| 528 |
+
def aabb_volume(aabb):
|
| 529 |
+
"""
|
| 530 |
+
计算轴对齐包围盒(AABB)的体积。
|
| 531 |
+
:param aabb: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
|
| 532 |
+
:return: 体积(浮点数)
|
| 533 |
+
"""
|
| 534 |
+
# 解包坐标
|
| 535 |
+
min_xyz, max_xyz = aabb
|
| 536 |
+
|
| 537 |
+
# 计算体积
|
| 538 |
+
dx = max_xyz[0] - min_xyz[0]
|
| 539 |
+
dy = max_xyz[1] - min_xyz[1]
|
| 540 |
+
dz = max_xyz[2] - min_xyz[2]
|
| 541 |
+
return dx * dy * dz
|
| 542 |
+
|
| 543 |
+
def find_neighbor_part(parts, adjacent_faces, parts_aabb=None, parts_ids=None):
|
| 544 |
+
face2part = {}
|
| 545 |
+
for i, part in enumerate(parts):
|
| 546 |
+
for face in part:
|
| 547 |
+
face2part[face] = i
|
| 548 |
+
neighbor_parts = []
|
| 549 |
+
for i, part in enumerate(parts):
|
| 550 |
+
neighbor_part = set()
|
| 551 |
+
for face in part:
|
| 552 |
+
if not (0 <= face < adjacent_faces.shape[0]):
|
| 553 |
+
continue
|
| 554 |
+
for adj_face in adjacent_faces[face]:
|
| 555 |
+
if adj_face == -1:
|
| 556 |
+
break
|
| 557 |
+
if adj_face not in face2part:
|
| 558 |
+
continue
|
| 559 |
+
if face2part[adj_face] == i:
|
| 560 |
+
continue
|
| 561 |
+
if parts_ids is not None and parts_ids[face2part[adj_face]] in [-1, -2]:
|
| 562 |
+
continue
|
| 563 |
+
neighbor_part.add(face2part[adj_face])
|
| 564 |
+
neighbor_part = list(neighbor_part)
|
| 565 |
+
if parts_aabb is not None and parts_ids is not None and (parts_ids[i] == -1 or parts_ids[i] == -2) and len(neighbor_part) == 0:
|
| 566 |
+
min_dis = np.inf
|
| 567 |
+
min_idx = -1
|
| 568 |
+
for j, _part in tqdm(enumerate(parts)):
|
| 569 |
+
if j == i:
|
| 570 |
+
continue
|
| 571 |
+
if parts_ids[j] == -1 or parts_ids[j] == -2:
|
| 572 |
+
continue
|
| 573 |
+
aabb_1 = parts_aabb[i]
|
| 574 |
+
aabb_2 = parts_aabb[j]
|
| 575 |
+
dis = aabb_distance(aabb_1, aabb_2)
|
| 576 |
+
if dis < min_dis:
|
| 577 |
+
min_dis = dis
|
| 578 |
+
min_idx = j
|
| 579 |
+
elif dis == min_dis:
|
| 580 |
+
if aabb_volume(parts_aabb[j]) < aabb_volume(parts_aabb[min_idx]):
|
| 581 |
+
min_idx = j
|
| 582 |
+
neighbor_part = [min_idx]
|
| 583 |
+
neighbor_parts.append(neighbor_part)
|
| 584 |
+
return neighbor_parts
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def do_post_process(face_areas, parts, adjacent_faces, face_ids, threshold=0.95, show_info=False):
|
| 588 |
+
# # 获取邻接面片
|
| 589 |
+
# mesh_save = mesh.copy()
|
| 590 |
+
# face_adjacency = mesh.face_adjacency
|
| 591 |
+
# adjacent_faces = {}
|
| 592 |
+
# for face1, face2 in face_adjacency:
|
| 593 |
+
# if face1 not in adjacent_faces:
|
| 594 |
+
# adjacent_faces[face1] = []
|
| 595 |
+
# if face2 not in adjacent_faces:
|
| 596 |
+
# adjacent_faces[face2] = []
|
| 597 |
+
# adjacent_faces[face1].append(face2)
|
| 598 |
+
# adjacent_faces[face2].append(face1)
|
| 599 |
+
|
| 600 |
+
# parts = get_connected_region(face_ids, adjacent_faces)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
unique_ids = np.unique(face_ids)
|
| 604 |
+
if show_info:
|
| 605 |
+
print(f"连通区域数量:{len(parts)}")
|
| 606 |
+
print(f"ID数量:{len(unique_ids)}")
|
| 607 |
+
|
| 608 |
+
# face_areas = calculate_face_areas(mesh)
|
| 609 |
+
total_area = np.sum(face_areas)
|
| 610 |
+
if show_info:
|
| 611 |
+
print(f"总面积:{total_area}")
|
| 612 |
+
part_areas = []
|
| 613 |
+
for i, part in enumerate(parts):
|
| 614 |
+
part_area = np.sum(face_areas[part])
|
| 615 |
+
part_areas.append(float(part_area / total_area))
|
| 616 |
+
|
| 617 |
+
sorted_parts = sorted(zip(part_areas, parts), key=lambda x: x[0], reverse=True)
|
| 618 |
+
parts = [x[1] for x in sorted_parts]
|
| 619 |
+
part_areas = [x[0] for x in sorted_parts]
|
| 620 |
+
integral_part_areas = np.cumsum(part_areas)
|
| 621 |
+
|
| 622 |
+
neighbor_parts = find_neighbor_part(parts, adjacent_faces)
|
| 623 |
+
|
| 624 |
+
new_face_ids = face_ids.copy()
|
| 625 |
+
|
| 626 |
+
for i, part in enumerate(parts):
|
| 627 |
+
if integral_part_areas[i] > threshold and part_areas[i] < 0.01:
|
| 628 |
+
if len(neighbor_parts[i]) > 0:
|
| 629 |
+
max_area = 0
|
| 630 |
+
max_part = -1
|
| 631 |
+
for j in neighbor_parts[i]:
|
| 632 |
+
if integral_part_areas[j] > threshold:
|
| 633 |
+
continue
|
| 634 |
+
if part_areas[j] > max_area:
|
| 635 |
+
max_area = part_areas[j]
|
| 636 |
+
max_part = j
|
| 637 |
+
if max_part != -1:
|
| 638 |
+
if show_info:
|
| 639 |
+
print(f"合并mesh:{i} {max_part}")
|
| 640 |
+
parts[max_part].extend(part)
|
| 641 |
+
parts[i] = []
|
| 642 |
+
target_face_id = face_ids[parts[max_part][0]]
|
| 643 |
+
for face in part:
|
| 644 |
+
new_face_ids[face] = target_face_id
|
| 645 |
+
|
| 646 |
+
return new_face_ids
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def do_no_mask_process(parts, face_ids):
|
| 650 |
+
# # 获取邻接面片
|
| 651 |
+
# mesh_save = mesh.copy()
|
| 652 |
+
# face_adjacency = mesh.face_adjacency
|
| 653 |
+
# adjacent_faces = {}
|
| 654 |
+
# for face1, face2 in face_adjacency:
|
| 655 |
+
# if face1 not in adjacent_faces:
|
| 656 |
+
# adjacent_faces[face1] = []
|
| 657 |
+
# if face2 not in adjacent_faces:
|
| 658 |
+
# adjacent_faces[face2] = []
|
| 659 |
+
# adjacent_faces[face1].append(face2)
|
| 660 |
+
# adjacent_faces[face2].append(face1)
|
| 661 |
+
# parts = get_connected_region(face_ids, adjacent_faces)
|
| 662 |
+
|
| 663 |
+
unique_ids = np.unique(face_ids)
|
| 664 |
+
max_id = np.max(unique_ids)
|
| 665 |
+
if -1 or -2 in unique_ids:
|
| 666 |
+
new_face_ids = face_ids.copy()
|
| 667 |
+
for i, part in enumerate(parts):
|
| 668 |
+
if face_ids[part[0]] == -1 or face_ids[part[0]] == -2:
|
| 669 |
+
for face in part:
|
| 670 |
+
new_face_ids[face] = max_id + 1
|
| 671 |
+
max_id += 1
|
| 672 |
+
return new_face_ids
|
| 673 |
+
else:
|
| 674 |
+
return face_ids
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def union_aabb(aabb1, aabb2):
|
| 678 |
+
min_xyz1 = aabb1[0]
|
| 679 |
+
max_xyz1 = aabb1[1]
|
| 680 |
+
min_xyz2 = aabb2[0]
|
| 681 |
+
max_xyz2 = aabb2[1]
|
| 682 |
+
min_xyz = np.minimum(min_xyz1, min_xyz2)
|
| 683 |
+
max_xyz = np.maximum(max_xyz1, max_xyz2)
|
| 684 |
+
return [min_xyz, max_xyz]
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def aabb_increase(aabb1, aabb2):
|
| 688 |
+
min_xyz_before = aabb1[0]
|
| 689 |
+
max_xyz_before = aabb1[1]
|
| 690 |
+
min_xyz_after, max_xyz_after = union_aabb(aabb1, aabb2)
|
| 691 |
+
min_xyz_increase = np.abs(min_xyz_after - min_xyz_before) / np.abs(min_xyz_before)
|
| 692 |
+
max_xyz_increase = np.abs(max_xyz_after - max_xyz_before) / np.abs(max_xyz_before)
|
| 693 |
+
return min_xyz_increase, max_xyz_increase
|
| 694 |
+
|
| 695 |
+
def sort_multi_list(multi_list, key=lambda x: x[0], reverse=False):
|
| 696 |
+
'''
|
| 697 |
+
multi_list: [list1, list2, list3, list4, ...], len(list1)=N, len(list2)=N, len(list3)=N, ...
|
| 698 |
+
key: 排序函数,默认按第一个元素排序
|
| 699 |
+
reverse: 排序顺序,默认降序
|
| 700 |
+
return:
|
| 701 |
+
[list1, list2, list3, list4, ...]: 按同一个顺序排序后的多个list
|
| 702 |
+
'''
|
| 703 |
+
sorted_list = sorted(zip(*multi_list), key=key, reverse=reverse)
|
| 704 |
+
return zip(*sorted_list)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
class Timer:
|
| 708 |
+
STATE = True
|
| 709 |
+
def __init__(self, name):
|
| 710 |
+
self.name = name
|
| 711 |
+
|
| 712 |
+
def __enter__(self):
|
| 713 |
+
if not Timer.STATE:
|
| 714 |
+
return
|
| 715 |
+
self.start_time = time.time()
|
| 716 |
+
return self # 可以返回 self 以便在 with 块内访问
|
| 717 |
+
|
| 718 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 719 |
+
if not Timer.STATE:
|
| 720 |
+
return
|
| 721 |
+
self.end_time = time.time()
|
| 722 |
+
self.elapsed_time = self.end_time - self.start_time
|
| 723 |
+
print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒")
|
| 724 |
+
|
| 725 |
+
###################### NUMBA 加速 ######################
|
| 726 |
+
@njit
|
| 727 |
+
def build_adjacent_faces_numba(face_adjacency):
|
| 728 |
+
"""
|
| 729 |
+
使用 Numba 加速构建邻接面片数组。
|
| 730 |
+
:param face_adjacency: (N, 2) numpy 数组,包含邻接面片对。
|
| 731 |
+
:return:
|
| 732 |
+
- adj_list: 一维数组,存储所有邻接面片。
|
| 733 |
+
- offsets: 一维数组,记录每个面片的邻接起始位置。
|
| 734 |
+
"""
|
| 735 |
+
n_faces = np.max(face_adjacency) + 1 # 总面片数
|
| 736 |
+
n_edges = face_adjacency.shape[0] # 总邻接边数
|
| 737 |
+
|
| 738 |
+
# 第一步:统计每个面片的邻接数量(度数)
|
| 739 |
+
degrees = np.zeros(n_faces, dtype=np.int32)
|
| 740 |
+
for i in range(n_edges):
|
| 741 |
+
f1, f2 = face_adjacency[i]
|
| 742 |
+
degrees[f1] += 1
|
| 743 |
+
degrees[f2] += 1
|
| 744 |
+
max_degree = np.max(degrees) # 最大度数
|
| 745 |
+
|
| 746 |
+
adjacent_faces = np.ones((n_faces, max_degree), dtype=np.int32) * -1 # 邻接面片数组
|
| 747 |
+
adjacent_faces_count = np.zeros(n_faces, dtype=np.int32) # 邻接面片计数器
|
| 748 |
+
for i in range(n_edges):
|
| 749 |
+
f1, f2 = face_adjacency[i]
|
| 750 |
+
adjacent_faces[f1, adjacent_faces_count[f1]] = f2
|
| 751 |
+
adjacent_faces_count[f1] += 1
|
| 752 |
+
adjacent_faces[f2, adjacent_faces_count[f2]] = f1
|
| 753 |
+
adjacent_faces_count[f2] += 1
|
| 754 |
+
return adjacent_faces
|
| 755 |
+
###################### NUMBA 加速 ######################
|
| 756 |
+
|
| 757 |
+
def mesh_sam(
|
| 758 |
+
model,
|
| 759 |
+
mesh,
|
| 760 |
+
save_path,
|
| 761 |
+
point_num=100000,
|
| 762 |
+
prompt_num=400,
|
| 763 |
+
save_mid_res=False,
|
| 764 |
+
show_info=False,
|
| 765 |
+
post_process=False,
|
| 766 |
+
threshold=0.95,
|
| 767 |
+
clean_mesh_flag=True,
|
| 768 |
+
seed=42,
|
| 769 |
+
prompt_bs=32,
|
| 770 |
+
):
|
| 771 |
+
with Timer("加载mesh"):
|
| 772 |
+
model, model_parallel = model
|
| 773 |
+
if clean_mesh_flag:
|
| 774 |
+
mesh = clean_mesh(mesh)
|
| 775 |
+
mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
|
| 776 |
+
if show_info:
|
| 777 |
+
print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}")
|
| 778 |
+
|
| 779 |
+
point_num = 100000
|
| 780 |
+
prompt_num = 400
|
| 781 |
+
with Timer("获取邻接面片"):
|
| 782 |
+
face_adjacency = mesh.face_adjacency
|
| 783 |
+
with Timer("处理邻接面片"):
|
| 784 |
+
adjacent_faces = build_adjacent_faces_numba(face_adjacency)
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
with Timer("采样点云"):
|
| 788 |
+
_points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed)
|
| 789 |
+
_points_org = _points.copy()
|
| 790 |
+
_points = normalize_pc(_points)
|
| 791 |
+
normals = mesh.face_normals[face_idx]
|
| 792 |
+
if show_info:
|
| 793 |
+
print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}")
|
| 794 |
+
|
| 795 |
+
with Timer("获取特征"):
|
| 796 |
+
_feats = get_feat(model, _points, normals)
|
| 797 |
+
if show_info:
|
| 798 |
+
print("预处理特征")
|
| 799 |
+
|
| 800 |
+
if save_mid_res:
|
| 801 |
+
feat_save = _feats.float().detach().cpu().numpy()
|
| 802 |
+
data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
|
| 803 |
+
pca = PCA(n_components=3)
|
| 804 |
+
data_reduced = pca.fit_transform(data_scaled)
|
| 805 |
+
data_reduced = (data_reduced - data_reduced.min()) / (
|
| 806 |
+
data_reduced.max() - data_reduced.min()
|
| 807 |
+
)
|
| 808 |
+
_colors_pca = (data_reduced * 255).astype(np.uint8)
|
| 809 |
+
pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca)
|
| 810 |
+
pc_save.export(os.path.join(save_path, "point_pca.glb"))
|
| 811 |
+
pc_save.export(os.path.join(save_path, "point_pca.ply"))
|
| 812 |
+
if show_info:
|
| 813 |
+
print("PCA获取特征颜色")
|
| 814 |
+
|
| 815 |
+
with Timer("FPS采样提示点"):
|
| 816 |
+
fps_idx = fpsample.fps_sampling(_points, prompt_num)
|
| 817 |
+
_point_prompts = _points[fps_idx]
|
| 818 |
+
if save_mid_res:
|
| 819 |
+
trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
|
| 820 |
+
os.path.join(save_path, "point_prompts_pca.glb")
|
| 821 |
+
)
|
| 822 |
+
trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
|
| 823 |
+
os.path.join(save_path, "point_prompts_pca.ply")
|
| 824 |
+
)
|
| 825 |
+
if show_info:
|
| 826 |
+
print("采样完成")
|
| 827 |
+
|
| 828 |
+
with Timer("推理"):
|
| 829 |
+
bs = prompt_bs
|
| 830 |
+
step_num = prompt_num // bs + 1
|
| 831 |
+
mask_res = []
|
| 832 |
+
iou_res = []
|
| 833 |
+
for i in tqdm(range(step_num), disable=not show_info):
|
| 834 |
+
cur_propmt = _point_prompts[bs * i : bs * (i + 1)]
|
| 835 |
+
pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask(
|
| 836 |
+
model_parallel, _feats, _points, cur_propmt
|
| 837 |
+
)
|
| 838 |
+
pred_mask = np.stack(
|
| 839 |
+
[pred_mask_1, pred_mask_2, pred_mask_3], axis=-1
|
| 840 |
+
) # [N, K, 3]
|
| 841 |
+
max_idx = np.argmax(pred_iou, axis=-1) # [K]
|
| 842 |
+
for j in range(max_idx.shape[0]):
|
| 843 |
+
mask_res.append(pred_mask[:, j, max_idx[j]])
|
| 844 |
+
iou_res.append(pred_iou[j, max_idx[j]])
|
| 845 |
+
mask_res = np.stack(mask_res, axis=-1) # [N, K]
|
| 846 |
+
if show_info:
|
| 847 |
+
print("prmopt 推理完成")
|
| 848 |
+
|
| 849 |
+
with Timer("根据IOU排序"):
|
| 850 |
+
iou_res = np.array(iou_res).tolist()
|
| 851 |
+
mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)]
|
| 852 |
+
mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True)
|
| 853 |
+
mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)]
|
| 854 |
+
iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)]
|
| 855 |
+
|
| 856 |
+
with Timer("NMS"):
|
| 857 |
+
clusters = defaultdict(list)
|
| 858 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
| 859 |
+
for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
|
| 860 |
+
_mask = mask_sorted[i]
|
| 861 |
+
futures = []
|
| 862 |
+
for j in clusters.keys():
|
| 863 |
+
futures.append(executor.submit(cal_iou, _mask, mask_sorted[j]))
|
| 864 |
+
|
| 865 |
+
for j, future in zip(clusters.keys(), futures):
|
| 866 |
+
if future.result() > 0.9:
|
| 867 |
+
clusters[j].append(i)
|
| 868 |
+
break
|
| 869 |
+
else:
|
| 870 |
+
clusters[i].append(i)
|
| 871 |
+
|
| 872 |
+
if show_info:
|
| 873 |
+
print(f"NMS完成,mask数量:{len(clusters)}")
|
| 874 |
+
|
| 875 |
+
if save_mid_res:
|
| 876 |
+
part_mask_save_path = os.path.join(save_path, "part_mask")
|
| 877 |
+
if os.path.exists(part_mask_save_path):
|
| 878 |
+
shutil.rmtree(part_mask_save_path)
|
| 879 |
+
os.makedirs(part_mask_save_path, exist_ok=True)
|
| 880 |
+
for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info):
|
| 881 |
+
cluster_num = len(clusters[i])
|
| 882 |
+
cluster_iou = iou_sorted[i]
|
| 883 |
+
cluster_area = np.sum(mask_sorted[i])
|
| 884 |
+
if cluster_num <= 2:
|
| 885 |
+
continue
|
| 886 |
+
mask_save = mask_sorted[i]
|
| 887 |
+
mask_save = np.expand_dims(mask_save, axis=-1)
|
| 888 |
+
mask_save = np.repeat(mask_save, 3, axis=-1)
|
| 889 |
+
mask_save = (mask_save * 255).astype(np.uint8)
|
| 890 |
+
point_save = trimesh.points.PointCloud(_points, colors=mask_save)
|
| 891 |
+
point_save.export(
|
| 892 |
+
os.path.join(
|
| 893 |
+
part_mask_save_path,
|
| 894 |
+
f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
|
| 895 |
+
)
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
# 过滤只有一个mask的cluster
|
| 899 |
+
with Timer("过滤只有一个mask的cluster"):
|
| 900 |
+
filtered_clusters = []
|
| 901 |
+
other_clusters = []
|
| 902 |
+
for i in clusters.keys():
|
| 903 |
+
if len(clusters[i]) > 2:
|
| 904 |
+
filtered_clusters.append(i)
|
| 905 |
+
else:
|
| 906 |
+
other_clusters.append(i)
|
| 907 |
+
if show_info:
|
| 908 |
+
print(
|
| 909 |
+
f"过滤前:{len(clusters)} 个cluster,"
|
| 910 |
+
f"过滤后:{len(filtered_clusters)} 个cluster"
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
# 再次合并
|
| 914 |
+
with Timer("再次合并"):
|
| 915 |
+
filtered_clusters_num = len(filtered_clusters)
|
| 916 |
+
cluster2 = {}
|
| 917 |
+
is_union = [False] * filtered_clusters_num
|
| 918 |
+
for i in range(filtered_clusters_num):
|
| 919 |
+
if is_union[i]:
|
| 920 |
+
continue
|
| 921 |
+
cur_cluster = filtered_clusters[i]
|
| 922 |
+
cluster2[cur_cluster] = [cur_cluster]
|
| 923 |
+
for j in range(i + 1, filtered_clusters_num):
|
| 924 |
+
if is_union[j]:
|
| 925 |
+
continue
|
| 926 |
+
tar_cluster = filtered_clusters[j]
|
| 927 |
+
if (
|
| 928 |
+
cal_bbox_iou(
|
| 929 |
+
_points, mask_sorted[tar_cluster], mask_sorted[cur_cluster]
|
| 930 |
+
)
|
| 931 |
+
> 0.5
|
| 932 |
+
):
|
| 933 |
+
cluster2[cur_cluster].append(tar_cluster)
|
| 934 |
+
is_union[j] = True
|
| 935 |
+
if show_info:
|
| 936 |
+
print(f"再次合并,合并数量:{len(cluster2.keys())}")
|
| 937 |
+
|
| 938 |
+
with Timer("计算没有mask的点"):
|
| 939 |
+
no_mask = np.ones(point_num)
|
| 940 |
+
for i in cluster2:
|
| 941 |
+
part_mask = mask_sorted[i]
|
| 942 |
+
no_mask[part_mask] = 0
|
| 943 |
+
if show_info:
|
| 944 |
+
print(
|
| 945 |
+
f"{np.sum(no_mask == 1)} 个点没有mask,"
|
| 946 |
+
f" 占比:{np.sum(no_mask == 1) / point_num:.4f}"
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
with Timer("修补遗漏mask"):
|
| 950 |
+
# 查询漏掉的mask
|
| 951 |
+
for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info):
|
| 952 |
+
if i in cluster2:
|
| 953 |
+
continue
|
| 954 |
+
part_mask = mask_sorted[i]
|
| 955 |
+
_iou = cal_single_iou(part_mask, no_mask)
|
| 956 |
+
if _iou > 0.7:
|
| 957 |
+
cluster2[i] = [i]
|
| 958 |
+
no_mask[part_mask] = 0
|
| 959 |
+
if save_mid_res:
|
| 960 |
+
mask_save = mask_sorted[i]
|
| 961 |
+
mask_save = np.expand_dims(mask_save, axis=-1)
|
| 962 |
+
mask_save = np.repeat(mask_save, 3, axis=-1)
|
| 963 |
+
mask_save = (mask_save * 255).astype(np.uint8)
|
| 964 |
+
point_save = trimesh.points.PointCloud(_points, colors=mask_save)
|
| 965 |
+
cluster_iou = iou_sorted[i]
|
| 966 |
+
cluster_area = int(np.sum(mask_sorted[i]))
|
| 967 |
+
cluster_num = 1
|
| 968 |
+
point_save.export(
|
| 969 |
+
os.path.join(
|
| 970 |
+
part_mask_save_path,
|
| 971 |
+
f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
|
| 972 |
+
)
|
| 973 |
+
)
|
| 974 |
+
if show_info:
|
| 975 |
+
print(f"修补遗漏mask:{len(cluster2.keys())}")
|
| 976 |
+
|
| 977 |
+
with Timer("计算点云最终mask"):
|
| 978 |
+
final_mask = list(cluster2.keys())
|
| 979 |
+
final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask]
|
| 980 |
+
final_mask_area = [
|
| 981 |
+
[final_mask[i], final_mask_area[i]] for i in range(len(final_mask))
|
| 982 |
+
]
|
| 983 |
+
final_mask_area_sorted = sorted(final_mask_area, key=lambda x: x[1], reverse=True)
|
| 984 |
+
final_mask_sorted = [
|
| 985 |
+
final_mask_area_sorted[i][0] for i in range(len(final_mask_area))
|
| 986 |
+
]
|
| 987 |
+
final_mask_area_sorted = [
|
| 988 |
+
final_mask_area_sorted[i][1] for i in range(len(final_mask_area))
|
| 989 |
+
]
|
| 990 |
+
if show_info:
|
| 991 |
+
print(f"最终mask数量:{len(final_mask_sorted)}")
|
| 992 |
+
|
| 993 |
+
with Timer("点云上色"):
|
| 994 |
+
# 生成color map
|
| 995 |
+
color_map = {}
|
| 996 |
+
for i in final_mask_sorted:
|
| 997 |
+
part_color = np.random.rand(3) * 255
|
| 998 |
+
color_map[i] = part_color
|
| 999 |
+
# print(color_map)
|
| 1000 |
+
|
| 1001 |
+
result_mask = -np.ones(point_num, dtype=np.int64)
|
| 1002 |
+
for i in final_mask_sorted:
|
| 1003 |
+
part_mask = mask_sorted[i]
|
| 1004 |
+
result_mask[part_mask] = i
|
| 1005 |
+
if save_mid_res:
|
| 1006 |
+
# 保存点云结果
|
| 1007 |
+
result_colors = np.zeros_like(_colors_pca)
|
| 1008 |
+
for i in final_mask_sorted:
|
| 1009 |
+
part_color = color_map[i]
|
| 1010 |
+
part_mask = mask_sorted[i]
|
| 1011 |
+
result_colors[part_mask, :3] = part_color
|
| 1012 |
+
trimesh.points.PointCloud(_points, colors=result_colors).export(
|
| 1013 |
+
os.path.join(save_path, "auto_mask_cluster.glb")
|
| 1014 |
+
)
|
| 1015 |
+
trimesh.points.PointCloud(_points, colors=result_colors).export(
|
| 1016 |
+
os.path.join(save_path, "auto_mask_cluster.ply")
|
| 1017 |
+
)
|
| 1018 |
+
if show_info:
|
| 1019 |
+
print("保存点云完成")
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
with Timer("投影Mesh并统计label"):
|
| 1023 |
+
# 保存mesh结果
|
| 1024 |
+
face_seg_res = {}
|
| 1025 |
+
for i in final_mask_sorted:
|
| 1026 |
+
_part_mask = result_mask == i
|
| 1027 |
+
_face_idx = face_idx[_part_mask]
|
| 1028 |
+
for k in _face_idx:
|
| 1029 |
+
if k not in face_seg_res:
|
| 1030 |
+
face_seg_res[k] = []
|
| 1031 |
+
face_seg_res[k].append(i)
|
| 1032 |
+
_part_mask = result_mask == -1
|
| 1033 |
+
_face_idx = face_idx[_part_mask]
|
| 1034 |
+
for k in _face_idx:
|
| 1035 |
+
if k not in face_seg_res:
|
| 1036 |
+
face_seg_res[k] = []
|
| 1037 |
+
face_seg_res[k].append(-1)
|
| 1038 |
+
|
| 1039 |
+
face_ids = -np.ones(len(mesh.faces), dtype=np.int64) * 2
|
| 1040 |
+
for i in tqdm(face_seg_res, leave=False, disable=True):
|
| 1041 |
+
_seg_ids = np.array(face_seg_res[i])
|
| 1042 |
+
# 获取最多的seg_id
|
| 1043 |
+
_max_id = np.argmax(np.bincount(_seg_ids + 2)) - 2
|
| 1044 |
+
face_ids[i] = _max_id
|
| 1045 |
+
face_ids_org = face_ids.copy()
|
| 1046 |
+
if show_info:
|
| 1047 |
+
print("生成face_ids完成")
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
with Timer("第一次修复face_ids"):
|
| 1051 |
+
face_ids += 1
|
| 1052 |
+
face_ids = fix_label(face_ids, adjacent_faces, mesh=mesh, show_info=show_info)
|
| 1053 |
+
face_ids -= 1
|
| 1054 |
+
if show_info:
|
| 1055 |
+
print("修复face_ids完成")
|
| 1056 |
+
|
| 1057 |
+
color_map[-1] = np.array([255, 0, 0], dtype=np.uint8)
|
| 1058 |
+
|
| 1059 |
+
if save_mid_res:
|
| 1060 |
+
save_mesh(
|
| 1061 |
+
os.path.join(save_path, "auto_mask_mesh.glb"), mesh, face_ids, color_map
|
| 1062 |
+
)
|
| 1063 |
+
save_mesh(
|
| 1064 |
+
os.path.join(save_path, "auto_mask_mesh_org.glb"),
|
| 1065 |
+
mesh,
|
| 1066 |
+
face_ids_org,
|
| 1067 |
+
color_map,
|
| 1068 |
+
)
|
| 1069 |
+
if show_info:
|
| 1070 |
+
print("保存mesh结果完成")
|
| 1071 |
+
|
| 1072 |
+
with Timer("计算连通区域"):
|
| 1073 |
+
face_areas = calculate_face_areas(mesh)
|
| 1074 |
+
mesh_total_area = np.sum(face_areas)
|
| 1075 |
+
parts = get_connected_region(face_ids, adjacent_faces)
|
| 1076 |
+
connected_parts, _face_connected_parts_ids = get_connected_region(np.ones_like(face_ids), adjacent_faces, return_face_part_ids=True)
|
| 1077 |
+
if show_info:
|
| 1078 |
+
print(f"共{len(parts)}个mesh")
|
| 1079 |
+
with Timer("排序连通区域"):
|
| 1080 |
+
parts_cp_idx = []
|
| 1081 |
+
for x in parts:
|
| 1082 |
+
_face_idx = x[0]
|
| 1083 |
+
parts_cp_idx.append(_face_connected_parts_ids[_face_idx])
|
| 1084 |
+
parts_cp_idx = np.array(parts_cp_idx)
|
| 1085 |
+
parts_areas = [float(np.sum(face_areas[x])) for x in parts]
|
| 1086 |
+
connected_parts_areas = [float(np.sum(face_areas[x])) for x in connected_parts]
|
| 1087 |
+
parts_cp_areas = [connected_parts_areas[x] for x in parts_cp_idx]
|
| 1088 |
+
parts_sorted, parts_areas_sorted, parts_cp_areas_sorted = sort_multi_list([parts, parts_areas, parts_cp_areas], key=lambda x: x[1], reverse=True)
|
| 1089 |
+
|
| 1090 |
+
with Timer("去除面积过小的区域"):
|
| 1091 |
+
filtered_parts = []
|
| 1092 |
+
other_parts = []
|
| 1093 |
+
for i in range(len(parts_sorted)):
|
| 1094 |
+
parts = parts_sorted[i]
|
| 1095 |
+
area = parts_areas_sorted[i]
|
| 1096 |
+
cp_area = parts_cp_areas_sorted[i]
|
| 1097 |
+
if area / (cp_area+1e-7) > 0.001:
|
| 1098 |
+
filtered_parts.append(i)
|
| 1099 |
+
else:
|
| 1100 |
+
other_parts.append(i)
|
| 1101 |
+
if show_info:
|
| 1102 |
+
print(f"保留{len(filtered_parts)}个mesh, 其他{len(other_parts)}个mesh")
|
| 1103 |
+
|
| 1104 |
+
with Timer("去除面积过小区域的label"):
|
| 1105 |
+
face_ids_2 = face_ids.copy()
|
| 1106 |
+
part_num = len(cluster2.keys())
|
| 1107 |
+
for j in other_parts:
|
| 1108 |
+
parts = parts_sorted[j]
|
| 1109 |
+
for i in parts:
|
| 1110 |
+
face_ids_2[i] = -1
|
| 1111 |
+
|
| 1112 |
+
with Timer("第二次修复face_ids"):
|
| 1113 |
+
face_ids_3 = face_ids_2.copy()
|
| 1114 |
+
face_ids_3 = fix_label(face_ids_3, adjacent_faces, mesh=mesh, show_info=show_info)
|
| 1115 |
+
|
| 1116 |
+
if save_mid_res:
|
| 1117 |
+
save_mesh(
|
| 1118 |
+
os.path.join(save_path, "auto_mask_mesh_filtered_2.glb"),
|
| 1119 |
+
mesh,
|
| 1120 |
+
face_ids_3,
|
| 1121 |
+
color_map,
|
| 1122 |
+
)
|
| 1123 |
+
if show_info:
|
| 1124 |
+
print("保存mesh结果完成")
|
| 1125 |
+
|
| 1126 |
+
with Timer("第二次计算连通区域"):
|
| 1127 |
+
parts_2 = get_connected_region(face_ids_3, adjacent_faces)
|
| 1128 |
+
parts_areas_2 = [float(np.sum(face_areas[x])) for x in parts_2]
|
| 1129 |
+
parts_ids_2 = [face_ids_3[x[0]] for x in parts_2]
|
| 1130 |
+
|
| 1131 |
+
with Timer("添加过大的缺失part"):
|
| 1132 |
+
color_map_2 = copy.deepcopy(color_map)
|
| 1133 |
+
max_id = np.max(parts_ids_2)
|
| 1134 |
+
for i in range(len(parts_2)):
|
| 1135 |
+
_parts = parts_2[i]
|
| 1136 |
+
_area = parts_areas_2[i]
|
| 1137 |
+
_parts_id = face_ids_3[_parts[0]]
|
| 1138 |
+
if _area / mesh_total_area > 0.001:
|
| 1139 |
+
if _parts_id == -1 or _parts_id == -2:
|
| 1140 |
+
parts_ids_2[i] = max_id + 1
|
| 1141 |
+
max_id += 1
|
| 1142 |
+
color_map_2[max_id] = np.random.rand(3) * 255
|
| 1143 |
+
if show_info:
|
| 1144 |
+
print(f"新增part {max_id}")
|
| 1145 |
+
# else:
|
| 1146 |
+
# parts_ids_2[i] = -1
|
| 1147 |
+
|
| 1148 |
+
with Timer("赋值新的face_ids"):
|
| 1149 |
+
face_ids_4 = face_ids_3.copy()
|
| 1150 |
+
for i in range(len(parts_2)):
|
| 1151 |
+
_parts = parts_2[i]
|
| 1152 |
+
_parts_id = parts_ids_2[i]
|
| 1153 |
+
for j in _parts:
|
| 1154 |
+
face_ids_4[j] = _parts_id
|
| 1155 |
+
with Timer("计算part和label的aabb"):
|
| 1156 |
+
ids_aabb = {}
|
| 1157 |
+
unique_ids = np.unique(face_ids_4)
|
| 1158 |
+
for i in unique_ids:
|
| 1159 |
+
if i < 0:
|
| 1160 |
+
continue
|
| 1161 |
+
_part_mask = face_ids_4 == i
|
| 1162 |
+
_faces = mesh.faces[_part_mask]
|
| 1163 |
+
_faces = np.reshape(_faces, (-1))
|
| 1164 |
+
_points = mesh.vertices[_faces]
|
| 1165 |
+
min_xyz = np.min(_points, axis=0)
|
| 1166 |
+
max_xyz = np.max(_points, axis=0)
|
| 1167 |
+
ids_aabb[i] = [min_xyz, max_xyz]
|
| 1168 |
+
|
| 1169 |
+
parts_2_aabb = []
|
| 1170 |
+
for i in range(len(parts_2)):
|
| 1171 |
+
_parts = parts_2[i]
|
| 1172 |
+
_faces = mesh.faces[_parts]
|
| 1173 |
+
_faces = np.reshape(_faces, (-1))
|
| 1174 |
+
_points = mesh.vertices[_faces]
|
| 1175 |
+
min_xyz = np.min(_points, axis=0)
|
| 1176 |
+
max_xyz = np.max(_points, axis=0)
|
| 1177 |
+
parts_2_aabb.append([min_xyz, max_xyz])
|
| 1178 |
+
|
| 1179 |
+
with Timer("计算part的邻居"):
|
| 1180 |
+
parts_2_neighbor = find_neighbor_part(parts_2, adjacent_faces, parts_2_aabb, parts_ids_2)
|
| 1181 |
+
with Timer("合并无mask区域"):
|
| 1182 |
+
for i in range(len(parts_2)):
|
| 1183 |
+
_parts = parts_2[i]
|
| 1184 |
+
_ids = parts_ids_2[i]
|
| 1185 |
+
if _ids == -1 or _ids == -2:
|
| 1186 |
+
_cur_aabb = parts_2_aabb[i]
|
| 1187 |
+
_min_aabb_increase = 1e10
|
| 1188 |
+
_min_id = -1
|
| 1189 |
+
for j in parts_2_neighbor[i]:
|
| 1190 |
+
if parts_ids_2[j] == -1 or parts_ids_2[j] == -2:
|
| 1191 |
+
continue
|
| 1192 |
+
_tar_id = parts_ids_2[j]
|
| 1193 |
+
_tar_aabb = ids_aabb[_tar_id]
|
| 1194 |
+
_min_increase, _max_increase = aabb_increase(_tar_aabb, _cur_aabb)
|
| 1195 |
+
_increase = max(np.max(_min_increase), np.max(_max_increase))
|
| 1196 |
+
if _min_aabb_increase > _increase:
|
| 1197 |
+
_min_aabb_increase = _increase
|
| 1198 |
+
_min_id = _tar_id
|
| 1199 |
+
if _min_id >= 0:
|
| 1200 |
+
parts_ids_2[i] = _min_id
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
with Timer("再次赋值新的face_ids"):
|
| 1204 |
+
face_ids_4 = face_ids_3.copy()
|
| 1205 |
+
for i in range(len(parts_2)):
|
| 1206 |
+
_parts = parts_2[i]
|
| 1207 |
+
_parts_id = parts_ids_2[i]
|
| 1208 |
+
for j in _parts:
|
| 1209 |
+
face_ids_4[j] = _parts_id
|
| 1210 |
+
|
| 1211 |
+
final_face_ids = face_ids_4
|
| 1212 |
+
if save_mid_res:
|
| 1213 |
+
save_mesh(
|
| 1214 |
+
os.path.join(save_path, "auto_mask_mesh_final.glb"),
|
| 1215 |
+
mesh,
|
| 1216 |
+
face_ids_4,
|
| 1217 |
+
color_map_2,
|
| 1218 |
+
)
|
| 1219 |
+
|
| 1220 |
+
if post_process:
|
| 1221 |
+
parts = get_connected_region(final_face_ids, adjacent_faces)
|
| 1222 |
+
final_face_ids = do_no_mask_process(parts, final_face_ids)
|
| 1223 |
+
face_ids_5 = do_post_process(face_areas, parts, adjacent_faces, face_ids_4, threshold, show_info=show_info)
|
| 1224 |
+
if save_mid_res:
|
| 1225 |
+
save_mesh(
|
| 1226 |
+
os.path.join(save_path, "auto_mask_mesh_final_post.glb"),
|
| 1227 |
+
mesh,
|
| 1228 |
+
face_ids_5,
|
| 1229 |
+
color_map_2,
|
| 1230 |
+
)
|
| 1231 |
+
final_face_ids = face_ids_5
|
| 1232 |
+
with Timer("计算最后的aabb"):
|
| 1233 |
+
aabb = get_aabb_from_face_ids(mesh, final_face_ids)
|
| 1234 |
+
return aabb, final_face_ids, mesh
|
| 1235 |
+
|
| 1236 |
+
|
| 1237 |
+
class AutoMask:
|
| 1238 |
+
def __init__(
|
| 1239 |
+
self,
|
| 1240 |
+
ckpt_path=None,
|
| 1241 |
+
point_num=100000,
|
| 1242 |
+
prompt_num=400,
|
| 1243 |
+
threshold=0.95,
|
| 1244 |
+
post_process=True,
|
| 1245 |
+
):
|
| 1246 |
+
"""
|
| 1247 |
+
ckpt_path: str, 模型路径
|
| 1248 |
+
point_num: int, 采样点数量
|
| 1249 |
+
prompt_num: int, 提示数量
|
| 1250 |
+
threshold: float, 阈值
|
| 1251 |
+
post_process: bool, 是否后处理
|
| 1252 |
+
"""
|
| 1253 |
+
self.model = P3SAM()
|
| 1254 |
+
self.model.load_state_dict(ckpt_path)
|
| 1255 |
+
self.model.eval()
|
| 1256 |
+
self.model_parallel = torch.nn.DataParallel(self.model)
|
| 1257 |
+
self.model.cuda()
|
| 1258 |
+
self.model_parallel.cuda()
|
| 1259 |
+
self.point_num = point_num
|
| 1260 |
+
self.prompt_num = prompt_num
|
| 1261 |
+
self.threshold = threshold
|
| 1262 |
+
self.post_process = post_process
|
| 1263 |
+
|
| 1264 |
+
def predict_aabb(
|
| 1265 |
+
self, mesh, point_num=None, prompt_num=None, threshold=None, post_process=None, save_path=None, save_mid_res=False, show_info=True, clean_mesh_flag=True, seed=42, is_parallel=True, prompt_bs=32
|
| 1266 |
+
):
|
| 1267 |
+
"""
|
| 1268 |
+
Parameters:
|
| 1269 |
+
mesh: trimesh.Trimesh, 输入网格
|
| 1270 |
+
point_num: int, 采样点数量
|
| 1271 |
+
prompt_num: int, 提示数量
|
| 1272 |
+
threshold: float, 阈值
|
| 1273 |
+
post_process: bool, 是否后处理
|
| 1274 |
+
Returns:
|
| 1275 |
+
aabb: np.ndarray, 包围盒
|
| 1276 |
+
face_ids: np.ndarray, 面id
|
| 1277 |
+
"""
|
| 1278 |
+
point_num = point_num if point_num is not None else self.point_num
|
| 1279 |
+
prompt_num = prompt_num if prompt_num is not None else self.prompt_num
|
| 1280 |
+
threshold = threshold if threshold is not None else self.threshold
|
| 1281 |
+
post_process = post_process if post_process is not None else self.post_process
|
| 1282 |
+
return mesh_sam(
|
| 1283 |
+
[self.model, self.model_parallel if is_parallel else self.model],
|
| 1284 |
+
mesh,
|
| 1285 |
+
save_path=save_path,
|
| 1286 |
+
point_num=point_num,
|
| 1287 |
+
prompt_num=prompt_num,
|
| 1288 |
+
threshold=threshold,
|
| 1289 |
+
post_process=post_process,
|
| 1290 |
+
show_info=show_info,
|
| 1291 |
+
save_mid_res=save_mid_res,
|
| 1292 |
+
clean_mesh_flag=clean_mesh_flag,
|
| 1293 |
+
seed=seed,
|
| 1294 |
+
prompt_bs=prompt_bs,
|
| 1295 |
+
)
|
| 1296 |
+
|
| 1297 |
+
def set_seed(seed):
|
| 1298 |
+
random.seed(seed)
|
| 1299 |
+
np.random.seed(seed)
|
| 1300 |
+
torch.manual_seed(seed)
|
| 1301 |
+
if torch.cuda.is_available():
|
| 1302 |
+
torch.cuda.manual_seed(seed)
|
| 1303 |
+
torch.cuda.manual_seed_all(seed)
|
| 1304 |
+
torch.backends.cudnn.deterministic = True
|
| 1305 |
+
torch.backends.cudnn.benchmark = False
|
| 1306 |
+
|
| 1307 |
+
if __name__ == '__main__':
|
| 1308 |
+
argparser = argparse.ArgumentParser()
|
| 1309 |
+
argparser.add_argument('--ckpt_path', type=str, default=None, help='模型路径')
|
| 1310 |
+
argparser.add_argument('--mesh_path', type=str, default='assets/1.glb', help='输入网格路径')
|
| 1311 |
+
argparser.add_argument('--output_path', type=str, default='results/1', help='保存路径')
|
| 1312 |
+
argparser.add_argument('--point_num', type=int, default=100000, help='采样点数量')
|
| 1313 |
+
argparser.add_argument('--prompt_num', type=int, default=400, help='提示数量')
|
| 1314 |
+
argparser.add_argument('--threshold', type=float, default=0.95, help='阈值')
|
| 1315 |
+
argparser.add_argument('--post_process', type=int, default=0, help='是否后处理')
|
| 1316 |
+
argparser.add_argument('--save_mid_res', type=int, default=1, help='是否保存中间结果')
|
| 1317 |
+
argparser.add_argument('--show_info', type=int, default=1, help='是否显示信息')
|
| 1318 |
+
argparser.add_argument('--show_time_info', type=int, default=1, help='是否显示时间信息')
|
| 1319 |
+
argparser.add_argument('--seed', type=int, default=42, help='随机种子')
|
| 1320 |
+
argparser.add_argument('--parallel', type=int, default=1, help='是否使用多卡')
|
| 1321 |
+
argparser.add_argument('--prompt_bs', type=int, default=32, help='提示点推理时的batch size大小')
|
| 1322 |
+
argparser.add_argument('--clean_mesh', type=int, default=1, help='是否清洗网格')
|
| 1323 |
+
args = argparser.parse_args()
|
| 1324 |
+
Timer.STATE = args.show_time_info
|
| 1325 |
+
|
| 1326 |
+
|
| 1327 |
+
output_path = args.output_path
|
| 1328 |
+
os.makedirs(output_path, exist_ok=True)
|
| 1329 |
+
ckpt_path = args.ckpt_path
|
| 1330 |
+
auto_mask = AutoMask(ckpt_path)
|
| 1331 |
+
mesh_path = args.mesh_path
|
| 1332 |
+
if os.path.isdir(mesh_path):
|
| 1333 |
+
for file in os.listdir(mesh_path):
|
| 1334 |
+
if not (file.endswith('.glb') or file.endswith('.obj') or file.endswith('.ply')):
|
| 1335 |
+
continue
|
| 1336 |
+
_mesh_path = os.path.join(mesh_path, file)
|
| 1337 |
+
_output_path = os.path.join(output_path, file[:-4])
|
| 1338 |
+
os.makedirs(_output_path, exist_ok=True)
|
| 1339 |
+
mesh = trimesh.load(_mesh_path, force='mesh')
|
| 1340 |
+
set_seed(args.seed)
|
| 1341 |
+
aabb, face_ids, mesh = auto_mask.predict_aabb(mesh,
|
| 1342 |
+
save_path=_output_path,
|
| 1343 |
+
point_num=args.point_num,
|
| 1344 |
+
prompt_num=args.prompt_num,
|
| 1345 |
+
threshold=args.threshold,
|
| 1346 |
+
post_process=args.post_process,
|
| 1347 |
+
save_mid_res=args.save_mid_res,
|
| 1348 |
+
show_info=args.show_info,
|
| 1349 |
+
seed=args.seed,
|
| 1350 |
+
is_parallel=args.parallel,
|
| 1351 |
+
clean_mesh_flag=args.clean_mesh,)
|
| 1352 |
+
else:
|
| 1353 |
+
mesh = trimesh.load(mesh_path, force='mesh')
|
| 1354 |
+
set_seed(args.seed)
|
| 1355 |
+
aabb, face_ids, mesh = auto_mask.predict_aabb(mesh,
|
| 1356 |
+
save_path=output_path,
|
| 1357 |
+
point_num=args.point_num,
|
| 1358 |
+
prompt_num=args.prompt_num,
|
| 1359 |
+
threshold=args.threshold,
|
| 1360 |
+
post_process=args.post_process,
|
| 1361 |
+
save_mid_res=args.save_mid_res,
|
| 1362 |
+
show_info=args.show_info,
|
| 1363 |
+
seed=args.seed,
|
| 1364 |
+
is_parallel=args.parallel,
|
| 1365 |
+
clean_mesh_flag=args.clean_mesh,)
|
| 1366 |
+
|
| 1367 |
+
###############################################
|
| 1368 |
+
## 可以通过以下代码保存返回的结果
|
| 1369 |
+
## You can save the returned result by the following code
|
| 1370 |
+
################# save result #################
|
| 1371 |
+
# color_map = {}
|
| 1372 |
+
# unique_ids = np.unique(face_ids)
|
| 1373 |
+
# for i in unique_ids:
|
| 1374 |
+
# if i == -1:
|
| 1375 |
+
# continue
|
| 1376 |
+
# part_color = np.random.rand(3) * 255
|
| 1377 |
+
# color_map[i] = part_color
|
| 1378 |
+
# face_colors = []
|
| 1379 |
+
# for i in face_ids:
|
| 1380 |
+
# if i == -1:
|
| 1381 |
+
# face_colors.append([0, 0, 0])
|
| 1382 |
+
# else:
|
| 1383 |
+
# face_colors.append(color_map[i])
|
| 1384 |
+
# face_colors = np.array(face_colors).astype(np.uint8)
|
| 1385 |
+
# mesh_save = mesh.copy()
|
| 1386 |
+
# mesh_save.visual.face_colors = face_colors
|
| 1387 |
+
# mesh_save.export(os.path.join(output_path, 'auto_mask_mesh.glb'))
|
| 1388 |
+
# scene_mesh = trimesh.Scene()
|
| 1389 |
+
# scene_mesh.add_geometry(mesh_save)
|
| 1390 |
+
# for i in range(len(aabb)):
|
| 1391 |
+
# min_xyz, max_xyz = aabb[i]
|
| 1392 |
+
# center = (min_xyz + max_xyz) / 2
|
| 1393 |
+
# size = max_xyz - min_xyz
|
| 1394 |
+
# box = trimesh.path.creation.box_outline()
|
| 1395 |
+
# box.vertices *= size
|
| 1396 |
+
# box.vertices += center
|
| 1397 |
+
# scene_mesh.add_geometry(box)
|
| 1398 |
+
# scene_mesh.export(os.path.join(output_path, 'auto_mask_aabb.glb'))
|
| 1399 |
+
################# save result #################
|
| 1400 |
+
|
| 1401 |
+
'''
|
| 1402 |
+
python auto_mask.py --parallel 0
|
| 1403 |
+
python auto_mask.py --ckpt_path ../weights/last.ckpt --mesh_path assets/1.glb --output_path results/1 --parallel 0
|
| 1404 |
+
python auto_mask.py --ckpt_path ../weights/last.ckpt --mesh_path assets --output_path results/all
|
| 1405 |
+
'''
|
P3-SAM/demo/auto_mask_no_postprocess.py
ADDED
|
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
import argparse
|
| 7 |
+
import trimesh
|
| 8 |
+
from sklearn.decomposition import PCA
|
| 9 |
+
import fpsample
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import threading
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
# from tqdm.notebook import tqdm
|
| 15 |
+
import time
|
| 16 |
+
import copy
|
| 17 |
+
import shutil
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
|
| 22 |
+
import numba
|
| 23 |
+
from numba import njit
|
| 24 |
+
|
| 25 |
+
sys.path.append("..")
|
| 26 |
+
from model import build_P3SAM, load_state_dict
|
| 27 |
+
|
| 28 |
+
from utils.chamfer3D.dist_chamfer_3D import chamfer_3DDist
|
| 29 |
+
|
| 30 |
+
cmd_loss = chamfer_3DDist()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class P3SAM(nn.Module):
|
| 34 |
+
def __init__(self):
|
| 35 |
+
super().__init__()
|
| 36 |
+
build_P3SAM(self)
|
| 37 |
+
|
| 38 |
+
def load_state_dict(self,
|
| 39 |
+
ckpt_path=None,
|
| 40 |
+
state_dict=None,
|
| 41 |
+
strict=True,
|
| 42 |
+
assign=False,
|
| 43 |
+
ignore_seg_mlp=False,
|
| 44 |
+
ignore_seg_s2_mlp=False,
|
| 45 |
+
ignore_iou_mlp=False):
|
| 46 |
+
load_state_dict(self,
|
| 47 |
+
ckpt_path=ckpt_path,
|
| 48 |
+
state_dict=state_dict,
|
| 49 |
+
strict=strict,
|
| 50 |
+
assign=assign,
|
| 51 |
+
ignore_seg_mlp=ignore_seg_mlp,
|
| 52 |
+
ignore_seg_s2_mlp=ignore_seg_s2_mlp,
|
| 53 |
+
ignore_iou_mlp=ignore_iou_mlp)
|
| 54 |
+
|
| 55 |
+
def forward(self, feats, points, point_prompt, iter=1):
|
| 56 |
+
"""
|
| 57 |
+
feats: [K, N, 512]
|
| 58 |
+
points: [K, N, 3]
|
| 59 |
+
point_prompt: [K, N, 3]
|
| 60 |
+
"""
|
| 61 |
+
# print(feats.shape, points.shape, point_prompt.shape)
|
| 62 |
+
point_num = points.shape[1]
|
| 63 |
+
feats = feats.transpose(0, 1) # [N, K, 512]
|
| 64 |
+
points = points.transpose(0, 1) # [N, K, 3]
|
| 65 |
+
point_prompt = point_prompt.transpose(0, 1) # [N, K, 3]
|
| 66 |
+
feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3]
|
| 67 |
+
|
| 68 |
+
# 预测mask stage-1
|
| 69 |
+
pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K]
|
| 70 |
+
pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K]
|
| 71 |
+
pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K]
|
| 72 |
+
pred_mask = torch.stack(
|
| 73 |
+
[pred_mask_1, pred_mask_2, pred_mask_3], dim=-1
|
| 74 |
+
) # [N, K, 3]
|
| 75 |
+
|
| 76 |
+
for _ in range(iter):
|
| 77 |
+
# 预测mask stage-2
|
| 78 |
+
feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3]
|
| 79 |
+
feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512]
|
| 80 |
+
feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512]
|
| 81 |
+
feats_seg_global = feats_seg_global.unsqueeze(0).repeat(
|
| 82 |
+
point_num, 1, 1
|
| 83 |
+
) # [N, K, 512]
|
| 84 |
+
feats_seg_3 = torch.cat(
|
| 85 |
+
[feats_seg_global, feats_seg_2], dim=-1
|
| 86 |
+
) # [N, K, 512+3+3+3+512]
|
| 87 |
+
pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K]
|
| 88 |
+
pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K]
|
| 89 |
+
pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K]
|
| 90 |
+
pred_mask_s2 = torch.stack(
|
| 91 |
+
[pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1
|
| 92 |
+
) # [N,, K 3]
|
| 93 |
+
pred_mask = pred_mask_s2
|
| 94 |
+
|
| 95 |
+
mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K]
|
| 96 |
+
mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K]
|
| 97 |
+
mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K]
|
| 98 |
+
|
| 99 |
+
feats_iou = torch.cat(
|
| 100 |
+
[feats_seg_global, feats_seg, pred_mask_s2], dim=-1
|
| 101 |
+
) # [N, K, 512+3+3+3+512]
|
| 102 |
+
feats_iou = self.iou_mlp(feats_iou) # [N, K, 512]
|
| 103 |
+
feats_iou = torch.max(feats_iou, dim=0).values # [K, 512]
|
| 104 |
+
pred_iou = self.iou_mlp_out(feats_iou) # [K, 3]
|
| 105 |
+
pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3]
|
| 106 |
+
|
| 107 |
+
mask_1 = mask_1.transpose(0, 1) # [K, N]
|
| 108 |
+
mask_2 = mask_2.transpose(0, 1) # [K, N]
|
| 109 |
+
mask_3 = mask_3.transpose(0, 1) # [K, N]
|
| 110 |
+
|
| 111 |
+
return mask_1, mask_2, mask_3, pred_iou
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def normalize_pc(pc):
|
| 115 |
+
"""
|
| 116 |
+
pc: (N, 3)
|
| 117 |
+
"""
|
| 118 |
+
max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0)
|
| 119 |
+
center = (max_ + min_) / 2
|
| 120 |
+
scale = (max_ - min_) / 2
|
| 121 |
+
scale = np.max(np.abs(scale))
|
| 122 |
+
pc = (pc - center) / (scale + 1e-10)
|
| 123 |
+
return pc
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def get_feat(model, points, normals):
|
| 128 |
+
data_dict = {
|
| 129 |
+
"coord": points,
|
| 130 |
+
"normal": normals,
|
| 131 |
+
"color": np.ones_like(points),
|
| 132 |
+
"batch": np.zeros(points.shape[0], dtype=np.int64),
|
| 133 |
+
}
|
| 134 |
+
data_dict = model.transform(data_dict)
|
| 135 |
+
for k in data_dict:
|
| 136 |
+
if isinstance(data_dict[k], torch.Tensor):
|
| 137 |
+
data_dict[k] = data_dict[k].cuda()
|
| 138 |
+
point = model.sonata(data_dict)
|
| 139 |
+
while "pooling_parent" in point.keys():
|
| 140 |
+
assert "pooling_inverse" in point.keys()
|
| 141 |
+
parent = point.pop("pooling_parent")
|
| 142 |
+
inverse = point.pop("pooling_inverse")
|
| 143 |
+
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
|
| 144 |
+
point = parent
|
| 145 |
+
feat = point.feat # [M, 1232]
|
| 146 |
+
feat = model.mlp(feat) # [M, 512]
|
| 147 |
+
feat = feat[point.inverse] # [N, 512]
|
| 148 |
+
feats = feat
|
| 149 |
+
return feats
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@torch.no_grad()
|
| 153 |
+
def get_mask(model, feats, points, point_prompt, iter=1):
|
| 154 |
+
"""
|
| 155 |
+
feats: [N, 512]
|
| 156 |
+
points: [N, 3]
|
| 157 |
+
point_prompt: [K, 3]
|
| 158 |
+
"""
|
| 159 |
+
point_num = points.shape[0]
|
| 160 |
+
prompt_num = point_prompt.shape[0]
|
| 161 |
+
feats = feats.unsqueeze(1) # [N, 1, 512]
|
| 162 |
+
feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512]
|
| 163 |
+
points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3]
|
| 164 |
+
points = points.repeat(1, prompt_num, 1) # [N, K, 3]
|
| 165 |
+
prompt_coord = (
|
| 166 |
+
torch.from_numpy(point_prompt).float().cuda().unsqueeze(0)
|
| 167 |
+
) # [1, K, 3]
|
| 168 |
+
prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3]
|
| 169 |
+
|
| 170 |
+
feats = feats.transpose(0, 1) # [K, N, 512]
|
| 171 |
+
points = points.transpose(0, 1) # [K, N, 3]
|
| 172 |
+
prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3]
|
| 173 |
+
|
| 174 |
+
mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter)
|
| 175 |
+
|
| 176 |
+
mask_1 = mask_1.transpose(0, 1) # [N, K]
|
| 177 |
+
mask_2 = mask_2.transpose(0, 1) # [N, K]
|
| 178 |
+
mask_3 = mask_3.transpose(0, 1) # [N, K]
|
| 179 |
+
|
| 180 |
+
mask_1 = mask_1.detach().cpu().numpy() > 0.5
|
| 181 |
+
mask_2 = mask_2.detach().cpu().numpy() > 0.5
|
| 182 |
+
mask_3 = mask_3.detach().cpu().numpy() > 0.5
|
| 183 |
+
|
| 184 |
+
org_iou = pred_iou.detach().cpu().numpy() # [K, 3]
|
| 185 |
+
|
| 186 |
+
return mask_1, mask_2, mask_3, org_iou
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def cal_iou(m1, m2):
|
| 190 |
+
return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def cal_single_iou(m1, m2):
|
| 194 |
+
return np.sum(np.logical_and(m1, m2)) / np.sum(m1)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def iou_3d(box1, box2, signle=None):
|
| 198 |
+
"""
|
| 199 |
+
计算两个三维边界框的交并比 (IoU)
|
| 200 |
+
|
| 201 |
+
参数:
|
| 202 |
+
box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max]
|
| 203 |
+
box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max]
|
| 204 |
+
|
| 205 |
+
返回:
|
| 206 |
+
float: 交并比 (IoU) 值
|
| 207 |
+
"""
|
| 208 |
+
# 计算交集的坐标
|
| 209 |
+
intersection_xmin = max(box1[0], box2[0])
|
| 210 |
+
intersection_ymin = max(box1[1], box2[1])
|
| 211 |
+
intersection_zmin = max(box1[2], box2[2])
|
| 212 |
+
intersection_xmax = min(box1[3], box2[3])
|
| 213 |
+
intersection_ymax = min(box1[4], box2[4])
|
| 214 |
+
intersection_zmax = min(box1[5], box2[5])
|
| 215 |
+
|
| 216 |
+
# 判断是否有交集
|
| 217 |
+
if (
|
| 218 |
+
intersection_xmin >= intersection_xmax
|
| 219 |
+
or intersection_ymin >= intersection_ymax
|
| 220 |
+
or intersection_zmin >= intersection_zmax
|
| 221 |
+
):
|
| 222 |
+
return 0.0 # 无交集
|
| 223 |
+
|
| 224 |
+
# 计算交集的体积
|
| 225 |
+
intersection_volume = (
|
| 226 |
+
(intersection_xmax - intersection_xmin)
|
| 227 |
+
* (intersection_ymax - intersection_ymin)
|
| 228 |
+
* (intersection_zmax - intersection_zmin)
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# 计算两个盒子的体积
|
| 232 |
+
box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2])
|
| 233 |
+
box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2])
|
| 234 |
+
|
| 235 |
+
if signle is None:
|
| 236 |
+
# 计算并集的体积
|
| 237 |
+
union_volume = box1_volume + box2_volume - intersection_volume
|
| 238 |
+
elif signle == "1":
|
| 239 |
+
union_volume = box1_volume
|
| 240 |
+
elif signle == "2":
|
| 241 |
+
union_volume = box2_volume
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError("signle must be None or 1 or 2")
|
| 244 |
+
|
| 245 |
+
# 计算 IoU
|
| 246 |
+
iou = intersection_volume / union_volume if union_volume > 0 else 0.0
|
| 247 |
+
return iou
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def cal_point_bbox_iou(p1, p2, signle=None):
|
| 251 |
+
min_p1 = np.min(p1, axis=0)
|
| 252 |
+
max_p1 = np.max(p1, axis=0)
|
| 253 |
+
min_p2 = np.min(p2, axis=0)
|
| 254 |
+
max_p2 = np.max(p2, axis=0)
|
| 255 |
+
box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]]
|
| 256 |
+
box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]]
|
| 257 |
+
return iou_3d(box1, box2, signle)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def cal_bbox_iou(points, m1, m2):
|
| 261 |
+
p1 = points[m1]
|
| 262 |
+
p2 = points[m2]
|
| 263 |
+
return cal_point_bbox_iou(p1, p2)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def clean_mesh(mesh):
|
| 267 |
+
"""
|
| 268 |
+
mesh: trimesh.Trimesh
|
| 269 |
+
"""
|
| 270 |
+
# 1. 合并接近的顶点
|
| 271 |
+
mesh.merge_vertices()
|
| 272 |
+
|
| 273 |
+
# 2. 删除重复的顶点
|
| 274 |
+
# 3. 删除重复的面片
|
| 275 |
+
mesh.process(True)
|
| 276 |
+
return mesh
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def get_aabb_from_face_ids(mesh, face_ids):
|
| 280 |
+
unique_ids = np.unique(face_ids)
|
| 281 |
+
aabb = []
|
| 282 |
+
for i in unique_ids:
|
| 283 |
+
if i == -1 or i == -2:
|
| 284 |
+
continue
|
| 285 |
+
_part_mask = face_ids == i
|
| 286 |
+
_faces = mesh.faces[_part_mask]
|
| 287 |
+
_faces = np.reshape(_faces, (-1))
|
| 288 |
+
_points = mesh.vertices[_faces]
|
| 289 |
+
min_xyz = np.min(_points, axis=0)
|
| 290 |
+
max_xyz = np.max(_points, axis=0)
|
| 291 |
+
aabb.append([min_xyz, max_xyz])
|
| 292 |
+
return np.array(aabb)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class Timer:
|
| 296 |
+
def __init__(self, name):
|
| 297 |
+
self.name = name
|
| 298 |
+
|
| 299 |
+
def __enter__(self):
|
| 300 |
+
self.start_time = time.time()
|
| 301 |
+
return self # 可以返回 self 以便在 with 块内访问
|
| 302 |
+
|
| 303 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 304 |
+
self.end_time = time.time()
|
| 305 |
+
self.elapsed_time = self.end_time - self.start_time
|
| 306 |
+
print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def sample_points_pre_face(vertices, faces, n_point_per_face=2000):
|
| 310 |
+
n_f = faces.shape[0] # 面片数量
|
| 311 |
+
|
| 312 |
+
# 生成随机数 u, v
|
| 313 |
+
u = np.sqrt(np.random.rand(n_f, n_point_per_face, 1)) # (n_f, n_point_per_face, 1)
|
| 314 |
+
v = np.random.rand(n_f, n_point_per_face, 1) # (n_f, n_point_per_face, 1)
|
| 315 |
+
|
| 316 |
+
# 计算 barycentric 坐标
|
| 317 |
+
w0 = 1 - u
|
| 318 |
+
w1 = u * (1 - v)
|
| 319 |
+
w2 = u * v # (n_f, n_point_per_face, 1)
|
| 320 |
+
|
| 321 |
+
# 从顶点中提取每个面的三个顶点
|
| 322 |
+
face_v_0 = vertices[faces[:, 0].reshape(-1)] # (n_f, 3)
|
| 323 |
+
face_v_1 = vertices[faces[:, 1].reshape(-1)] # (n_f, 3)
|
| 324 |
+
face_v_2 = vertices[faces[:, 2].reshape(-1)] # (n_f, 3)
|
| 325 |
+
|
| 326 |
+
# 扩展维度以匹配 w0, w1, w2 的形状
|
| 327 |
+
face_v_0 = face_v_0.reshape(n_f, 1, 3) # (n_f, 1, 3)
|
| 328 |
+
face_v_1 = face_v_1.reshape(n_f, 1, 3) # (n_f, 1, 3)
|
| 329 |
+
face_v_2 = face_v_2.reshape(n_f, 1, 3) # (n_f, 1, 3)
|
| 330 |
+
|
| 331 |
+
# 计算每个点的坐标
|
| 332 |
+
points = w0 * face_v_0 + w1 * face_v_1 + w2 * face_v_2 # (n_f, n_point_per_face, 3)
|
| 333 |
+
|
| 334 |
+
return points
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def cal_cd_batch(p1, p2, pn=100000):
|
| 338 |
+
p1_n = p1.shape[0]
|
| 339 |
+
batch_num = (p1_n + pn - 1) // pn
|
| 340 |
+
p2_cuda = torch.from_numpy(p2).cuda().float().unsqueeze(0)
|
| 341 |
+
p1_cuda = torch.from_numpy(p1).cuda().float().unsqueeze(0)
|
| 342 |
+
cd_res = []
|
| 343 |
+
for i in tqdm(range(batch_num)):
|
| 344 |
+
start_idx = i * pn
|
| 345 |
+
end_idx = min((i + 1) * pn, p1_n)
|
| 346 |
+
_p1_cuda = p1_cuda[:, start_idx:end_idx, :]
|
| 347 |
+
_, _, idx, _ = cmd_loss(_p1_cuda, p2_cuda)
|
| 348 |
+
idx = idx[0].detach().cpu().numpy()
|
| 349 |
+
cd_res.append(idx)
|
| 350 |
+
cd_res = np.concatenate(cd_res, axis=0)
|
| 351 |
+
return cd_res
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def remove_outliers_iqr(data, factor=1.5):
|
| 355 |
+
"""
|
| 356 |
+
基于 IQR 去除离群值
|
| 357 |
+
:param data: 输入的列表或 NumPy 数组
|
| 358 |
+
:param factor: IQR 的倍数(默认 1.5)
|
| 359 |
+
:return: 去除离群值后的列表
|
| 360 |
+
"""
|
| 361 |
+
data = np.array(data, dtype=np.float32)
|
| 362 |
+
q1 = np.percentile(data, 25) # 第一四分位数
|
| 363 |
+
q3 = np.percentile(data, 75) # 第三四分位数
|
| 364 |
+
iqr = q3 - q1 # 四分位距
|
| 365 |
+
lower_bound = q1 - factor * iqr
|
| 366 |
+
upper_bound = q3 + factor * iqr
|
| 367 |
+
return data[(data >= lower_bound) & (data <= upper_bound)].tolist()
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def better_aabb(points):
|
| 371 |
+
x = points[:, 0]
|
| 372 |
+
y = points[:, 1]
|
| 373 |
+
z = points[:, 2]
|
| 374 |
+
x = remove_outliers_iqr(x)
|
| 375 |
+
y = remove_outliers_iqr(y)
|
| 376 |
+
z = remove_outliers_iqr(z)
|
| 377 |
+
min_xyz = np.array([np.min(x), np.min(y), np.min(z)])
|
| 378 |
+
max_xyz = np.array([np.max(x), np.max(y), np.max(z)])
|
| 379 |
+
return [min_xyz, max_xyz]
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def save_mesh(save_path, mesh, face_ids, color_map):
|
| 383 |
+
face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8)
|
| 384 |
+
for i in tqdm(range(len(mesh.faces)), disable=True):
|
| 385 |
+
_max_id = face_ids[i]
|
| 386 |
+
if _max_id == -2:
|
| 387 |
+
continue
|
| 388 |
+
face_colors[i, :3] = color_map[_max_id]
|
| 389 |
+
|
| 390 |
+
mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
|
| 391 |
+
mesh_save.visual.face_colors = face_colors
|
| 392 |
+
mesh_save.export(save_path)
|
| 393 |
+
mesh_save.export(save_path.replace(".glb", ".ply"))
|
| 394 |
+
# print('保存mesh完成')
|
| 395 |
+
|
| 396 |
+
scene_mesh = trimesh.Scene()
|
| 397 |
+
scene_mesh.add_geometry(mesh_save)
|
| 398 |
+
unique_ids = np.unique(face_ids)
|
| 399 |
+
aabb = []
|
| 400 |
+
for i in unique_ids:
|
| 401 |
+
if i == -1 or i == -2:
|
| 402 |
+
continue
|
| 403 |
+
_part_mask = face_ids == i
|
| 404 |
+
_faces = mesh.faces[_part_mask]
|
| 405 |
+
_faces = np.reshape(_faces, (-1))
|
| 406 |
+
_points = mesh.vertices[_faces]
|
| 407 |
+
min_xyz, max_xyz = better_aabb(_points)
|
| 408 |
+
center = (min_xyz + max_xyz) / 2
|
| 409 |
+
size = max_xyz - min_xyz
|
| 410 |
+
box = trimesh.path.creation.box_outline()
|
| 411 |
+
box.vertices *= size
|
| 412 |
+
box.vertices += center
|
| 413 |
+
box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]])
|
| 414 |
+
box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8)
|
| 415 |
+
box.colors = box_color
|
| 416 |
+
scene_mesh.add_geometry(box)
|
| 417 |
+
min_xyz = np.min(_points, axis=0)
|
| 418 |
+
max_xyz = np.max(_points, axis=0)
|
| 419 |
+
aabb.append([min_xyz, max_xyz])
|
| 420 |
+
scene_mesh.export(save_path.replace(".glb", "_aabb.glb"))
|
| 421 |
+
aabb = np.array(aabb)
|
| 422 |
+
np.save(save_path.replace(".glb", "_aabb.npy"), aabb)
|
| 423 |
+
np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def mesh_sam(
|
| 427 |
+
model,
|
| 428 |
+
mesh,
|
| 429 |
+
save_path,
|
| 430 |
+
point_num=100000,
|
| 431 |
+
prompt_num=400,
|
| 432 |
+
save_mid_res=False,
|
| 433 |
+
show_info=False,
|
| 434 |
+
post_process=False,
|
| 435 |
+
threshold=0.95,
|
| 436 |
+
clean_mesh_flag=True,
|
| 437 |
+
seed=42,
|
| 438 |
+
prompt_bs=32,
|
| 439 |
+
):
|
| 440 |
+
with Timer("加载mesh"):
|
| 441 |
+
model, model_parallel = model
|
| 442 |
+
if clean_mesh_flag:
|
| 443 |
+
mesh = clean_mesh(mesh)
|
| 444 |
+
mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=False)
|
| 445 |
+
if show_info:
|
| 446 |
+
print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}")
|
| 447 |
+
|
| 448 |
+
point_num = 100000
|
| 449 |
+
prompt_num = 400
|
| 450 |
+
|
| 451 |
+
with Timer("采样点云"):
|
| 452 |
+
_points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed)
|
| 453 |
+
_points_org = _points.copy()
|
| 454 |
+
_points = normalize_pc(_points)
|
| 455 |
+
normals = mesh.face_normals[face_idx]
|
| 456 |
+
# _points = _points + np.random.normal(0, 1, size=_points.shape) * 0.01
|
| 457 |
+
# normals = normals * 0. # debug no normal
|
| 458 |
+
if show_info:
|
| 459 |
+
print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}")
|
| 460 |
+
|
| 461 |
+
with Timer("获取特征"):
|
| 462 |
+
_feats = get_feat(model, _points, normals)
|
| 463 |
+
if show_info:
|
| 464 |
+
print("预处理特征")
|
| 465 |
+
|
| 466 |
+
if save_mid_res:
|
| 467 |
+
feat_save = _feats.float().detach().cpu().numpy()
|
| 468 |
+
data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
|
| 469 |
+
pca = PCA(n_components=3)
|
| 470 |
+
data_reduced = pca.fit_transform(data_scaled)
|
| 471 |
+
data_reduced = (data_reduced - data_reduced.min()) / (
|
| 472 |
+
data_reduced.max() - data_reduced.min()
|
| 473 |
+
)
|
| 474 |
+
_colors_pca = (data_reduced * 255).astype(np.uint8)
|
| 475 |
+
pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca)
|
| 476 |
+
pc_save.export(os.path.join(save_path, "point_pca.glb"))
|
| 477 |
+
pc_save.export(os.path.join(save_path, "point_pca.ply"))
|
| 478 |
+
if show_info:
|
| 479 |
+
print("PCA获取特征颜色")
|
| 480 |
+
|
| 481 |
+
with Timer("FPS采样提示点"):
|
| 482 |
+
fps_idx = fpsample.fps_sampling(_points, prompt_num)
|
| 483 |
+
_point_prompts = _points[fps_idx]
|
| 484 |
+
if save_mid_res:
|
| 485 |
+
trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
|
| 486 |
+
os.path.join(save_path, "point_prompts_pca.glb")
|
| 487 |
+
)
|
| 488 |
+
trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
|
| 489 |
+
os.path.join(save_path, "point_prompts_pca.ply")
|
| 490 |
+
)
|
| 491 |
+
if show_info:
|
| 492 |
+
print("采样完成")
|
| 493 |
+
|
| 494 |
+
with Timer("推理"):
|
| 495 |
+
bs = prompt_bs
|
| 496 |
+
step_num = prompt_num // bs + 1
|
| 497 |
+
mask_res = []
|
| 498 |
+
iou_res = []
|
| 499 |
+
for i in tqdm(range(step_num), disable=not show_info):
|
| 500 |
+
cur_propmt = _point_prompts[bs * i : bs * (i + 1)]
|
| 501 |
+
pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask(
|
| 502 |
+
model_parallel, _feats, _points, cur_propmt
|
| 503 |
+
)
|
| 504 |
+
pred_mask = np.stack(
|
| 505 |
+
[pred_mask_1, pred_mask_2, pred_mask_3], axis=-1
|
| 506 |
+
) # [N, K, 3]
|
| 507 |
+
max_idx = np.argmax(pred_iou, axis=-1) # [K]
|
| 508 |
+
for j in range(max_idx.shape[0]):
|
| 509 |
+
mask_res.append(pred_mask[:, j, max_idx[j]])
|
| 510 |
+
iou_res.append(pred_iou[j, max_idx[j]])
|
| 511 |
+
mask_res = np.stack(mask_res, axis=-1) # [N, K]
|
| 512 |
+
if show_info:
|
| 513 |
+
print("prmopt 推理完成")
|
| 514 |
+
|
| 515 |
+
with Timer("根据IOU排序"):
|
| 516 |
+
iou_res = np.array(iou_res).tolist()
|
| 517 |
+
mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)]
|
| 518 |
+
mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True)
|
| 519 |
+
mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)]
|
| 520 |
+
iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)]
|
| 521 |
+
|
| 522 |
+
# clusters = {}
|
| 523 |
+
# for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
|
| 524 |
+
# _mask = mask_sorted[i]
|
| 525 |
+
# union_flag = False
|
| 526 |
+
# for j in clusters.keys():
|
| 527 |
+
# if cal_iou(_mask, mask_sorted[j]) > 0.9:
|
| 528 |
+
# clusters[j].append(i)
|
| 529 |
+
# union_flag = True
|
| 530 |
+
# break
|
| 531 |
+
# if not union_flag:
|
| 532 |
+
# clusters[i] = [i]
|
| 533 |
+
with Timer("NMS"):
|
| 534 |
+
clusters = defaultdict(list)
|
| 535 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
| 536 |
+
for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
|
| 537 |
+
_mask = mask_sorted[i]
|
| 538 |
+
futures = []
|
| 539 |
+
for j in clusters.keys():
|
| 540 |
+
futures.append(executor.submit(cal_iou, _mask, mask_sorted[j]))
|
| 541 |
+
|
| 542 |
+
for j, future in zip(clusters.keys(), futures):
|
| 543 |
+
if future.result() > 0.9:
|
| 544 |
+
clusters[j].append(i)
|
| 545 |
+
break
|
| 546 |
+
else:
|
| 547 |
+
clusters[i].append(i)
|
| 548 |
+
|
| 549 |
+
# print(clusters)
|
| 550 |
+
if show_info:
|
| 551 |
+
print(f"NMS完成,mask数量:{len(clusters)}")
|
| 552 |
+
|
| 553 |
+
if save_mid_res:
|
| 554 |
+
part_mask_save_path = os.path.join(save_path, "part_mask")
|
| 555 |
+
if os.path.exists(part_mask_save_path):
|
| 556 |
+
shutil.rmtree(part_mask_save_path)
|
| 557 |
+
os.makedirs(part_mask_save_path, exist_ok=True)
|
| 558 |
+
for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info):
|
| 559 |
+
cluster_num = len(clusters[i])
|
| 560 |
+
cluster_iou = iou_sorted[i]
|
| 561 |
+
cluster_area = np.sum(mask_sorted[i])
|
| 562 |
+
if cluster_num <= 2:
|
| 563 |
+
continue
|
| 564 |
+
mask_save = mask_sorted[i]
|
| 565 |
+
mask_save = np.expand_dims(mask_save, axis=-1)
|
| 566 |
+
mask_save = np.repeat(mask_save, 3, axis=-1)
|
| 567 |
+
mask_save = (mask_save * 255).astype(np.uint8)
|
| 568 |
+
point_save = trimesh.points.PointCloud(_points, colors=mask_save)
|
| 569 |
+
point_save.export(
|
| 570 |
+
os.path.join(
|
| 571 |
+
part_mask_save_path,
|
| 572 |
+
f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
|
| 573 |
+
)
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# 过滤只有一个mask的cluster
|
| 577 |
+
with Timer("过滤只有一个mask的cluster"):
|
| 578 |
+
filtered_clusters = []
|
| 579 |
+
other_clusters = []
|
| 580 |
+
for i in clusters.keys():
|
| 581 |
+
if len(clusters[i]) > 2:
|
| 582 |
+
filtered_clusters.append(i)
|
| 583 |
+
else:
|
| 584 |
+
other_clusters.append(i)
|
| 585 |
+
if show_info:
|
| 586 |
+
print(
|
| 587 |
+
f"过滤前:{len(clusters)} 个cluster,"
|
| 588 |
+
f"过滤后:{len(filtered_clusters)} 个cluster"
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# 再次合并
|
| 592 |
+
with Timer("再次合并"):
|
| 593 |
+
filtered_clusters_num = len(filtered_clusters)
|
| 594 |
+
cluster2 = {}
|
| 595 |
+
is_union = [False] * filtered_clusters_num
|
| 596 |
+
for i in range(filtered_clusters_num):
|
| 597 |
+
if is_union[i]:
|
| 598 |
+
continue
|
| 599 |
+
cur_cluster = filtered_clusters[i]
|
| 600 |
+
cluster2[cur_cluster] = [cur_cluster]
|
| 601 |
+
for j in range(i + 1, filtered_clusters_num):
|
| 602 |
+
if is_union[j]:
|
| 603 |
+
continue
|
| 604 |
+
tar_cluster = filtered_clusters[j]
|
| 605 |
+
# if cal_single_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.9:
|
| 606 |
+
# if cal_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.5:
|
| 607 |
+
if (
|
| 608 |
+
cal_bbox_iou(
|
| 609 |
+
_points, mask_sorted[tar_cluster], mask_sorted[cur_cluster]
|
| 610 |
+
)
|
| 611 |
+
> 0.5
|
| 612 |
+
):
|
| 613 |
+
cluster2[cur_cluster].append(tar_cluster)
|
| 614 |
+
is_union[j] = True
|
| 615 |
+
if show_info:
|
| 616 |
+
print(f"再次合并,合并数量:{len(cluster2.keys())}")
|
| 617 |
+
|
| 618 |
+
with Timer("计算没有mask的点"):
|
| 619 |
+
no_mask = np.ones(point_num)
|
| 620 |
+
for i in cluster2:
|
| 621 |
+
part_mask = mask_sorted[i]
|
| 622 |
+
no_mask[part_mask] = 0
|
| 623 |
+
if show_info:
|
| 624 |
+
print(
|
| 625 |
+
f"{np.sum(no_mask == 1)} 个点没有mask,"
|
| 626 |
+
f" 占比:{np.sum(no_mask == 1) / point_num:.4f}"
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
with Timer("修补遗漏mask"):
|
| 630 |
+
# 查询漏掉的mask
|
| 631 |
+
for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info):
|
| 632 |
+
if i in cluster2:
|
| 633 |
+
continue
|
| 634 |
+
part_mask = mask_sorted[i]
|
| 635 |
+
_iou = cal_single_iou(part_mask, no_mask)
|
| 636 |
+
if _iou > 0.7:
|
| 637 |
+
cluster2[i] = [i]
|
| 638 |
+
no_mask[part_mask] = 0
|
| 639 |
+
if save_mid_res:
|
| 640 |
+
mask_save = mask_sorted[i]
|
| 641 |
+
mask_save = np.expand_dims(mask_save, axis=-1)
|
| 642 |
+
mask_save = np.repeat(mask_save, 3, axis=-1)
|
| 643 |
+
mask_save = (mask_save * 255).astype(np.uint8)
|
| 644 |
+
point_save = trimesh.points.PointCloud(_points, colors=mask_save)
|
| 645 |
+
cluster_iou = iou_sorted[i]
|
| 646 |
+
cluster_area = int(np.sum(mask_sorted[i]))
|
| 647 |
+
cluster_num = 1
|
| 648 |
+
point_save.export(
|
| 649 |
+
os.path.join(
|
| 650 |
+
part_mask_save_path,
|
| 651 |
+
f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
|
| 652 |
+
)
|
| 653 |
+
)
|
| 654 |
+
# print(cluster2)
|
| 655 |
+
# print(len(cluster2.keys()))
|
| 656 |
+
if show_info:
|
| 657 |
+
print(f"修补遗漏mask:{len(cluster2.keys())}")
|
| 658 |
+
|
| 659 |
+
with Timer("计算点云最终mask"):
|
| 660 |
+
final_mask = list(cluster2.keys())
|
| 661 |
+
final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask]
|
| 662 |
+
final_mask_area = [
|
| 663 |
+
[final_mask[i], final_mask_area[i]] for i in range(len(final_mask))
|
| 664 |
+
]
|
| 665 |
+
final_mask_area_sorted = sorted(
|
| 666 |
+
final_mask_area, key=lambda x: x[1], reverse=True
|
| 667 |
+
)
|
| 668 |
+
final_mask_sorted = [
|
| 669 |
+
final_mask_area_sorted[i][0] for i in range(len(final_mask_area))
|
| 670 |
+
]
|
| 671 |
+
final_mask_area_sorted = [
|
| 672 |
+
final_mask_area_sorted[i][1] for i in range(len(final_mask_area))
|
| 673 |
+
]
|
| 674 |
+
# print(final_mask_sorted)
|
| 675 |
+
# print(final_mask_area_sorted)
|
| 676 |
+
if show_info:
|
| 677 |
+
print(f"最终mask数量:{len(final_mask_sorted)}")
|
| 678 |
+
|
| 679 |
+
with Timer("点云上色"):
|
| 680 |
+
# 生成color map
|
| 681 |
+
color_map = {}
|
| 682 |
+
for i in final_mask_sorted:
|
| 683 |
+
part_color = np.random.rand(3) * 255
|
| 684 |
+
color_map[i] = part_color
|
| 685 |
+
# print(color_map)
|
| 686 |
+
|
| 687 |
+
result_mask = -np.ones(point_num, dtype=np.int64)
|
| 688 |
+
for i in final_mask_sorted:
|
| 689 |
+
part_mask = mask_sorted[i]
|
| 690 |
+
result_mask[part_mask] = i
|
| 691 |
+
if save_mid_res:
|
| 692 |
+
# 保存点云结果
|
| 693 |
+
result_colors = np.zeros_like(_colors_pca)
|
| 694 |
+
for i in final_mask_sorted:
|
| 695 |
+
part_color = color_map[i]
|
| 696 |
+
part_mask = mask_sorted[i]
|
| 697 |
+
result_colors[part_mask, :3] = part_color
|
| 698 |
+
trimesh.points.PointCloud(_points, colors=result_colors).export(
|
| 699 |
+
os.path.join(save_path, "auto_mask_cluster.glb")
|
| 700 |
+
)
|
| 701 |
+
trimesh.points.PointCloud(_points, colors=result_colors).export(
|
| 702 |
+
os.path.join(save_path, "auto_mask_cluster.ply")
|
| 703 |
+
)
|
| 704 |
+
if show_info:
|
| 705 |
+
print("保存点云完成")
|
| 706 |
+
|
| 707 |
+
with Timer("后处理"):
|
| 708 |
+
valid_mask = result_mask >= 0
|
| 709 |
+
_org = _points_org[valid_mask]
|
| 710 |
+
_results = result_mask[valid_mask]
|
| 711 |
+
pre_face = 10
|
| 712 |
+
_face_points = sample_points_pre_face(
|
| 713 |
+
mesh.vertices, mesh.faces, n_point_per_face=pre_face
|
| 714 |
+
)
|
| 715 |
+
_face_points = np.reshape(_face_points, (len(mesh.faces) * pre_face, 3))
|
| 716 |
+
_idx = cal_cd_batch(_face_points, _org)
|
| 717 |
+
_idx_res = _results[_idx]
|
| 718 |
+
_idx_res = np.reshape(_idx_res, (-1, pre_face))
|
| 719 |
+
|
| 720 |
+
face_ids = []
|
| 721 |
+
for i in range(len(mesh.faces)):
|
| 722 |
+
_label = np.argmax(np.bincount(_idx_res[i] + 2)) - 2
|
| 723 |
+
face_ids.append(_label)
|
| 724 |
+
final_face_ids = np.array(face_ids)
|
| 725 |
+
|
| 726 |
+
if save_mid_res:
|
| 727 |
+
save_mesh(
|
| 728 |
+
os.path.join(save_path, "auto_mask_mesh_final.glb"),
|
| 729 |
+
mesh,
|
| 730 |
+
final_face_ids,
|
| 731 |
+
color_map,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
with Timer("计算最后的aabb"):
|
| 735 |
+
aabb = get_aabb_from_face_ids(mesh, final_face_ids)
|
| 736 |
+
return aabb, final_face_ids, mesh
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
class AutoMask:
|
| 740 |
+
def __init__(
|
| 741 |
+
self,
|
| 742 |
+
ckpt_path=None,
|
| 743 |
+
point_num=100000,
|
| 744 |
+
prompt_num=400,
|
| 745 |
+
threshold=0.95,
|
| 746 |
+
post_process=True,
|
| 747 |
+
automask_instance=None,
|
| 748 |
+
):
|
| 749 |
+
"""
|
| 750 |
+
ckpt_path: str, 模型路径
|
| 751 |
+
point_num: int, 采样点数量
|
| 752 |
+
prompt_num: int, 提示数量
|
| 753 |
+
threshold: float, 阈值
|
| 754 |
+
post_process: bool, 是否后处理
|
| 755 |
+
"""
|
| 756 |
+
if automask_instance is not None:
|
| 757 |
+
self.model = automask_instance.model
|
| 758 |
+
self.model_parallel = automask_instance.model_parallel
|
| 759 |
+
else:
|
| 760 |
+
self.model = P3SAM()
|
| 761 |
+
self.model.load_state_dict(ckpt_path)
|
| 762 |
+
self.model.eval()
|
| 763 |
+
self.model_parallel = torch.nn.DataParallel(self.model)
|
| 764 |
+
self.model.cuda()
|
| 765 |
+
self.model_parallel.cuda()
|
| 766 |
+
self.point_num = point_num
|
| 767 |
+
self.prompt_num = prompt_num
|
| 768 |
+
self.threshold = threshold
|
| 769 |
+
self.post_process = post_process
|
| 770 |
+
|
| 771 |
+
def predict_aabb(
|
| 772 |
+
self,
|
| 773 |
+
mesh,
|
| 774 |
+
point_num=None,
|
| 775 |
+
prompt_num=None,
|
| 776 |
+
threshold=None,
|
| 777 |
+
post_process=None,
|
| 778 |
+
save_path=None,
|
| 779 |
+
save_mid_res=False,
|
| 780 |
+
show_info=True,
|
| 781 |
+
clean_mesh_flag=True,
|
| 782 |
+
seed=42,
|
| 783 |
+
is_parallel=True,
|
| 784 |
+
prompt_bs=32,
|
| 785 |
+
):
|
| 786 |
+
"""
|
| 787 |
+
Parameters:
|
| 788 |
+
mesh: trimesh.Trimesh, 输入网格
|
| 789 |
+
point_num: int, 采样点数量
|
| 790 |
+
prompt_num: int, 提示数量
|
| 791 |
+
threshold: float, 阈值
|
| 792 |
+
post_process: bool, 是否后处理
|
| 793 |
+
Returns:
|
| 794 |
+
aabb: np.ndarray, 包围盒
|
| 795 |
+
face_ids: np.ndarray, 面id
|
| 796 |
+
"""
|
| 797 |
+
point_num = point_num if point_num is not None else self.point_num
|
| 798 |
+
prompt_num = prompt_num if prompt_num is not None else self.prompt_num
|
| 799 |
+
threshold = threshold if threshold is not None else self.threshold
|
| 800 |
+
post_process = post_process if post_process is not None else self.post_process
|
| 801 |
+
return mesh_sam(
|
| 802 |
+
[self.model, self.model_parallel if is_parallel else self.model],
|
| 803 |
+
mesh,
|
| 804 |
+
save_path=save_path,
|
| 805 |
+
point_num=point_num,
|
| 806 |
+
prompt_num=prompt_num,
|
| 807 |
+
threshold=threshold,
|
| 808 |
+
post_process=post_process,
|
| 809 |
+
show_info=show_info,
|
| 810 |
+
save_mid_res=save_mid_res,
|
| 811 |
+
clean_mesh_flag=clean_mesh_flag,
|
| 812 |
+
seed=seed,
|
| 813 |
+
prompt_bs=prompt_bs,
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
def set_seed(seed):
|
| 818 |
+
random.seed(seed)
|
| 819 |
+
np.random.seed(seed)
|
| 820 |
+
torch.manual_seed(seed)
|
| 821 |
+
if torch.cuda.is_available():
|
| 822 |
+
torch.cuda.manual_seed(seed)
|
| 823 |
+
torch.cuda.manual_seed_all(seed)
|
| 824 |
+
torch.backends.cudnn.deterministic = True
|
| 825 |
+
torch.backends.cudnn.benchmark = False
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
if __name__ == "__main__":
|
| 829 |
+
argparser = argparse.ArgumentParser()
|
| 830 |
+
argparser.add_argument(
|
| 831 |
+
"--ckpt_path", type=str, default=None, help="模型路径"
|
| 832 |
+
)
|
| 833 |
+
argparser.add_argument(
|
| 834 |
+
"--mesh_path", type=str, default="assets/1.glb", help="输入网格路径"
|
| 835 |
+
)
|
| 836 |
+
argparser.add_argument(
|
| 837 |
+
"--output_path", type=str, default="results/1", help="保存路径"
|
| 838 |
+
)
|
| 839 |
+
argparser.add_argument("--point_num", type=int, default=100000, help="采样点数量")
|
| 840 |
+
argparser.add_argument("--prompt_num", type=int, default=400, help="提示数量")
|
| 841 |
+
argparser.add_argument("--threshold", type=float, default=0.95, help="阈值")
|
| 842 |
+
argparser.add_argument("--post_process", type=int, default=0, help="是否后处理")
|
| 843 |
+
argparser.add_argument(
|
| 844 |
+
"--save_mid_res", type=int, default=1, help="是否保存中间结果"
|
| 845 |
+
)
|
| 846 |
+
argparser.add_argument("--show_info", type=int, default=1, help="是否显示信息")
|
| 847 |
+
argparser.add_argument(
|
| 848 |
+
"--show_time_info", type=int, default=1, help="是否显示时间信息"
|
| 849 |
+
)
|
| 850 |
+
argparser.add_argument("--seed", type=int, default=42, help="随机种子")
|
| 851 |
+
argparser.add_argument("--parallel", type=int, default=1, help="是否使用多卡")
|
| 852 |
+
argparser.add_argument(
|
| 853 |
+
"--prompt_bs", type=int, default=32, help="提示点推理时的batch size大小"
|
| 854 |
+
)
|
| 855 |
+
argparser.add_argument("--clean_mesh", type=int, default=1, help="是否清洗网格")
|
| 856 |
+
args = argparser.parse_args()
|
| 857 |
+
Timer.STATE = args.show_time_info
|
| 858 |
+
|
| 859 |
+
output_path = args.output_path
|
| 860 |
+
os.makedirs(output_path, exist_ok=True)
|
| 861 |
+
ckpt_path = args.ckpt_path
|
| 862 |
+
auto_mask = AutoMask(ckpt_path)
|
| 863 |
+
mesh_path = args.mesh_path
|
| 864 |
+
if os.path.isdir(mesh_path):
|
| 865 |
+
for file in os.listdir(mesh_path):
|
| 866 |
+
if not (
|
| 867 |
+
file.endswith(".glb") or file.endswith(".obj") or file.endswith(".ply")
|
| 868 |
+
):
|
| 869 |
+
continue
|
| 870 |
+
_mesh_path = os.path.join(mesh_path, file)
|
| 871 |
+
_output_path = os.path.join(output_path, file[:-4])
|
| 872 |
+
os.makedirs(_output_path, exist_ok=True)
|
| 873 |
+
mesh = trimesh.load(_mesh_path, force="mesh")
|
| 874 |
+
set_seed(args.seed)
|
| 875 |
+
aabb, face_ids, mesh = auto_mask.predict_aabb(
|
| 876 |
+
mesh,
|
| 877 |
+
save_path=_output_path,
|
| 878 |
+
point_num=args.point_num,
|
| 879 |
+
prompt_num=args.prompt_num,
|
| 880 |
+
threshold=args.threshold,
|
| 881 |
+
post_process=args.post_process,
|
| 882 |
+
save_mid_res=args.save_mid_res,
|
| 883 |
+
show_info=args.show_info,
|
| 884 |
+
seed=args.seed,
|
| 885 |
+
is_parallel=args.parallel,
|
| 886 |
+
clean_mesh_flag=args.clean_mesh,
|
| 887 |
+
)
|
| 888 |
+
else:
|
| 889 |
+
mesh = trimesh.load(mesh_path, force="mesh")
|
| 890 |
+
set_seed(args.seed)
|
| 891 |
+
aabb, face_ids, mesh = auto_mask.predict_aabb(
|
| 892 |
+
mesh,
|
| 893 |
+
save_path=output_path,
|
| 894 |
+
point_num=args.point_num,
|
| 895 |
+
prompt_num=args.prompt_num,
|
| 896 |
+
threshold=args.threshold,
|
| 897 |
+
post_process=args.post_process,
|
| 898 |
+
save_mid_res=args.save_mid_res,
|
| 899 |
+
show_info=args.show_info,
|
| 900 |
+
seed=args.seed,
|
| 901 |
+
is_parallel=args.parallel,
|
| 902 |
+
clean_mesh_flag=args.clean_mesh,
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
###############################################
|
| 906 |
+
## 可以通过以下代码保存返回的结果
|
| 907 |
+
## You can save the returned result by the following code
|
| 908 |
+
################# save result #################
|
| 909 |
+
# color_map = {}
|
| 910 |
+
# unique_ids = np.unique(face_ids)
|
| 911 |
+
# for i in unique_ids:
|
| 912 |
+
# if i == -1:
|
| 913 |
+
# continue
|
| 914 |
+
# part_color = np.random.rand(3) * 255
|
| 915 |
+
# color_map[i] = part_color
|
| 916 |
+
# face_colors = []
|
| 917 |
+
# for i in face_ids:
|
| 918 |
+
# if i == -1:
|
| 919 |
+
# face_colors.append([0, 0, 0])
|
| 920 |
+
# else:
|
| 921 |
+
# face_colors.append(color_map[i])
|
| 922 |
+
# face_colors = np.array(face_colors).astype(np.uint8)
|
| 923 |
+
# mesh_save = mesh.copy()
|
| 924 |
+
# mesh_save.visual.face_colors = face_colors
|
| 925 |
+
# mesh_save.export(os.path.join(output_path, 'auto_mask_mesh.glb'))
|
| 926 |
+
# scene_mesh = trimesh.Scene()
|
| 927 |
+
# scene_mesh.add_geometry(mesh_save)
|
| 928 |
+
# for i in range(len(aabb)):
|
| 929 |
+
# min_xyz, max_xyz = aabb[i]
|
| 930 |
+
# center = (min_xyz + max_xyz) / 2
|
| 931 |
+
# size = max_xyz - min_xyz
|
| 932 |
+
# box = trimesh.path.creation.box_outline()
|
| 933 |
+
# box.vertices *= size
|
| 934 |
+
# box.vertices += center
|
| 935 |
+
# scene_mesh.add_geometry(box)
|
| 936 |
+
# scene_mesh.export(os.path.join(output_path, 'auto_mask_aabb.glb'))
|
| 937 |
+
################# save result #################
|
| 938 |
+
|
| 939 |
+
"""
|
| 940 |
+
python auto_mask_no_postprocess.py --parallel 0
|
| 941 |
+
python auto_mask_no_postprocess.py --ckpt_path ../weights/p3sam.ckpt --mesh_path assets/1.glb --output_path results/1 --parallel 0
|
| 942 |
+
python auto_mask_no_postprocess.py --ckpt_path ../weights/p3sam.ckpt --mesh_path assets --output_path results/all_no_postprocess
|
| 943 |
+
"""
|
P3-SAM/model.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'XPart/partgen'))
|
| 6 |
+
from models import sonata
|
| 7 |
+
from utils.misc import smart_load_model
|
| 8 |
+
|
| 9 |
+
'''
|
| 10 |
+
This is the P3-SAM model.
|
| 11 |
+
The model is composed of three parts:
|
| 12 |
+
1. Sonata: a 3D-CNN model for point cloud feature extraction.
|
| 13 |
+
2. SEG1+SEG2: a two-stage multi-head segmentor
|
| 14 |
+
3. IoU prediction: an IoU predictor
|
| 15 |
+
'''
|
| 16 |
+
def build_P3SAM(self):
|
| 17 |
+
######################## Sonata ########################
|
| 18 |
+
self.sonata = sonata.load("sonata", repo_id="facebook/sonata", download_root='/root/sonata')
|
| 19 |
+
self.mlp = nn.Sequential(
|
| 20 |
+
nn.Linear(1232, 512),
|
| 21 |
+
nn.GELU(),
|
| 22 |
+
nn.Linear(512, 512),
|
| 23 |
+
nn.GELU(),
|
| 24 |
+
nn.Linear(512, 512),
|
| 25 |
+
)
|
| 26 |
+
self.transform = sonata.transform.default()
|
| 27 |
+
######################## Sonata ########################
|
| 28 |
+
|
| 29 |
+
######################## SEG1 ########################
|
| 30 |
+
self.seg_mlp_1 = nn.Sequential(
|
| 31 |
+
nn.Linear(512+3+3, 512),
|
| 32 |
+
nn.GELU(),
|
| 33 |
+
nn.Linear(512, 512),
|
| 34 |
+
nn.GELU(),
|
| 35 |
+
nn.Linear(512, 1),
|
| 36 |
+
)
|
| 37 |
+
self.seg_mlp_2 = nn.Sequential(
|
| 38 |
+
nn.Linear(512+3+3, 512),
|
| 39 |
+
nn.GELU(),
|
| 40 |
+
nn.Linear(512, 512),
|
| 41 |
+
nn.GELU(),
|
| 42 |
+
nn.Linear(512, 1),
|
| 43 |
+
)
|
| 44 |
+
self.seg_mlp_3 = nn.Sequential(
|
| 45 |
+
nn.Linear(512+3+3, 512),
|
| 46 |
+
nn.GELU(),
|
| 47 |
+
nn.Linear(512, 512),
|
| 48 |
+
nn.GELU(),
|
| 49 |
+
nn.Linear(512, 1),
|
| 50 |
+
)
|
| 51 |
+
######################## SEG1 ########################
|
| 52 |
+
|
| 53 |
+
######################## SEG2 ########################
|
| 54 |
+
self.seg_s2_mlp_g = nn.Sequential(
|
| 55 |
+
nn.Linear(512+3+3+3, 256),
|
| 56 |
+
nn.GELU(),
|
| 57 |
+
nn.Linear(256, 256),
|
| 58 |
+
nn.GELU(),
|
| 59 |
+
nn.Linear(256, 256),
|
| 60 |
+
)
|
| 61 |
+
self.seg_s2_mlp_1 = nn.Sequential(
|
| 62 |
+
nn.Linear(512+3+3+3+256, 256),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
nn.Linear(256, 256),
|
| 65 |
+
nn.GELU(),
|
| 66 |
+
nn.Linear(256, 1),
|
| 67 |
+
)
|
| 68 |
+
self.seg_s2_mlp_2 = nn.Sequential(
|
| 69 |
+
nn.Linear(512+3+3+3+256, 256),
|
| 70 |
+
nn.GELU(),
|
| 71 |
+
nn.Linear(256, 256),
|
| 72 |
+
nn.GELU(),
|
| 73 |
+
nn.Linear(256, 1),
|
| 74 |
+
)
|
| 75 |
+
self.seg_s2_mlp_3 = nn.Sequential(
|
| 76 |
+
nn.Linear(512+3+3+3+256, 256),
|
| 77 |
+
nn.GELU(),
|
| 78 |
+
nn.Linear(256, 256),
|
| 79 |
+
nn.GELU(),
|
| 80 |
+
nn.Linear(256, 1),
|
| 81 |
+
)
|
| 82 |
+
######################## SEG2 ########################
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
self.iou_mlp = nn.Sequential(
|
| 86 |
+
nn.Linear(512+3+3+3+256, 256),
|
| 87 |
+
nn.GELU(),
|
| 88 |
+
nn.Linear(256, 256),
|
| 89 |
+
nn.GELU(),
|
| 90 |
+
nn.Linear(256, 256),
|
| 91 |
+
)
|
| 92 |
+
self.iou_mlp_out = nn.Sequential(
|
| 93 |
+
nn.Linear(256, 256),
|
| 94 |
+
nn.GELU(),
|
| 95 |
+
nn.Linear(256, 256),
|
| 96 |
+
nn.GELU(),
|
| 97 |
+
nn.Linear(256, 3),
|
| 98 |
+
)
|
| 99 |
+
self.iou_criterion = torch.nn.MSELoss()
|
| 100 |
+
|
| 101 |
+
'''
|
| 102 |
+
Load the P3-SAM model from a checkpoint.
|
| 103 |
+
If ckpt_path is not None, load the checkpoint from the given path.
|
| 104 |
+
If state_dict is not None, load the state_dict from the given state_dict.
|
| 105 |
+
If both ckpt_path and state_dict are None, download the model from huggingface and load the checkpoint.
|
| 106 |
+
'''
|
| 107 |
+
def load_state_dict(self,
|
| 108 |
+
ckpt_path=None,
|
| 109 |
+
state_dict=None,
|
| 110 |
+
strict=True,
|
| 111 |
+
assign=False,
|
| 112 |
+
ignore_seg_mlp=False,
|
| 113 |
+
ignore_seg_s2_mlp=False,
|
| 114 |
+
ignore_iou_mlp=False):
|
| 115 |
+
if ckpt_path is not None:
|
| 116 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 117 |
+
elif state_dict is None:
|
| 118 |
+
# download from huggingface
|
| 119 |
+
print(f'trying to download model from huggingface...')
|
| 120 |
+
from huggingface_hub import hf_hub_download
|
| 121 |
+
ckpt_path = hf_hub_download(repo_id="tencent/Hunyuan3D-Part", filename="p3sam.ckpt", local_dir='/cache/P3-SAM/')
|
| 122 |
+
print(f'download model from huggingface to: {ckpt_path}')
|
| 123 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 124 |
+
|
| 125 |
+
local_state_dict = self.state_dict()
|
| 126 |
+
seen_keys = {k: False for k in local_state_dict.keys()}
|
| 127 |
+
for k, v in state_dict.items():
|
| 128 |
+
if k.startswith("dit."):
|
| 129 |
+
k = k[4:]
|
| 130 |
+
if k in local_state_dict:
|
| 131 |
+
seen_keys[k] = True
|
| 132 |
+
if local_state_dict[k].shape == v.shape:
|
| 133 |
+
local_state_dict[k].copy_(v)
|
| 134 |
+
else:
|
| 135 |
+
print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
|
| 136 |
+
else:
|
| 137 |
+
print(f"unexpected key {k} in loaded state dict")
|
| 138 |
+
seg_mlp_flag = False
|
| 139 |
+
seg_s2_mlp_flag = False
|
| 140 |
+
iou_mlp_flag = False
|
| 141 |
+
for k in seen_keys:
|
| 142 |
+
if not seen_keys[k]:
|
| 143 |
+
if ignore_seg_mlp and 'seg_mlp' in k:
|
| 144 |
+
seg_mlp_flag = True
|
| 145 |
+
elif ignore_seg_s2_mlp and'seg_s2_mlp' in k:
|
| 146 |
+
seg_s2_mlp_flag = True
|
| 147 |
+
elif ignore_iou_mlp and 'iou_mlp' in k:
|
| 148 |
+
iou_mlp_flag = True
|
| 149 |
+
else:
|
| 150 |
+
print(f"missing key {k} in loaded state dict")
|
| 151 |
+
if ignore_seg_mlp and seg_mlp_flag:
|
| 152 |
+
print("seg_mlp is missing in loaded state dict, ignore seg_mlp in loaded state dict")
|
| 153 |
+
if ignore_seg_s2_mlp and seg_s2_mlp_flag:
|
| 154 |
+
print("seg_s2_mlp is missing in loaded state dict, ignore seg_s2_mlp in loaded state dict")
|
| 155 |
+
if ignore_iou_mlp and iou_mlp_flag:
|
| 156 |
+
print("iou_mlp is missing in loaded state dict, ignore iou_mlp in loaded state dict")
|
P3-SAM/utils/chamfer3D/chamfer3D.cu
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
#include <stdio.h>
|
| 3 |
+
#include <ATen/ATen.h>
|
| 4 |
+
|
| 5 |
+
#include <cuda.h>
|
| 6 |
+
#include <cuda_runtime.h>
|
| 7 |
+
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
|
| 13 |
+
const int batch=512;
|
| 14 |
+
__shared__ float buf[batch*3];
|
| 15 |
+
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
| 16 |
+
for (int k2=0;k2<m;k2+=batch){
|
| 17 |
+
int end_k=min(m,k2+batch)-k2;
|
| 18 |
+
for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
|
| 19 |
+
buf[j]=xyz2[(i*m+k2)*3+j];
|
| 20 |
+
}
|
| 21 |
+
__syncthreads();
|
| 22 |
+
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
| 23 |
+
float x1=xyz[(i*n+j)*3+0];
|
| 24 |
+
float y1=xyz[(i*n+j)*3+1];
|
| 25 |
+
float z1=xyz[(i*n+j)*3+2];
|
| 26 |
+
int best_i=0;
|
| 27 |
+
float best=0;
|
| 28 |
+
int end_ka=end_k-(end_k&3);
|
| 29 |
+
if (end_ka==batch){
|
| 30 |
+
for (int k=0;k<batch;k+=4){
|
| 31 |
+
{
|
| 32 |
+
float x2=buf[k*3+0]-x1;
|
| 33 |
+
float y2=buf[k*3+1]-y1;
|
| 34 |
+
float z2=buf[k*3+2]-z1;
|
| 35 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 36 |
+
if (k==0 || d<best){
|
| 37 |
+
best=d;
|
| 38 |
+
best_i=k+k2;
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
{
|
| 42 |
+
float x2=buf[k*3+3]-x1;
|
| 43 |
+
float y2=buf[k*3+4]-y1;
|
| 44 |
+
float z2=buf[k*3+5]-z1;
|
| 45 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 46 |
+
if (d<best){
|
| 47 |
+
best=d;
|
| 48 |
+
best_i=k+k2+1;
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
{
|
| 52 |
+
float x2=buf[k*3+6]-x1;
|
| 53 |
+
float y2=buf[k*3+7]-y1;
|
| 54 |
+
float z2=buf[k*3+8]-z1;
|
| 55 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 56 |
+
if (d<best){
|
| 57 |
+
best=d;
|
| 58 |
+
best_i=k+k2+2;
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
{
|
| 62 |
+
float x2=buf[k*3+9]-x1;
|
| 63 |
+
float y2=buf[k*3+10]-y1;
|
| 64 |
+
float z2=buf[k*3+11]-z1;
|
| 65 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 66 |
+
if (d<best){
|
| 67 |
+
best=d;
|
| 68 |
+
best_i=k+k2+3;
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
}else{
|
| 73 |
+
for (int k=0;k<end_ka;k+=4){
|
| 74 |
+
{
|
| 75 |
+
float x2=buf[k*3+0]-x1;
|
| 76 |
+
float y2=buf[k*3+1]-y1;
|
| 77 |
+
float z2=buf[k*3+2]-z1;
|
| 78 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 79 |
+
if (k==0 || d<best){
|
| 80 |
+
best=d;
|
| 81 |
+
best_i=k+k2;
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
{
|
| 85 |
+
float x2=buf[k*3+3]-x1;
|
| 86 |
+
float y2=buf[k*3+4]-y1;
|
| 87 |
+
float z2=buf[k*3+5]-z1;
|
| 88 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 89 |
+
if (d<best){
|
| 90 |
+
best=d;
|
| 91 |
+
best_i=k+k2+1;
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
{
|
| 95 |
+
float x2=buf[k*3+6]-x1;
|
| 96 |
+
float y2=buf[k*3+7]-y1;
|
| 97 |
+
float z2=buf[k*3+8]-z1;
|
| 98 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 99 |
+
if (d<best){
|
| 100 |
+
best=d;
|
| 101 |
+
best_i=k+k2+2;
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
{
|
| 105 |
+
float x2=buf[k*3+9]-x1;
|
| 106 |
+
float y2=buf[k*3+10]-y1;
|
| 107 |
+
float z2=buf[k*3+11]-z1;
|
| 108 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 109 |
+
if (d<best){
|
| 110 |
+
best=d;
|
| 111 |
+
best_i=k+k2+3;
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
for (int k=end_ka;k<end_k;k++){
|
| 117 |
+
float x2=buf[k*3+0]-x1;
|
| 118 |
+
float y2=buf[k*3+1]-y1;
|
| 119 |
+
float z2=buf[k*3+2]-z1;
|
| 120 |
+
float d=x2*x2+y2*y2+z2*z2;
|
| 121 |
+
if (k==0 || d<best){
|
| 122 |
+
best=d;
|
| 123 |
+
best_i=k+k2;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
if (k2==0 || result[(i*n+j)]>best){
|
| 127 |
+
result[(i*n+j)]=best;
|
| 128 |
+
result_i[(i*n+j)]=best_i;
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
__syncthreads();
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
|
| 136 |
+
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
|
| 137 |
+
|
| 138 |
+
const auto batch_size = xyz1.size(0);
|
| 139 |
+
const auto n = xyz1.size(1); //num_points point cloud A
|
| 140 |
+
const auto m = xyz2.size(1); //num_points point cloud B
|
| 141 |
+
|
| 142 |
+
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
|
| 143 |
+
NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
|
| 144 |
+
|
| 145 |
+
cudaError_t err = cudaGetLastError();
|
| 146 |
+
if (err != cudaSuccess) {
|
| 147 |
+
printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
|
| 148 |
+
//THError("aborting");
|
| 149 |
+
return 0;
|
| 150 |
+
}
|
| 151 |
+
return 1;
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
}
|
| 155 |
+
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
|
| 156 |
+
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
| 157 |
+
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
| 158 |
+
float x1=xyz1[(i*n+j)*3+0];
|
| 159 |
+
float y1=xyz1[(i*n+j)*3+1];
|
| 160 |
+
float z1=xyz1[(i*n+j)*3+2];
|
| 161 |
+
int j2=idx1[i*n+j];
|
| 162 |
+
float x2=xyz2[(i*m+j2)*3+0];
|
| 163 |
+
float y2=xyz2[(i*m+j2)*3+1];
|
| 164 |
+
float z2=xyz2[(i*m+j2)*3+2];
|
| 165 |
+
float g=grad_dist1[i*n+j]*2;
|
| 166 |
+
atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
|
| 167 |
+
atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
|
| 168 |
+
atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
|
| 169 |
+
atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
|
| 170 |
+
atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
|
| 171 |
+
atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
|
| 176 |
+
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
|
| 177 |
+
// cudaMemset(grad_xyz1,0,b*n*3*4);
|
| 178 |
+
// cudaMemset(grad_xyz2,0,b*m*3*4);
|
| 179 |
+
|
| 180 |
+
const auto batch_size = xyz1.size(0);
|
| 181 |
+
const auto n = xyz1.size(1); //num_points point cloud A
|
| 182 |
+
const auto m = xyz2.size(1); //num_points point cloud B
|
| 183 |
+
|
| 184 |
+
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
|
| 185 |
+
NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
|
| 186 |
+
|
| 187 |
+
cudaError_t err = cudaGetLastError();
|
| 188 |
+
if (err != cudaSuccess) {
|
| 189 |
+
printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
|
| 190 |
+
//THError("aborting");
|
| 191 |
+
return 0;
|
| 192 |
+
}
|
| 193 |
+
return 1;
|
| 194 |
+
|
| 195 |
+
}
|
| 196 |
+
|
P3-SAM/utils/chamfer3D/chamfer_cuda.cpp
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/torch.h>
|
| 2 |
+
#include <vector>
|
| 3 |
+
|
| 4 |
+
/// TMP
|
| 5 |
+
// #include "common.h"
|
| 6 |
+
/// NOT TMP
|
| 7 |
+
|
| 8 |
+
int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2,
|
| 9 |
+
at::Tensor idx1, at::Tensor idx2);
|
| 10 |
+
|
| 11 |
+
int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1,
|
| 12 |
+
at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2,
|
| 13 |
+
at::Tensor idx1, at::Tensor idx2);
|
| 14 |
+
|
| 15 |
+
int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2,
|
| 16 |
+
at::Tensor idx1, at::Tensor idx2) {
|
| 17 |
+
return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2,
|
| 21 |
+
at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
|
| 22 |
+
|
| 23 |
+
return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 27 |
+
m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
|
| 28 |
+
m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
|
| 29 |
+
}
|
P3-SAM/utils/chamfer3D/dist_chamfer_3D.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from torch.autograd import Function
|
| 3 |
+
import torch
|
| 4 |
+
import importlib
|
| 5 |
+
import os
|
| 6 |
+
chamfer_found = importlib.find_loader("chamfer_3D") is not None
|
| 7 |
+
if not chamfer_found:
|
| 8 |
+
## Cool trick from https://github.com/chrdiller
|
| 9 |
+
print("Jitting Chamfer 3D")
|
| 10 |
+
cur_path = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
+
build_path = cur_path.replace('chamfer3D', 'tmp')
|
| 12 |
+
os.makedirs(build_path, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
from torch.utils.cpp_extension import load
|
| 15 |
+
chamfer_3D = load(name="chamfer_3D",
|
| 16 |
+
sources=[
|
| 17 |
+
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
|
| 18 |
+
"/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
|
| 19 |
+
], build_directory=build_path)
|
| 20 |
+
print("Loaded JIT 3D CUDA chamfer distance")
|
| 21 |
+
|
| 22 |
+
else:
|
| 23 |
+
import chamfer_3D
|
| 24 |
+
print("Loaded compiled 3D CUDA chamfer distance")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Chamfer's distance module @thibaultgroueix
|
| 28 |
+
# GPU tensors only
|
| 29 |
+
class chamfer_3DFunction(Function):
|
| 30 |
+
@staticmethod
|
| 31 |
+
def forward(ctx, xyz1, xyz2):
|
| 32 |
+
batchsize, n, dim = xyz1.size()
|
| 33 |
+
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
| 34 |
+
_, m, dim = xyz2.size()
|
| 35 |
+
assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
|
| 36 |
+
device = xyz1.device
|
| 37 |
+
|
| 38 |
+
device = xyz1.device
|
| 39 |
+
|
| 40 |
+
dist1 = torch.zeros(batchsize, n)
|
| 41 |
+
dist2 = torch.zeros(batchsize, m)
|
| 42 |
+
|
| 43 |
+
idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
|
| 44 |
+
idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
|
| 45 |
+
|
| 46 |
+
dist1 = dist1.to(device)
|
| 47 |
+
dist2 = dist2.to(device)
|
| 48 |
+
idx1 = idx1.to(device)
|
| 49 |
+
idx2 = idx2.to(device)
|
| 50 |
+
torch.cuda.set_device(device)
|
| 51 |
+
|
| 52 |
+
chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
| 53 |
+
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
|
| 54 |
+
return dist1, dist2, idx1, idx2
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
|
| 58 |
+
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
|
| 59 |
+
graddist1 = graddist1.contiguous()
|
| 60 |
+
graddist2 = graddist2.contiguous()
|
| 61 |
+
device = graddist1.device
|
| 62 |
+
|
| 63 |
+
gradxyz1 = torch.zeros(xyz1.size())
|
| 64 |
+
gradxyz2 = torch.zeros(xyz2.size())
|
| 65 |
+
|
| 66 |
+
gradxyz1 = gradxyz1.to(device)
|
| 67 |
+
gradxyz2 = gradxyz2.to(device)
|
| 68 |
+
chamfer_3D.backward(
|
| 69 |
+
xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
|
| 70 |
+
)
|
| 71 |
+
return gradxyz1, gradxyz2
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class chamfer_3DDist(nn.Module):
|
| 75 |
+
def __init__(self):
|
| 76 |
+
super(chamfer_3DDist, self).__init__()
|
| 77 |
+
|
| 78 |
+
def forward(self, input1, input2):
|
| 79 |
+
input1 = input1.contiguous()
|
| 80 |
+
input2 = input2.contiguous()
|
| 81 |
+
return chamfer_3DFunction.apply(input1, input2)
|
P3-SAM/utils/chamfer3D/setup.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup
|
| 2 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 3 |
+
|
| 4 |
+
setup(
|
| 5 |
+
name='chamfer_3D',
|
| 6 |
+
ext_modules=[
|
| 7 |
+
CUDAExtension('chamfer_3D', [
|
| 8 |
+
"/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
|
| 9 |
+
"/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
|
| 10 |
+
]),
|
| 11 |
+
],
|
| 12 |
+
cmdclass={
|
| 13 |
+
'build_ext': BuildExtension
|
| 14 |
+
})
|
XPart/data/000.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e1b728ba92d353d87c5acd2289d9f19a0dc9ab6ceacdde488e7c5d3b456b2ff8
|
| 3 |
+
size 9000484
|
XPart/data/001.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98601f642c444d8466007b5b35e33a57d3f0bade873c9a7d7c57db039ea7a318
|
| 3 |
+
size 9000676
|
XPart/data/002.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5dbd8408e37e41be78358bbca5bc1ef4d81e6413c92a19bc9c2f136760c0248a
|
| 3 |
+
size 8999812
|
XPart/data/003.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ad8390bdafd83b8aea0f7e0159b24d9a8070c58140cd1706b80c6217fe8cf14d
|
| 3 |
+
size 9000880
|
XPart/data/004.glb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fb21265f371436d18437a9cb420e012fe07cf1e50fb20c1473223c921a101dc
|
| 3 |
+
size 9000796
|
XPart/partgen/bbox_estimator/auto_mask_api.py
ADDED
|
@@ -0,0 +1,1417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
import argparse
|
| 7 |
+
import trimesh
|
| 8 |
+
from sklearn.decomposition import PCA
|
| 9 |
+
import fpsample
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
|
| 13 |
+
# from tqdm.notebook import tqdm
|
| 14 |
+
import time
|
| 15 |
+
import copy
|
| 16 |
+
import shutil
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 19 |
+
|
| 20 |
+
from numba import njit
|
| 21 |
+
|
| 22 |
+
#################################
|
| 23 |
+
# 修改sonata import路径
|
| 24 |
+
from ..models import sonata
|
| 25 |
+
|
| 26 |
+
#################################
|
| 27 |
+
sys.path.append("../P3-SAM")
|
| 28 |
+
from model import build_P3SAM, load_state_dict
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class YSAM(nn.Module):
|
| 32 |
+
def __init__(self):
|
| 33 |
+
super().__init__()
|
| 34 |
+
build_P3SAM(self)
|
| 35 |
+
|
| 36 |
+
def load_state_dict(
|
| 37 |
+
self,
|
| 38 |
+
state_dict=None,
|
| 39 |
+
strict=True,
|
| 40 |
+
assign=False,
|
| 41 |
+
ignore_seg_mlp=False,
|
| 42 |
+
ignore_seg_s2_mlp=False,
|
| 43 |
+
ignore_iou_mlp=False,
|
| 44 |
+
):
|
| 45 |
+
load_state_dict(
|
| 46 |
+
self,
|
| 47 |
+
state_dict=state_dict,
|
| 48 |
+
strict=strict,
|
| 49 |
+
assign=assign,
|
| 50 |
+
ignore_seg_mlp=ignore_seg_mlp,
|
| 51 |
+
ignore_seg_s2_mlp=ignore_seg_s2_mlp,
|
| 52 |
+
ignore_iou_mlp=ignore_iou_mlp,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, feats, points, point_prompt, iter=1):
|
| 56 |
+
"""
|
| 57 |
+
feats: [K, N, 512]
|
| 58 |
+
points: [K, N, 3]
|
| 59 |
+
point_prompt: [K, N, 3]
|
| 60 |
+
"""
|
| 61 |
+
# print(feats.shape, points.shape, point_prompt.shape)
|
| 62 |
+
point_num = points.shape[1]
|
| 63 |
+
feats = feats.transpose(0, 1) # [N, K, 512]
|
| 64 |
+
points = points.transpose(0, 1) # [N, K, 3]
|
| 65 |
+
point_prompt = point_prompt.transpose(0, 1) # [N, K, 3]
|
| 66 |
+
feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3]
|
| 67 |
+
|
| 68 |
+
# 预测mask stage-1
|
| 69 |
+
pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K]
|
| 70 |
+
pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K]
|
| 71 |
+
pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K]
|
| 72 |
+
pred_mask = torch.stack(
|
| 73 |
+
[pred_mask_1, pred_mask_2, pred_mask_3], dim=-1
|
| 74 |
+
) # [N, K, 3]
|
| 75 |
+
|
| 76 |
+
for _ in range(iter):
|
| 77 |
+
# 预测mask stage-2
|
| 78 |
+
feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3]
|
| 79 |
+
feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512]
|
| 80 |
+
feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512]
|
| 81 |
+
feats_seg_global = feats_seg_global.unsqueeze(0).repeat(
|
| 82 |
+
point_num, 1, 1
|
| 83 |
+
) # [N, K, 512]
|
| 84 |
+
feats_seg_3 = torch.cat(
|
| 85 |
+
[feats_seg_global, feats_seg_2], dim=-1
|
| 86 |
+
) # [N, K, 512+3+3+3+512]
|
| 87 |
+
pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K]
|
| 88 |
+
pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K]
|
| 89 |
+
pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K]
|
| 90 |
+
pred_mask_s2 = torch.stack(
|
| 91 |
+
[pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1
|
| 92 |
+
) # [N,, K 3]
|
| 93 |
+
pred_mask = pred_mask_s2
|
| 94 |
+
|
| 95 |
+
mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K]
|
| 96 |
+
mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K]
|
| 97 |
+
mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K]
|
| 98 |
+
|
| 99 |
+
feats_iou = torch.cat(
|
| 100 |
+
[feats_seg_global, feats_seg, pred_mask_s2], dim=-1
|
| 101 |
+
) # [N, K, 512+3+3+3+512]
|
| 102 |
+
feats_iou = self.iou_mlp(feats_iou) # [N, K, 512]
|
| 103 |
+
feats_iou = torch.max(feats_iou, dim=0).values # [K, 512]
|
| 104 |
+
pred_iou = self.iou_mlp_out(feats_iou) # [K, 3]
|
| 105 |
+
pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3]
|
| 106 |
+
|
| 107 |
+
mask_1 = mask_1.transpose(0, 1) # [K, N]
|
| 108 |
+
mask_2 = mask_2.transpose(0, 1) # [K, N]
|
| 109 |
+
mask_3 = mask_3.transpose(0, 1) # [K, N]
|
| 110 |
+
|
| 111 |
+
return mask_1, mask_2, mask_3, pred_iou
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def normalize_pc(pc):
|
| 115 |
+
"""
|
| 116 |
+
pc: (N, 3)
|
| 117 |
+
"""
|
| 118 |
+
max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0)
|
| 119 |
+
center = (max_ + min_) / 2
|
| 120 |
+
scale = (max_ - min_) / 2
|
| 121 |
+
scale = np.max(np.abs(scale))
|
| 122 |
+
pc = (pc - center) / (scale + 1e-10)
|
| 123 |
+
return pc
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def get_feat(model, points, normals):
|
| 128 |
+
data_dict = {
|
| 129 |
+
"coord": points,
|
| 130 |
+
"normal": normals,
|
| 131 |
+
"color": np.ones_like(points),
|
| 132 |
+
"batch": np.zeros(points.shape[0], dtype=np.int64),
|
| 133 |
+
}
|
| 134 |
+
data_dict = model.transform(data_dict)
|
| 135 |
+
for k in data_dict:
|
| 136 |
+
if isinstance(data_dict[k], torch.Tensor):
|
| 137 |
+
data_dict[k] = data_dict[k].cuda()
|
| 138 |
+
point = model.sonata(data_dict)
|
| 139 |
+
while "pooling_parent" in point.keys():
|
| 140 |
+
assert "pooling_inverse" in point.keys()
|
| 141 |
+
parent = point.pop("pooling_parent")
|
| 142 |
+
inverse = point.pop("pooling_inverse")
|
| 143 |
+
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
|
| 144 |
+
point = parent
|
| 145 |
+
feat = point.feat # [M, 1232]
|
| 146 |
+
feat = model.mlp(feat) # [M, 512]
|
| 147 |
+
feat = feat[point.inverse] # [N, 512]
|
| 148 |
+
feats = feat
|
| 149 |
+
return feats
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@torch.no_grad()
|
| 153 |
+
def get_mask(model, feats, points, point_prompt, iter=1):
|
| 154 |
+
"""
|
| 155 |
+
feats: [N, 512]
|
| 156 |
+
points: [N, 3]
|
| 157 |
+
point_prompt: [K, 3]
|
| 158 |
+
"""
|
| 159 |
+
point_num = points.shape[0]
|
| 160 |
+
prompt_num = point_prompt.shape[0]
|
| 161 |
+
feats = feats.unsqueeze(1) # [N, 1, 512]
|
| 162 |
+
feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512]
|
| 163 |
+
points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3]
|
| 164 |
+
points = points.repeat(1, prompt_num, 1) # [N, K, 3]
|
| 165 |
+
prompt_coord = (
|
| 166 |
+
torch.from_numpy(point_prompt).float().cuda().unsqueeze(0)
|
| 167 |
+
) # [1, K, 3]
|
| 168 |
+
prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3]
|
| 169 |
+
|
| 170 |
+
feats = feats.transpose(0, 1) # [K, N, 512]
|
| 171 |
+
points = points.transpose(0, 1) # [K, N, 3]
|
| 172 |
+
prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3]
|
| 173 |
+
|
| 174 |
+
mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter)
|
| 175 |
+
|
| 176 |
+
mask_1 = mask_1.transpose(0, 1) # [N, K]
|
| 177 |
+
mask_2 = mask_2.transpose(0, 1) # [N, K]
|
| 178 |
+
mask_3 = mask_3.transpose(0, 1) # [N, K]
|
| 179 |
+
|
| 180 |
+
mask_1 = mask_1.detach().cpu().numpy() > 0.5
|
| 181 |
+
mask_2 = mask_2.detach().cpu().numpy() > 0.5
|
| 182 |
+
mask_3 = mask_3.detach().cpu().numpy() > 0.5
|
| 183 |
+
|
| 184 |
+
org_iou = pred_iou.detach().cpu().numpy() # [K, 3]
|
| 185 |
+
|
| 186 |
+
return mask_1, mask_2, mask_3, org_iou
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def cal_iou(m1, m2):
|
| 190 |
+
return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def cal_single_iou(m1, m2):
|
| 194 |
+
return np.sum(np.logical_and(m1, m2)) / np.sum(m1)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def iou_3d(box1, box2, signle=None):
|
| 198 |
+
"""
|
| 199 |
+
计算两个三维边界框的交并比 (IoU)
|
| 200 |
+
|
| 201 |
+
参数:
|
| 202 |
+
box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max]
|
| 203 |
+
box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max]
|
| 204 |
+
|
| 205 |
+
返回:
|
| 206 |
+
float: 交并比 (IoU) 值
|
| 207 |
+
"""
|
| 208 |
+
# 计算交集的坐标
|
| 209 |
+
intersection_xmin = max(box1[0], box2[0])
|
| 210 |
+
intersection_ymin = max(box1[1], box2[1])
|
| 211 |
+
intersection_zmin = max(box1[2], box2[2])
|
| 212 |
+
intersection_xmax = min(box1[3], box2[3])
|
| 213 |
+
intersection_ymax = min(box1[4], box2[4])
|
| 214 |
+
intersection_zmax = min(box1[5], box2[5])
|
| 215 |
+
|
| 216 |
+
# 判断是否有交集
|
| 217 |
+
if (
|
| 218 |
+
intersection_xmin >= intersection_xmax
|
| 219 |
+
or intersection_ymin >= intersection_ymax
|
| 220 |
+
or intersection_zmin >= intersection_zmax
|
| 221 |
+
):
|
| 222 |
+
return 0.0 # 无交集
|
| 223 |
+
|
| 224 |
+
# 计算交集的体积
|
| 225 |
+
intersection_volume = (
|
| 226 |
+
(intersection_xmax - intersection_xmin)
|
| 227 |
+
* (intersection_ymax - intersection_ymin)
|
| 228 |
+
* (intersection_zmax - intersection_zmin)
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# 计算两个盒子的体积
|
| 232 |
+
box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2])
|
| 233 |
+
box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2])
|
| 234 |
+
|
| 235 |
+
if signle is None:
|
| 236 |
+
# 计算并集的体积
|
| 237 |
+
union_volume = box1_volume + box2_volume - intersection_volume
|
| 238 |
+
elif signle == "1":
|
| 239 |
+
union_volume = box1_volume
|
| 240 |
+
elif signle == "2":
|
| 241 |
+
union_volume = box2_volume
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError("signle must be None or 1 or 2")
|
| 244 |
+
|
| 245 |
+
# 计算 IoU
|
| 246 |
+
iou = intersection_volume / union_volume if union_volume > 0 else 0.0
|
| 247 |
+
return iou
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def cal_point_bbox_iou(p1, p2, signle=None):
|
| 251 |
+
min_p1 = np.min(p1, axis=0)
|
| 252 |
+
max_p1 = np.max(p1, axis=0)
|
| 253 |
+
min_p2 = np.min(p2, axis=0)
|
| 254 |
+
max_p2 = np.max(p2, axis=0)
|
| 255 |
+
box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]]
|
| 256 |
+
box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]]
|
| 257 |
+
return iou_3d(box1, box2, signle)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def cal_bbox_iou(points, m1, m2):
|
| 261 |
+
p1 = points[m1]
|
| 262 |
+
p2 = points[m2]
|
| 263 |
+
return cal_point_bbox_iou(p1, p2)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def clean_mesh(mesh):
|
| 267 |
+
"""
|
| 268 |
+
mesh: trimesh.Trimesh
|
| 269 |
+
"""
|
| 270 |
+
# 1. 合并接近的顶点
|
| 271 |
+
mesh.merge_vertices()
|
| 272 |
+
|
| 273 |
+
# 2. 删除重复的顶点
|
| 274 |
+
# 3. 删除重复的面片
|
| 275 |
+
mesh.process(True)
|
| 276 |
+
return mesh
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# @njit
|
| 280 |
+
def remove_outliers_iqr(data, factor=1.5):
|
| 281 |
+
"""
|
| 282 |
+
基于 IQR 去除离群值
|
| 283 |
+
:param data: 输入的列表或 NumPy 数组
|
| 284 |
+
:param factor: IQR 的倍数(默认 1.5)
|
| 285 |
+
:return: 去除离群值后的列表
|
| 286 |
+
"""
|
| 287 |
+
data = np.array(data, dtype=np.float32)
|
| 288 |
+
q1 = np.percentile(data, 25) # 第一四分位数
|
| 289 |
+
q3 = np.percentile(data, 75) # 第三四分位数
|
| 290 |
+
iqr = q3 - q1 # 四分位距
|
| 291 |
+
lower_bound = q1 - factor * iqr
|
| 292 |
+
upper_bound = q3 + factor * iqr
|
| 293 |
+
return data[(data >= lower_bound) & (data <= upper_bound)].tolist()
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# @njit
|
| 297 |
+
def better_aabb(points):
|
| 298 |
+
x = points[:, 0]
|
| 299 |
+
y = points[:, 1]
|
| 300 |
+
z = points[:, 2]
|
| 301 |
+
x = remove_outliers_iqr(x)
|
| 302 |
+
y = remove_outliers_iqr(y)
|
| 303 |
+
z = remove_outliers_iqr(z)
|
| 304 |
+
min_xyz = np.array([np.min(x), np.min(y), np.min(z)])
|
| 305 |
+
max_xyz = np.array([np.max(x), np.max(y), np.max(z)])
|
| 306 |
+
return [min_xyz, max_xyz]
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def fix_label(face_ids, adjacent_faces, use_aabb=False, mesh=None, show_info=False):
|
| 310 |
+
if use_aabb:
|
| 311 |
+
|
| 312 |
+
def _cal_aabb(face_ids, i, _points_org):
|
| 313 |
+
_part_mask = face_ids == i
|
| 314 |
+
_faces = mesh.faces[_part_mask]
|
| 315 |
+
_faces = np.reshape(_faces, (-1))
|
| 316 |
+
_points = mesh.vertices[_faces]
|
| 317 |
+
min_xyz, max_xyz = better_aabb(_points)
|
| 318 |
+
_part_mask = (
|
| 319 |
+
(_points_org[:, 0] >= min_xyz[0])
|
| 320 |
+
& (_points_org[:, 0] <= max_xyz[0])
|
| 321 |
+
& (_points_org[:, 1] >= min_xyz[1])
|
| 322 |
+
& (_points_org[:, 1] <= max_xyz[1])
|
| 323 |
+
& (_points_org[:, 2] >= min_xyz[2])
|
| 324 |
+
& (_points_org[:, 2] <= max_xyz[2])
|
| 325 |
+
)
|
| 326 |
+
_part_mask = np.reshape(_part_mask, (-1, 3))
|
| 327 |
+
_part_mask = np.all(_part_mask, axis=1)
|
| 328 |
+
return i, [min_xyz, max_xyz], _part_mask
|
| 329 |
+
|
| 330 |
+
with Timer("计算aabb"):
|
| 331 |
+
aabb = {}
|
| 332 |
+
unique_ids = np.unique(face_ids)
|
| 333 |
+
# print(max(unique_ids))
|
| 334 |
+
aabb_face_mask = {}
|
| 335 |
+
_faces = mesh.faces
|
| 336 |
+
_vertices = mesh.vertices
|
| 337 |
+
_faces = np.reshape(_faces, (-1))
|
| 338 |
+
_points = _vertices[_faces]
|
| 339 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
| 340 |
+
futures = []
|
| 341 |
+
for i in unique_ids:
|
| 342 |
+
if i < 0:
|
| 343 |
+
continue
|
| 344 |
+
futures.append(executor.submit(_cal_aabb, face_ids, i, _points))
|
| 345 |
+
for future in futures:
|
| 346 |
+
res = future.result()
|
| 347 |
+
aabb[res[0]] = res[1]
|
| 348 |
+
aabb_face_mask[res[0]] = res[2]
|
| 349 |
+
|
| 350 |
+
# _faces = mesh.faces
|
| 351 |
+
# _vertices = mesh.vertices
|
| 352 |
+
# _faces = np.reshape(_faces, (-1))
|
| 353 |
+
# _points = _vertices[_faces]
|
| 354 |
+
# aabb_face_mask = cal_aabb_mask(_points, face_ids)
|
| 355 |
+
|
| 356 |
+
with Timer("合并mesh"):
|
| 357 |
+
loop_cnt = 1
|
| 358 |
+
changed = True
|
| 359 |
+
progress = tqdm(disable=not show_info)
|
| 360 |
+
no_mask_ids = np.where(face_ids < 0)[0].tolist()
|
| 361 |
+
faces_max = adjacent_faces.shape[0]
|
| 362 |
+
while changed and loop_cnt <= 50:
|
| 363 |
+
changed = False
|
| 364 |
+
# 获取无色面片
|
| 365 |
+
new_no_mask_ids = []
|
| 366 |
+
for i in no_mask_ids:
|
| 367 |
+
# if face_ids[i] < 0:
|
| 368 |
+
# 找邻居
|
| 369 |
+
if not (0 <= i < faces_max):
|
| 370 |
+
continue
|
| 371 |
+
_adj_faces = adjacent_faces[i]
|
| 372 |
+
_adj_ids = []
|
| 373 |
+
for j in _adj_faces:
|
| 374 |
+
if j == -1:
|
| 375 |
+
break
|
| 376 |
+
if face_ids[j] >= 0:
|
| 377 |
+
_tar_id = face_ids[j]
|
| 378 |
+
if use_aabb:
|
| 379 |
+
_mask = aabb_face_mask[_tar_id]
|
| 380 |
+
if _mask[i]:
|
| 381 |
+
_adj_ids.append(_tar_id)
|
| 382 |
+
else:
|
| 383 |
+
_adj_ids.append(_tar_id)
|
| 384 |
+
if len(_adj_ids) == 0:
|
| 385 |
+
new_no_mask_ids.append(i)
|
| 386 |
+
continue
|
| 387 |
+
_max_id = np.argmax(np.bincount(_adj_ids))
|
| 388 |
+
face_ids[i] = _max_id
|
| 389 |
+
changed = True
|
| 390 |
+
no_mask_ids = new_no_mask_ids
|
| 391 |
+
# print(loop_cnt)
|
| 392 |
+
progress.update(1)
|
| 393 |
+
# progress.set_description(f"合并mesh循环:{loop_cnt} {np.sum(face_ids < 0)}")
|
| 394 |
+
loop_cnt += 1
|
| 395 |
+
return face_ids
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def save_mesh(save_path, mesh, face_ids, color_map):
|
| 399 |
+
face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8)
|
| 400 |
+
for i in tqdm(range(len(mesh.faces)), disable=True):
|
| 401 |
+
_max_id = face_ids[i]
|
| 402 |
+
if _max_id == -2:
|
| 403 |
+
continue
|
| 404 |
+
face_colors[i, :3] = color_map[_max_id]
|
| 405 |
+
|
| 406 |
+
mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
|
| 407 |
+
mesh_save.visual.face_colors = face_colors
|
| 408 |
+
mesh_save.export(save_path)
|
| 409 |
+
mesh_save.export(save_path.replace(".glb", ".ply"))
|
| 410 |
+
# print('保存mesh完成')
|
| 411 |
+
|
| 412 |
+
scene_mesh = trimesh.Scene()
|
| 413 |
+
scene_mesh.add_geometry(mesh_save)
|
| 414 |
+
unique_ids = np.unique(face_ids)
|
| 415 |
+
aabb = []
|
| 416 |
+
for i in unique_ids:
|
| 417 |
+
if i == -1 or i == -2:
|
| 418 |
+
continue
|
| 419 |
+
_part_mask = face_ids == i
|
| 420 |
+
_faces = mesh.faces[_part_mask]
|
| 421 |
+
_faces = np.reshape(_faces, (-1))
|
| 422 |
+
_points = mesh.vertices[_faces]
|
| 423 |
+
min_xyz, max_xyz = better_aabb(_points)
|
| 424 |
+
center = (min_xyz + max_xyz) / 2
|
| 425 |
+
size = max_xyz - min_xyz
|
| 426 |
+
box = trimesh.path.creation.box_outline()
|
| 427 |
+
box.vertices *= size
|
| 428 |
+
box.vertices += center
|
| 429 |
+
box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]])
|
| 430 |
+
box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8)
|
| 431 |
+
box.colors = box_color
|
| 432 |
+
scene_mesh.add_geometry(box)
|
| 433 |
+
min_xyz = np.min(_points, axis=0)
|
| 434 |
+
max_xyz = np.max(_points, axis=0)
|
| 435 |
+
aabb.append([min_xyz, max_xyz])
|
| 436 |
+
scene_mesh.export(save_path.replace(".glb", "_aabb.glb"))
|
| 437 |
+
aabb = np.array(aabb)
|
| 438 |
+
np.save(save_path.replace(".glb", "_aabb.npy"), aabb)
|
| 439 |
+
np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def get_aabb_from_face_ids(mesh, face_ids):
|
| 443 |
+
unique_ids = np.unique(face_ids)
|
| 444 |
+
aabb = []
|
| 445 |
+
for i in unique_ids:
|
| 446 |
+
if i == -1 or i == -2:
|
| 447 |
+
continue
|
| 448 |
+
_part_mask = face_ids == i
|
| 449 |
+
_faces = mesh.faces[_part_mask]
|
| 450 |
+
_faces = np.reshape(_faces, (-1))
|
| 451 |
+
_points = mesh.vertices[_faces]
|
| 452 |
+
min_xyz = np.min(_points, axis=0)
|
| 453 |
+
max_xyz = np.max(_points, axis=0)
|
| 454 |
+
aabb.append([min_xyz, max_xyz])
|
| 455 |
+
return np.array(aabb)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def calculate_face_areas(mesh):
|
| 459 |
+
"""
|
| 460 |
+
计算每个三角形面片的面积
|
| 461 |
+
:param mesh: trimesh.Trimesh 对象
|
| 462 |
+
:return: 面片面积数组 (n_faces,)
|
| 463 |
+
"""
|
| 464 |
+
return mesh.area_faces
|
| 465 |
+
# # 提取顶点和面片索引
|
| 466 |
+
# vertices = mesh.vertices
|
| 467 |
+
# faces = mesh.faces
|
| 468 |
+
|
| 469 |
+
# # 获取所有三个顶点的坐标
|
| 470 |
+
# v0 = vertices[faces[:, 0]]
|
| 471 |
+
# v1 = vertices[faces[:, 1]]
|
| 472 |
+
# v2 = vertices[faces[:, 2]]
|
| 473 |
+
|
| 474 |
+
# # 计算两个边向量
|
| 475 |
+
# edge1 = v1 - v0
|
| 476 |
+
# edge2 = v2 - v0
|
| 477 |
+
|
| 478 |
+
# # 计算叉积的模长(向量面积的两倍)
|
| 479 |
+
# cross_product = np.cross(edge1, edge2)
|
| 480 |
+
# areas = 0.5 * np.linalg.norm(cross_product, axis=1)
|
| 481 |
+
|
| 482 |
+
# return areas
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def get_connected_region(face_ids, adjacent_faces, return_face_part_ids=False):
|
| 486 |
+
vis = [False] * len(face_ids)
|
| 487 |
+
parts = []
|
| 488 |
+
face_part_ids = np.ones_like(face_ids) * -1
|
| 489 |
+
for i in range(len(face_ids)):
|
| 490 |
+
if vis[i]:
|
| 491 |
+
continue
|
| 492 |
+
_part = []
|
| 493 |
+
_queue = [i]
|
| 494 |
+
while len(_queue) > 0:
|
| 495 |
+
_cur_face = _queue.pop(0)
|
| 496 |
+
if vis[_cur_face]:
|
| 497 |
+
continue
|
| 498 |
+
vis[_cur_face] = True
|
| 499 |
+
_part.append(_cur_face)
|
| 500 |
+
face_part_ids[_cur_face] = len(parts)
|
| 501 |
+
if not (0 <= _cur_face < adjacent_faces.shape[0]):
|
| 502 |
+
continue
|
| 503 |
+
_cur_face_id = face_ids[_cur_face]
|
| 504 |
+
_adj_faces = adjacent_faces[_cur_face]
|
| 505 |
+
for j in _adj_faces:
|
| 506 |
+
if j == -1:
|
| 507 |
+
break
|
| 508 |
+
if not vis[j] and face_ids[j] == _cur_face_id:
|
| 509 |
+
_queue.append(j)
|
| 510 |
+
parts.append(_part)
|
| 511 |
+
if return_face_part_ids:
|
| 512 |
+
return parts, face_part_ids
|
| 513 |
+
else:
|
| 514 |
+
return parts
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def aabb_distance(box1, box2):
|
| 518 |
+
"""
|
| 519 |
+
计算两个轴对齐包围盒(AABB)之间的最近距离。
|
| 520 |
+
:param box1: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
|
| 521 |
+
:param box2: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
|
| 522 |
+
:return: 最近距离(浮点数)
|
| 523 |
+
"""
|
| 524 |
+
# 解包坐标
|
| 525 |
+
min1, max1 = box1
|
| 526 |
+
min2, max2 = box2
|
| 527 |
+
|
| 528 |
+
# 计算各轴上的分离距离
|
| 529 |
+
dx = max(0, max2[0] - min1[0], max1[0] - min2[0]) # x轴分离距离
|
| 530 |
+
dy = max(0, max2[1] - min1[1], max1[1] - min2[1]) # y轴分离距离
|
| 531 |
+
dz = max(0, max2[2] - min1[2], max1[2] - min2[2]) # z轴分离距离
|
| 532 |
+
|
| 533 |
+
# 如果所有轴都重叠,则距离为0
|
| 534 |
+
if dx == 0 and dy == 0 and dz == 0:
|
| 535 |
+
return 0.0
|
| 536 |
+
|
| 537 |
+
# 计算欧几里得距离
|
| 538 |
+
return np.sqrt(dx**2 + dy**2 + dz**2)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def aabb_volume(aabb):
|
| 542 |
+
"""
|
| 543 |
+
计算轴对齐包围盒(AABB)的体积。
|
| 544 |
+
:param aabb: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
|
| 545 |
+
:return: 体积(浮点数)
|
| 546 |
+
"""
|
| 547 |
+
# 解包坐标
|
| 548 |
+
min_xyz, max_xyz = aabb
|
| 549 |
+
|
| 550 |
+
# 计算体积
|
| 551 |
+
dx = max_xyz[0] - min_xyz[0]
|
| 552 |
+
dy = max_xyz[1] - min_xyz[1]
|
| 553 |
+
dz = max_xyz[2] - min_xyz[2]
|
| 554 |
+
return dx * dy * dz
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def find_neighbor_part(parts, adjacent_faces, parts_aabb=None, parts_ids=None):
|
| 558 |
+
face2part = {}
|
| 559 |
+
for i, part in enumerate(parts):
|
| 560 |
+
for face in part:
|
| 561 |
+
face2part[face] = i
|
| 562 |
+
neighbor_parts = []
|
| 563 |
+
for i, part in enumerate(parts):
|
| 564 |
+
neighbor_part = set()
|
| 565 |
+
for face in part:
|
| 566 |
+
if not (0 <= face < adjacent_faces.shape[0]):
|
| 567 |
+
continue
|
| 568 |
+
for adj_face in adjacent_faces[face]:
|
| 569 |
+
if adj_face == -1:
|
| 570 |
+
break
|
| 571 |
+
if adj_face not in face2part:
|
| 572 |
+
continue
|
| 573 |
+
if face2part[adj_face] == i:
|
| 574 |
+
continue
|
| 575 |
+
if parts_ids is not None and parts_ids[face2part[adj_face]] in [-1, -2]:
|
| 576 |
+
continue
|
| 577 |
+
neighbor_part.add(face2part[adj_face])
|
| 578 |
+
neighbor_part = list(neighbor_part)
|
| 579 |
+
if (
|
| 580 |
+
parts_aabb is not None
|
| 581 |
+
and parts_ids is not None
|
| 582 |
+
and (parts_ids[i] == -1 or parts_ids[i] == -2)
|
| 583 |
+
and len(neighbor_part) == 0
|
| 584 |
+
):
|
| 585 |
+
min_dis = np.inf
|
| 586 |
+
min_idx = -1
|
| 587 |
+
for j, _part in enumerate(parts):
|
| 588 |
+
if j == i:
|
| 589 |
+
continue
|
| 590 |
+
if parts_ids[j] == -1 or parts_ids[j] == -2:
|
| 591 |
+
continue
|
| 592 |
+
aabb_1 = parts_aabb[i]
|
| 593 |
+
aabb_2 = parts_aabb[j]
|
| 594 |
+
dis = aabb_distance(aabb_1, aabb_2)
|
| 595 |
+
if dis < min_dis:
|
| 596 |
+
min_dis = dis
|
| 597 |
+
min_idx = j
|
| 598 |
+
elif dis == min_dis:
|
| 599 |
+
if aabb_volume(parts_aabb[j]) < aabb_volume(parts_aabb[min_idx]):
|
| 600 |
+
min_idx = j
|
| 601 |
+
neighbor_part = [min_idx]
|
| 602 |
+
neighbor_parts.append(neighbor_part)
|
| 603 |
+
return neighbor_parts
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def do_post_process(
|
| 607 |
+
face_areas, parts, adjacent_faces, face_ids, threshold=0.95, show_info=False
|
| 608 |
+
):
|
| 609 |
+
# # 获取邻接面片
|
| 610 |
+
# mesh_save = mesh.copy()
|
| 611 |
+
# face_adjacency = mesh.face_adjacency
|
| 612 |
+
# adjacent_faces = {}
|
| 613 |
+
# for face1, face2 in face_adjacency:
|
| 614 |
+
# if face1 not in adjacent_faces:
|
| 615 |
+
# adjacent_faces[face1] = []
|
| 616 |
+
# if face2 not in adjacent_faces:
|
| 617 |
+
# adjacent_faces[face2] = []
|
| 618 |
+
# adjacent_faces[face1].append(face2)
|
| 619 |
+
# adjacent_faces[face2].append(face1)
|
| 620 |
+
|
| 621 |
+
# parts = get_connected_region(face_ids, adjacent_faces)
|
| 622 |
+
|
| 623 |
+
unique_ids = np.unique(face_ids)
|
| 624 |
+
if show_info:
|
| 625 |
+
print(f"连通区域数量:{len(parts)}")
|
| 626 |
+
print(f"ID数量:{len(unique_ids)}")
|
| 627 |
+
|
| 628 |
+
# face_areas = calculate_face_areas(mesh)
|
| 629 |
+
total_area = np.sum(face_areas)
|
| 630 |
+
if show_info:
|
| 631 |
+
print(f"总面积:{total_area}")
|
| 632 |
+
part_areas = []
|
| 633 |
+
for i, part in enumerate(parts):
|
| 634 |
+
part_area = np.sum(face_areas[part])
|
| 635 |
+
part_areas.append(float(part_area / total_area))
|
| 636 |
+
|
| 637 |
+
sorted_parts = sorted(zip(part_areas, parts), key=lambda x: x[0], reverse=True)
|
| 638 |
+
parts = [x[1] for x in sorted_parts]
|
| 639 |
+
part_areas = [x[0] for x in sorted_parts]
|
| 640 |
+
integral_part_areas = np.cumsum(part_areas)
|
| 641 |
+
|
| 642 |
+
neighbor_parts = find_neighbor_part(parts, adjacent_faces)
|
| 643 |
+
|
| 644 |
+
new_face_ids = face_ids.copy()
|
| 645 |
+
|
| 646 |
+
for i, part in enumerate(parts):
|
| 647 |
+
if integral_part_areas[i] > threshold and part_areas[i] < 0.01:
|
| 648 |
+
if len(neighbor_parts[i]) > 0:
|
| 649 |
+
max_area = 0
|
| 650 |
+
max_part = -1
|
| 651 |
+
for j in neighbor_parts[i]:
|
| 652 |
+
if integral_part_areas[j] > threshold:
|
| 653 |
+
continue
|
| 654 |
+
if part_areas[j] > max_area:
|
| 655 |
+
max_area = part_areas[j]
|
| 656 |
+
max_part = j
|
| 657 |
+
if max_part != -1:
|
| 658 |
+
if show_info:
|
| 659 |
+
print(f"合并mesh:{i} {max_part}")
|
| 660 |
+
parts[max_part].extend(part)
|
| 661 |
+
parts[i] = []
|
| 662 |
+
target_face_id = face_ids[parts[max_part][0]]
|
| 663 |
+
for face in part:
|
| 664 |
+
new_face_ids[face] = target_face_id
|
| 665 |
+
|
| 666 |
+
return new_face_ids
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def do_no_mask_process(parts, face_ids):
|
| 670 |
+
# # 获取邻接面片
|
| 671 |
+
# mesh_save = mesh.copy()
|
| 672 |
+
# face_adjacency = mesh.face_adjacency
|
| 673 |
+
# adjacent_faces = {}
|
| 674 |
+
# for face1, face2 in face_adjacency:
|
| 675 |
+
# if face1 not in adjacent_faces:
|
| 676 |
+
# adjacent_faces[face1] = []
|
| 677 |
+
# if face2 not in adjacent_faces:
|
| 678 |
+
# adjacent_faces[face2] = []
|
| 679 |
+
# adjacent_faces[face1].append(face2)
|
| 680 |
+
# adjacent_faces[face2].append(face1)
|
| 681 |
+
# parts = get_connected_region(face_ids, adjacent_faces)
|
| 682 |
+
|
| 683 |
+
unique_ids = np.unique(face_ids)
|
| 684 |
+
max_id = np.max(unique_ids)
|
| 685 |
+
if -1 or -2 in unique_ids:
|
| 686 |
+
new_face_ids = face_ids.copy()
|
| 687 |
+
for i, part in enumerate(parts):
|
| 688 |
+
if face_ids[part[0]] == -1 or face_ids[part[0]] == -2:
|
| 689 |
+
for face in part:
|
| 690 |
+
new_face_ids[face] = max_id + 1
|
| 691 |
+
max_id += 1
|
| 692 |
+
return new_face_ids
|
| 693 |
+
else:
|
| 694 |
+
return face_ids
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def union_aabb(aabb1, aabb2):
|
| 698 |
+
min_xyz1 = aabb1[0]
|
| 699 |
+
max_xyz1 = aabb1[1]
|
| 700 |
+
min_xyz2 = aabb2[0]
|
| 701 |
+
max_xyz2 = aabb2[1]
|
| 702 |
+
min_xyz = np.minimum(min_xyz1, min_xyz2)
|
| 703 |
+
max_xyz = np.maximum(max_xyz1, max_xyz2)
|
| 704 |
+
return [min_xyz, max_xyz]
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def aabb_increase(aabb1, aabb2):
|
| 708 |
+
min_xyz_before = aabb1[0]
|
| 709 |
+
max_xyz_before = aabb1[1]
|
| 710 |
+
min_xyz_after, max_xyz_after = union_aabb(aabb1, aabb2)
|
| 711 |
+
min_xyz_increase = np.abs(min_xyz_after - min_xyz_before) / np.abs(min_xyz_before)
|
| 712 |
+
max_xyz_increase = np.abs(max_xyz_after - max_xyz_before) / np.abs(max_xyz_before)
|
| 713 |
+
return min_xyz_increase, max_xyz_increase
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def sort_multi_list(multi_list, key=lambda x: x[0], reverse=False):
|
| 717 |
+
"""
|
| 718 |
+
multi_list: [list1, list2, list3, list4, ...], len(list1)=N, len(list2)=N, len(list3)=N, ...
|
| 719 |
+
key: 排序函数,默认按第一个元素排序
|
| 720 |
+
reverse: 排序顺序,默认降序
|
| 721 |
+
return:
|
| 722 |
+
[list1, list2, list3, list4, ...]: 按同一个顺序排序后的多个list
|
| 723 |
+
"""
|
| 724 |
+
sorted_list = sorted(zip(*multi_list), key=key, reverse=reverse)
|
| 725 |
+
return zip(*sorted_list)
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
# def sample_mesh(mesh, adjacent_faces, point_num=100000):
|
| 729 |
+
# connected_parts = get_connected_region(np.ones(len(mesh.faces)), adjacent_faces)
|
| 730 |
+
# _points, face_idx = trimesh.sample.sample_surface(mesh, point_num)
|
| 731 |
+
# face_sampled = np.zeros(len(mesh.faces), dtype=np.bool)
|
| 732 |
+
# face_sampled[face_idx] = True
|
| 733 |
+
# for parts in connected_parts
|
| 734 |
+
|
| 735 |
+
# def parallel_run(model_parallel, feats, points, prompts):
|
| 736 |
+
# bs = prompts.shape[0]
|
| 737 |
+
# prompts_1 = prompts[:bs//2]
|
| 738 |
+
# prompts_2 = prompts[bs//2:]
|
| 739 |
+
# device_1 = 'cuda:0'
|
| 740 |
+
# device_2 = 'cuda:1'
|
| 741 |
+
# pred_mask_1_1, pred_mask_2_1, pred_mask_3_1, pred_iou_1 = get_mask(
|
| 742 |
+
# model_parallel.module.to(device_1), feats, points, prompts_1, device=device_1
|
| 743 |
+
# )
|
| 744 |
+
# pred_mask_1_2, pred_mask_2_2, pred_mask_3_2, pred_iou_2 = get_mask(
|
| 745 |
+
# model_parallel.module.to(device_2), feats, points, prompts_2, device=device_2
|
| 746 |
+
# )
|
| 747 |
+
# pred_mask_1 = np.concatenate([pred_mask_1_1, pred_mask_1_2], axis=1)
|
| 748 |
+
# pred_mask_2 = np.concatenate([pred_mask_2_1, pred_mask_2_2], axis=1)
|
| 749 |
+
# pred_mask_3 = np.concatenate([pred_mask_3_1, pred_mask_3_2], axis=1)
|
| 750 |
+
# pred_iou = np.concatenate([pred_iou_1, pred_iou_2], axis=0)
|
| 751 |
+
# return pred_mask_1, pred_mask_2, pred_mask_3, pred_iou
|
| 752 |
+
|
| 753 |
+
############################################################################################
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
class Timer:
|
| 757 |
+
def __init__(self, name):
|
| 758 |
+
self.name = name
|
| 759 |
+
|
| 760 |
+
def __enter__(self):
|
| 761 |
+
self.start_time = time.time()
|
| 762 |
+
return self # 可以返回 self 以便在 with 块内访问
|
| 763 |
+
|
| 764 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 765 |
+
self.end_time = time.time()
|
| 766 |
+
self.elapsed_time = self.end_time - self.start_time
|
| 767 |
+
print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒")
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
###################### NUMBA 加速 ######################
|
| 771 |
+
@njit
|
| 772 |
+
def build_adjacent_faces_numba(face_adjacency):
|
| 773 |
+
"""
|
| 774 |
+
使用 Numba 加速构建邻接面片数组。
|
| 775 |
+
:param face_adjacency: (N, 2) numpy 数组,包含邻接面片对。
|
| 776 |
+
:return:
|
| 777 |
+
- adj_list: 一维数组,存储所有邻接面片。
|
| 778 |
+
- offsets: 一维数组,记录每个面片的邻接起始位置。
|
| 779 |
+
"""
|
| 780 |
+
n_faces = np.max(face_adjacency) + 1 # 总面片数
|
| 781 |
+
n_edges = face_adjacency.shape[0] # 总邻接边数
|
| 782 |
+
|
| 783 |
+
# 第一步:统计每个面片的邻接数量(度数)
|
| 784 |
+
degrees = np.zeros(n_faces, dtype=np.int32)
|
| 785 |
+
for i in range(n_edges):
|
| 786 |
+
f1, f2 = face_adjacency[i]
|
| 787 |
+
degrees[f1] += 1
|
| 788 |
+
degrees[f2] += 1
|
| 789 |
+
max_degree = np.max(degrees) # 最大度数
|
| 790 |
+
|
| 791 |
+
adjacent_faces = np.ones((n_faces, max_degree), dtype=np.int32) * -1 # 邻接面片数组
|
| 792 |
+
adjacent_faces_count = np.zeros(n_faces, dtype=np.int32) # 邻接面片计数器
|
| 793 |
+
for i in range(n_edges):
|
| 794 |
+
f1, f2 = face_adjacency[i]
|
| 795 |
+
adjacent_faces[f1, adjacent_faces_count[f1]] = f2
|
| 796 |
+
adjacent_faces_count[f1] += 1
|
| 797 |
+
adjacent_faces[f2, adjacent_faces_count[f2]] = f1
|
| 798 |
+
adjacent_faces_count[f2] += 1
|
| 799 |
+
return adjacent_faces
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
###################### NUMBA 加速 ######################
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
def mesh_sam(
|
| 806 |
+
model,
|
| 807 |
+
mesh,
|
| 808 |
+
save_path,
|
| 809 |
+
point_num=100000,
|
| 810 |
+
prompt_num=400,
|
| 811 |
+
save_mid_res=False,
|
| 812 |
+
show_info=False,
|
| 813 |
+
post_process=False,
|
| 814 |
+
threshold=0.95,
|
| 815 |
+
clean_mesh_flag=True,
|
| 816 |
+
seed=42,
|
| 817 |
+
):
|
| 818 |
+
with Timer("加载mesh"):
|
| 819 |
+
model, model_parallel = model
|
| 820 |
+
if clean_mesh_flag:
|
| 821 |
+
mesh = clean_mesh(mesh)
|
| 822 |
+
mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=False)
|
| 823 |
+
if show_info:
|
| 824 |
+
print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}")
|
| 825 |
+
|
| 826 |
+
point_num = 100000
|
| 827 |
+
prompt_num = 400
|
| 828 |
+
with Timer("获取邻接面片"):
|
| 829 |
+
# 获取邻接面片
|
| 830 |
+
face_adjacency = mesh.face_adjacency
|
| 831 |
+
with Timer("处理邻接面片"):
|
| 832 |
+
# adjacent_faces = defaultdict(list)
|
| 833 |
+
# for face1, face2 in face_adjacency:
|
| 834 |
+
# adjacent_faces[face1].append(face2)
|
| 835 |
+
# adjacent_faces[face2].append(face1)
|
| 836 |
+
# adj_list, offsets = build_adjacent_faces_numba(face_adjacency)
|
| 837 |
+
adjacent_faces = build_adjacent_faces_numba(face_adjacency)
|
| 838 |
+
# with Timer("处理邻接面片2"):
|
| 839 |
+
# adjacent_faces = to_adj_dict(adj_list, offsets)
|
| 840 |
+
|
| 841 |
+
with Timer("采样点云"):
|
| 842 |
+
_points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed)
|
| 843 |
+
_points_org = _points.copy()
|
| 844 |
+
_points = normalize_pc(_points)
|
| 845 |
+
normals = mesh.face_normals[face_idx]
|
| 846 |
+
# _points = _points + np.random.normal(0, 1, size=_points.shape) * 0.01
|
| 847 |
+
# normals = normals * 0. # debug no normal
|
| 848 |
+
if show_info:
|
| 849 |
+
print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}")
|
| 850 |
+
|
| 851 |
+
with Timer("获取特征"):
|
| 852 |
+
_feats = get_feat(model, _points, normals)
|
| 853 |
+
if show_info:
|
| 854 |
+
print("预处理特征")
|
| 855 |
+
|
| 856 |
+
if save_mid_res:
|
| 857 |
+
feat_save = _feats.float().detach().cpu().numpy()
|
| 858 |
+
data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
|
| 859 |
+
pca = PCA(n_components=3)
|
| 860 |
+
data_reduced = pca.fit_transform(data_scaled)
|
| 861 |
+
data_reduced = (data_reduced - data_reduced.min()) / (
|
| 862 |
+
data_reduced.max() - data_reduced.min()
|
| 863 |
+
)
|
| 864 |
+
_colors_pca = (data_reduced * 255).astype(np.uint8)
|
| 865 |
+
pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca)
|
| 866 |
+
pc_save.export(os.path.join(save_path, "point_pca.glb"))
|
| 867 |
+
pc_save.export(os.path.join(save_path, "point_pca.ply"))
|
| 868 |
+
if show_info:
|
| 869 |
+
print("PCA获取特征颜色")
|
| 870 |
+
|
| 871 |
+
with Timer("FPS采样提示点"):
|
| 872 |
+
fps_idx = fpsample.fps_sampling(_points, prompt_num)
|
| 873 |
+
_point_prompts = _points[fps_idx]
|
| 874 |
+
if save_mid_res:
|
| 875 |
+
trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
|
| 876 |
+
os.path.join(save_path, "point_prompts_pca.glb")
|
| 877 |
+
)
|
| 878 |
+
trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
|
| 879 |
+
os.path.join(save_path, "point_prompts_pca.ply")
|
| 880 |
+
)
|
| 881 |
+
if show_info:
|
| 882 |
+
print("采样完成")
|
| 883 |
+
|
| 884 |
+
with Timer("推理"):
|
| 885 |
+
bs = 64
|
| 886 |
+
step_num = prompt_num // bs + 1
|
| 887 |
+
mask_res = []
|
| 888 |
+
iou_res = []
|
| 889 |
+
for i in tqdm(range(step_num), disable=not show_info):
|
| 890 |
+
cur_propmt = _point_prompts[bs * i : bs * (i + 1)]
|
| 891 |
+
# pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask(
|
| 892 |
+
# model, _feats, _points, cur_propmt
|
| 893 |
+
# )
|
| 894 |
+
# pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = model_parallel(
|
| 895 |
+
# _feats, _points, cur_propmt
|
| 896 |
+
# )
|
| 897 |
+
# pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = parallel_run(
|
| 898 |
+
# model_parallel, _feats, _points, cur_propmt
|
| 899 |
+
# )
|
| 900 |
+
pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask(
|
| 901 |
+
model_parallel, _feats, _points, cur_propmt
|
| 902 |
+
)
|
| 903 |
+
# print(pred_mask_1.shape, pred_mask_2.shape, pred_mask_3.shape, pred_iou.shape)
|
| 904 |
+
pred_mask = np.stack(
|
| 905 |
+
[pred_mask_1, pred_mask_2, pred_mask_3], axis=-1
|
| 906 |
+
) # [N, K, 3]
|
| 907 |
+
max_idx = np.argmax(pred_iou, axis=-1) # [K]
|
| 908 |
+
for j in range(max_idx.shape[0]):
|
| 909 |
+
mask_res.append(pred_mask[:, j, max_idx[j]])
|
| 910 |
+
iou_res.append(pred_iou[j, max_idx[j]])
|
| 911 |
+
mask_res = np.stack(mask_res, axis=-1) # [N, K]
|
| 912 |
+
if show_info:
|
| 913 |
+
print("prmopt 推理完成")
|
| 914 |
+
|
| 915 |
+
with Timer("根据IOU排序"):
|
| 916 |
+
iou_res = np.array(iou_res).tolist()
|
| 917 |
+
mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)]
|
| 918 |
+
mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True)
|
| 919 |
+
mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)]
|
| 920 |
+
iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)]
|
| 921 |
+
|
| 922 |
+
# clusters = {}
|
| 923 |
+
# for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
|
| 924 |
+
# _mask = mask_sorted[i]
|
| 925 |
+
# union_flag = False
|
| 926 |
+
# for j in clusters.keys():
|
| 927 |
+
# if cal_iou(_mask, mask_sorted[j]) > 0.9:
|
| 928 |
+
# clusters[j].append(i)
|
| 929 |
+
# union_flag = True
|
| 930 |
+
# break
|
| 931 |
+
# if not union_flag:
|
| 932 |
+
# clusters[i] = [i]
|
| 933 |
+
with Timer("NMS"):
|
| 934 |
+
clusters = defaultdict(list)
|
| 935 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
| 936 |
+
for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
|
| 937 |
+
_mask = mask_sorted[i]
|
| 938 |
+
futures = []
|
| 939 |
+
for j in clusters.keys():
|
| 940 |
+
futures.append(executor.submit(cal_iou, _mask, mask_sorted[j]))
|
| 941 |
+
|
| 942 |
+
for j, future in zip(clusters.keys(), futures):
|
| 943 |
+
if future.result() > 0.9:
|
| 944 |
+
clusters[j].append(i)
|
| 945 |
+
break
|
| 946 |
+
else:
|
| 947 |
+
clusters[i].append(i)
|
| 948 |
+
|
| 949 |
+
# print(clusters)
|
| 950 |
+
if show_info:
|
| 951 |
+
print(f"NMS完成,mask数量:{len(clusters)}")
|
| 952 |
+
|
| 953 |
+
if save_mid_res:
|
| 954 |
+
part_mask_save_path = os.path.join(save_path, "part_mask")
|
| 955 |
+
if os.path.exists(part_mask_save_path):
|
| 956 |
+
shutil.rmtree(part_mask_save_path)
|
| 957 |
+
os.makedirs(part_mask_save_path, exist_ok=True)
|
| 958 |
+
for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info):
|
| 959 |
+
cluster_num = len(clusters[i])
|
| 960 |
+
cluster_iou = iou_sorted[i]
|
| 961 |
+
cluster_area = np.sum(mask_sorted[i])
|
| 962 |
+
if cluster_num <= 2:
|
| 963 |
+
continue
|
| 964 |
+
mask_save = mask_sorted[i]
|
| 965 |
+
mask_save = np.expand_dims(mask_save, axis=-1)
|
| 966 |
+
mask_save = np.repeat(mask_save, 3, axis=-1)
|
| 967 |
+
mask_save = (mask_save * 255).astype(np.uint8)
|
| 968 |
+
point_save = trimesh.points.PointCloud(_points, colors=mask_save)
|
| 969 |
+
point_save.export(
|
| 970 |
+
os.path.join(
|
| 971 |
+
part_mask_save_path,
|
| 972 |
+
f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
|
| 973 |
+
)
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
# 过滤只有一个mask的cluster
|
| 977 |
+
with Timer("过滤只有一个mask的cluster"):
|
| 978 |
+
filtered_clusters = []
|
| 979 |
+
other_clusters = []
|
| 980 |
+
for i in clusters.keys():
|
| 981 |
+
if len(clusters[i]) > 2:
|
| 982 |
+
filtered_clusters.append(i)
|
| 983 |
+
else:
|
| 984 |
+
other_clusters.append(i)
|
| 985 |
+
if show_info:
|
| 986 |
+
print(
|
| 987 |
+
f"过滤前:{len(clusters)} 个cluster,"
|
| 988 |
+
f"过滤后:{len(filtered_clusters)} 个cluster"
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
# 再次合并
|
| 992 |
+
with Timer("再次合并"):
|
| 993 |
+
filtered_clusters_num = len(filtered_clusters)
|
| 994 |
+
cluster2 = {}
|
| 995 |
+
is_union = [False] * filtered_clusters_num
|
| 996 |
+
for i in range(filtered_clusters_num):
|
| 997 |
+
if is_union[i]:
|
| 998 |
+
continue
|
| 999 |
+
cur_cluster = filtered_clusters[i]
|
| 1000 |
+
cluster2[cur_cluster] = [cur_cluster]
|
| 1001 |
+
for j in range(i + 1, filtered_clusters_num):
|
| 1002 |
+
if is_union[j]:
|
| 1003 |
+
continue
|
| 1004 |
+
tar_cluster = filtered_clusters[j]
|
| 1005 |
+
# if cal_single_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.9:
|
| 1006 |
+
# if cal_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.5:
|
| 1007 |
+
if (
|
| 1008 |
+
cal_bbox_iou(
|
| 1009 |
+
_points, mask_sorted[tar_cluster], mask_sorted[cur_cluster]
|
| 1010 |
+
)
|
| 1011 |
+
> 0.5
|
| 1012 |
+
):
|
| 1013 |
+
cluster2[cur_cluster].append(tar_cluster)
|
| 1014 |
+
is_union[j] = True
|
| 1015 |
+
if show_info:
|
| 1016 |
+
print(f"再次合并,合并数量:{len(cluster2.keys())}")
|
| 1017 |
+
|
| 1018 |
+
with Timer("计算没有mask的点"):
|
| 1019 |
+
no_mask = np.ones(point_num)
|
| 1020 |
+
for i in cluster2:
|
| 1021 |
+
part_mask = mask_sorted[i]
|
| 1022 |
+
no_mask[part_mask] = 0
|
| 1023 |
+
if show_info:
|
| 1024 |
+
print(
|
| 1025 |
+
f"{np.sum(no_mask == 1)} 个点没有mask,"
|
| 1026 |
+
f" 占比:{np.sum(no_mask == 1) / point_num:.4f}"
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
with Timer("修补遗漏mask"):
|
| 1030 |
+
# 查询漏掉的mask
|
| 1031 |
+
for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info):
|
| 1032 |
+
if i in cluster2:
|
| 1033 |
+
continue
|
| 1034 |
+
part_mask = mask_sorted[i]
|
| 1035 |
+
_iou = cal_single_iou(part_mask, no_mask)
|
| 1036 |
+
if _iou > 0.7:
|
| 1037 |
+
cluster2[i] = [i]
|
| 1038 |
+
no_mask[part_mask] = 0
|
| 1039 |
+
if save_mid_res:
|
| 1040 |
+
mask_save = mask_sorted[i]
|
| 1041 |
+
mask_save = np.expand_dims(mask_save, axis=-1)
|
| 1042 |
+
mask_save = np.repeat(mask_save, 3, axis=-1)
|
| 1043 |
+
mask_save = (mask_save * 255).astype(np.uint8)
|
| 1044 |
+
point_save = trimesh.points.PointCloud(_points, colors=mask_save)
|
| 1045 |
+
cluster_iou = iou_sorted[i]
|
| 1046 |
+
cluster_area = int(np.sum(mask_sorted[i]))
|
| 1047 |
+
cluster_num = 1
|
| 1048 |
+
point_save.export(
|
| 1049 |
+
os.path.join(
|
| 1050 |
+
part_mask_save_path,
|
| 1051 |
+
f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
|
| 1052 |
+
)
|
| 1053 |
+
)
|
| 1054 |
+
# print(cluster2)
|
| 1055 |
+
# print(len(cluster2.keys()))
|
| 1056 |
+
if show_info:
|
| 1057 |
+
print(f"修补遗漏mask:{len(cluster2.keys())}")
|
| 1058 |
+
|
| 1059 |
+
with Timer("计算点云最终mask"):
|
| 1060 |
+
final_mask = list(cluster2.keys())
|
| 1061 |
+
final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask]
|
| 1062 |
+
final_mask_area = [
|
| 1063 |
+
[final_mask[i], final_mask_area[i]] for i in range(len(final_mask))
|
| 1064 |
+
]
|
| 1065 |
+
final_mask_area_sorted = sorted(
|
| 1066 |
+
final_mask_area, key=lambda x: x[1], reverse=True
|
| 1067 |
+
)
|
| 1068 |
+
final_mask_sorted = [
|
| 1069 |
+
final_mask_area_sorted[i][0] for i in range(len(final_mask_area))
|
| 1070 |
+
]
|
| 1071 |
+
final_mask_area_sorted = [
|
| 1072 |
+
final_mask_area_sorted[i][1] for i in range(len(final_mask_area))
|
| 1073 |
+
]
|
| 1074 |
+
# print(final_mask_sorted)
|
| 1075 |
+
# print(final_mask_area_sorted)
|
| 1076 |
+
if show_info:
|
| 1077 |
+
print(f"最终mask数量:{len(final_mask_sorted)}")
|
| 1078 |
+
|
| 1079 |
+
with Timer("点云上色"):
|
| 1080 |
+
# 生成color map
|
| 1081 |
+
color_map = {}
|
| 1082 |
+
for i in final_mask_sorted:
|
| 1083 |
+
part_color = np.random.rand(3) * 255
|
| 1084 |
+
color_map[i] = part_color
|
| 1085 |
+
# print(color_map)
|
| 1086 |
+
|
| 1087 |
+
result_mask = -np.ones(point_num, dtype=np.int64)
|
| 1088 |
+
for i in final_mask_sorted:
|
| 1089 |
+
part_mask = mask_sorted[i]
|
| 1090 |
+
result_mask[part_mask] = i
|
| 1091 |
+
if save_mid_res:
|
| 1092 |
+
# 保存点云结果
|
| 1093 |
+
result_colors = np.zeros_like(_colors_pca)
|
| 1094 |
+
for i in final_mask_sorted:
|
| 1095 |
+
part_color = color_map[i]
|
| 1096 |
+
part_mask = mask_sorted[i]
|
| 1097 |
+
result_colors[part_mask, :3] = part_color
|
| 1098 |
+
trimesh.points.PointCloud(_points, colors=result_colors).export(
|
| 1099 |
+
os.path.join(save_path, "auto_mask_cluster.glb")
|
| 1100 |
+
)
|
| 1101 |
+
trimesh.points.PointCloud(_points, colors=result_colors).export(
|
| 1102 |
+
os.path.join(save_path, "auto_mask_cluster.ply")
|
| 1103 |
+
)
|
| 1104 |
+
if show_info:
|
| 1105 |
+
print("保存点云完成")
|
| 1106 |
+
|
| 1107 |
+
with Timer("投影Mesh并统计label"):
|
| 1108 |
+
# 保存mesh结果
|
| 1109 |
+
face_seg_res = {}
|
| 1110 |
+
for i in final_mask_sorted:
|
| 1111 |
+
_part_mask = result_mask == i
|
| 1112 |
+
_face_idx = face_idx[_part_mask]
|
| 1113 |
+
for k in _face_idx:
|
| 1114 |
+
if k not in face_seg_res:
|
| 1115 |
+
face_seg_res[k] = []
|
| 1116 |
+
face_seg_res[k].append(i)
|
| 1117 |
+
_part_mask = result_mask == -1
|
| 1118 |
+
_face_idx = face_idx[_part_mask]
|
| 1119 |
+
for k in _face_idx:
|
| 1120 |
+
if k not in face_seg_res:
|
| 1121 |
+
face_seg_res[k] = []
|
| 1122 |
+
face_seg_res[k].append(-1)
|
| 1123 |
+
|
| 1124 |
+
face_ids = -np.ones(len(mesh.faces), dtype=np.int64) * 2
|
| 1125 |
+
for i in tqdm(face_seg_res, leave=False, disable=True):
|
| 1126 |
+
_seg_ids = np.array(face_seg_res[i])
|
| 1127 |
+
# 获取最多的seg_id
|
| 1128 |
+
_max_id = np.argmax(np.bincount(_seg_ids + 2)) - 2
|
| 1129 |
+
face_ids[i] = _max_id
|
| 1130 |
+
face_ids_org = face_ids.copy()
|
| 1131 |
+
if show_info:
|
| 1132 |
+
print("生成face_ids完成")
|
| 1133 |
+
|
| 1134 |
+
# 获取邻接面片
|
| 1135 |
+
# face_adjacency = mesh.face_adjacency
|
| 1136 |
+
# adjacent_faces = {}
|
| 1137 |
+
# for face1, face2 in face_adjacency:
|
| 1138 |
+
# if face1 not in adjacent_faces:
|
| 1139 |
+
# adjacent_faces[face1] = []
|
| 1140 |
+
# if face2 not in adjacent_faces:
|
| 1141 |
+
# adjacent_faces[face2] = []
|
| 1142 |
+
# adjacent_faces[face1].append(face2)
|
| 1143 |
+
# adjacent_faces[face2].append(face1)
|
| 1144 |
+
|
| 1145 |
+
with Timer("第一次修复face_ids"):
|
| 1146 |
+
face_ids += 1
|
| 1147 |
+
# face_ids = fix_label(face_ids, adjacent_faces, use_aabb=True, mesh=mesh, show_info=show_info)
|
| 1148 |
+
face_ids = fix_label(face_ids, adjacent_faces, mesh=mesh, show_info=show_info)
|
| 1149 |
+
face_ids -= 1
|
| 1150 |
+
if show_info:
|
| 1151 |
+
print("修复face_ids完成")
|
| 1152 |
+
|
| 1153 |
+
color_map[-1] = np.array([255, 0, 0], dtype=np.uint8)
|
| 1154 |
+
|
| 1155 |
+
if save_mid_res:
|
| 1156 |
+
save_mesh(
|
| 1157 |
+
os.path.join(save_path, "auto_mask_mesh.glb"), mesh, face_ids, color_map
|
| 1158 |
+
)
|
| 1159 |
+
save_mesh(
|
| 1160 |
+
os.path.join(save_path, "auto_mask_mesh_org.glb"),
|
| 1161 |
+
mesh,
|
| 1162 |
+
face_ids_org,
|
| 1163 |
+
color_map,
|
| 1164 |
+
)
|
| 1165 |
+
if show_info:
|
| 1166 |
+
print("保存mesh结果完成")
|
| 1167 |
+
|
| 1168 |
+
with Timer("计算连通区域"):
|
| 1169 |
+
face_areas = calculate_face_areas(mesh)
|
| 1170 |
+
mesh_total_area = np.sum(face_areas)
|
| 1171 |
+
parts = get_connected_region(face_ids, adjacent_faces)
|
| 1172 |
+
connected_parts, _face_connected_parts_ids = get_connected_region(
|
| 1173 |
+
np.ones_like(face_ids), adjacent_faces, return_face_part_ids=True
|
| 1174 |
+
)
|
| 1175 |
+
if show_info:
|
| 1176 |
+
print(f"共{len(parts)}个mesh")
|
| 1177 |
+
with Timer("排序连通区域"):
|
| 1178 |
+
parts_cp_idx = []
|
| 1179 |
+
for x in parts:
|
| 1180 |
+
_face_idx = x[0]
|
| 1181 |
+
parts_cp_idx.append(_face_connected_parts_ids[_face_idx])
|
| 1182 |
+
parts_cp_idx = np.array(parts_cp_idx)
|
| 1183 |
+
parts_areas = [float(np.sum(face_areas[x])) for x in parts]
|
| 1184 |
+
connected_parts_areas = [float(np.sum(face_areas[x])) for x in connected_parts]
|
| 1185 |
+
parts_cp_areas = [connected_parts_areas[x] for x in parts_cp_idx]
|
| 1186 |
+
parts_sorted, parts_areas_sorted, parts_cp_areas_sorted = sort_multi_list(
|
| 1187 |
+
[parts, parts_areas, parts_cp_areas], key=lambda x: x[1], reverse=True
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
with Timer("去除面积过小的区域"):
|
| 1191 |
+
filtered_parts = []
|
| 1192 |
+
other_parts = []
|
| 1193 |
+
for i in range(len(parts_sorted)):
|
| 1194 |
+
parts = parts_sorted[i]
|
| 1195 |
+
area = parts_areas_sorted[i]
|
| 1196 |
+
cp_area = parts_cp_areas_sorted[i]
|
| 1197 |
+
if area / (cp_area + 1e-7) > 0.001:
|
| 1198 |
+
filtered_parts.append(i)
|
| 1199 |
+
else:
|
| 1200 |
+
other_parts.append(i)
|
| 1201 |
+
if show_info:
|
| 1202 |
+
print(f"保留{len(filtered_parts)}个mesh, 其他{len(other_parts)}个mesh")
|
| 1203 |
+
|
| 1204 |
+
with Timer("去除面积过小区域的label"):
|
| 1205 |
+
face_ids_2 = face_ids.copy()
|
| 1206 |
+
part_num = len(cluster2.keys())
|
| 1207 |
+
for j in other_parts:
|
| 1208 |
+
parts = parts_sorted[j]
|
| 1209 |
+
for i in parts:
|
| 1210 |
+
face_ids_2[i] = -1
|
| 1211 |
+
|
| 1212 |
+
with Timer("第二次修复face_ids"):
|
| 1213 |
+
face_ids_3 = face_ids_2.copy()
|
| 1214 |
+
# face_ids_3 = fix_label(face_ids_3, adjacent_faces, use_aabb=True, mesh=mesh, show_info=show_info)
|
| 1215 |
+
face_ids_3 = fix_label(
|
| 1216 |
+
face_ids_3, adjacent_faces, mesh=mesh, show_info=show_info
|
| 1217 |
+
)
|
| 1218 |
+
|
| 1219 |
+
if save_mid_res:
|
| 1220 |
+
save_mesh(
|
| 1221 |
+
os.path.join(save_path, "auto_mask_mesh_filtered_2.glb"),
|
| 1222 |
+
mesh,
|
| 1223 |
+
face_ids_3,
|
| 1224 |
+
color_map,
|
| 1225 |
+
)
|
| 1226 |
+
if show_info:
|
| 1227 |
+
print("保存mesh结果完成")
|
| 1228 |
+
|
| 1229 |
+
with Timer("第二次计算连通区域"):
|
| 1230 |
+
parts_2 = get_connected_region(face_ids_3, adjacent_faces)
|
| 1231 |
+
parts_areas_2 = [float(np.sum(face_areas[x])) for x in parts_2]
|
| 1232 |
+
parts_ids_2 = [face_ids_3[x[0]] for x in parts_2]
|
| 1233 |
+
|
| 1234 |
+
with Timer("添加过大的缺失part"):
|
| 1235 |
+
color_map_2 = copy.deepcopy(color_map)
|
| 1236 |
+
max_id = np.max(parts_ids_2)
|
| 1237 |
+
for i in range(len(parts_2)):
|
| 1238 |
+
_parts = parts_2[i]
|
| 1239 |
+
_area = parts_areas_2[i]
|
| 1240 |
+
_parts_id = face_ids_3[_parts[0]]
|
| 1241 |
+
if _area / mesh_total_area > 0.001:
|
| 1242 |
+
if _parts_id == -1 or _parts_id == -2:
|
| 1243 |
+
parts_ids_2[i] = max_id + 1
|
| 1244 |
+
max_id += 1
|
| 1245 |
+
color_map_2[max_id] = np.random.rand(3) * 255
|
| 1246 |
+
if show_info:
|
| 1247 |
+
print(f"新增part {max_id}")
|
| 1248 |
+
# else:
|
| 1249 |
+
# parts_ids_2[i] = -1
|
| 1250 |
+
|
| 1251 |
+
with Timer("赋值新的face_ids"):
|
| 1252 |
+
face_ids_4 = face_ids_3.copy()
|
| 1253 |
+
for i in range(len(parts_2)):
|
| 1254 |
+
_parts = parts_2[i]
|
| 1255 |
+
_parts_id = parts_ids_2[i]
|
| 1256 |
+
for j in _parts:
|
| 1257 |
+
face_ids_4[j] = _parts_id
|
| 1258 |
+
with Timer("计算part和label的aabb"):
|
| 1259 |
+
ids_aabb = {}
|
| 1260 |
+
unique_ids = np.unique(face_ids_4)
|
| 1261 |
+
for i in unique_ids:
|
| 1262 |
+
if i < 0:
|
| 1263 |
+
continue
|
| 1264 |
+
_part_mask = face_ids_4 == i
|
| 1265 |
+
_faces = mesh.faces[_part_mask]
|
| 1266 |
+
_faces = np.reshape(_faces, (-1))
|
| 1267 |
+
_points = mesh.vertices[_faces]
|
| 1268 |
+
min_xyz = np.min(_points, axis=0)
|
| 1269 |
+
max_xyz = np.max(_points, axis=0)
|
| 1270 |
+
ids_aabb[i] = [min_xyz, max_xyz]
|
| 1271 |
+
|
| 1272 |
+
parts_2_aabb = []
|
| 1273 |
+
for i in range(len(parts_2)):
|
| 1274 |
+
_parts = parts_2[i]
|
| 1275 |
+
_faces = mesh.faces[_parts]
|
| 1276 |
+
_faces = np.reshape(_faces, (-1))
|
| 1277 |
+
_points = mesh.vertices[_faces]
|
| 1278 |
+
min_xyz = np.min(_points, axis=0)
|
| 1279 |
+
max_xyz = np.max(_points, axis=0)
|
| 1280 |
+
parts_2_aabb.append([min_xyz, max_xyz])
|
| 1281 |
+
|
| 1282 |
+
with Timer("计算part的邻居"):
|
| 1283 |
+
parts_2_neighbor = find_neighbor_part(
|
| 1284 |
+
parts_2, adjacent_faces, parts_2_aabb, parts_ids_2
|
| 1285 |
+
)
|
| 1286 |
+
with Timer("合并无mask区域"):
|
| 1287 |
+
for i in range(len(parts_2)):
|
| 1288 |
+
_parts = parts_2[i]
|
| 1289 |
+
_ids = parts_ids_2[i]
|
| 1290 |
+
if _ids == -1 or _ids == -2:
|
| 1291 |
+
_cur_aabb = parts_2_aabb[i]
|
| 1292 |
+
_min_aabb_increase = 1e10
|
| 1293 |
+
_min_id = -1
|
| 1294 |
+
for j in parts_2_neighbor[i]:
|
| 1295 |
+
if parts_ids_2[j] == -1 or parts_ids_2[j] == -2:
|
| 1296 |
+
continue
|
| 1297 |
+
_tar_id = parts_ids_2[j]
|
| 1298 |
+
_tar_aabb = ids_aabb[_tar_id]
|
| 1299 |
+
_min_increase, _max_increase = aabb_increase(_tar_aabb, _cur_aabb)
|
| 1300 |
+
_increase = max(np.max(_min_increase), np.max(_max_increase))
|
| 1301 |
+
if _min_aabb_increase > _increase:
|
| 1302 |
+
_min_aabb_increase = _increase
|
| 1303 |
+
_min_id = _tar_id
|
| 1304 |
+
if _min_id >= 0:
|
| 1305 |
+
parts_ids_2[i] = _min_id
|
| 1306 |
+
|
| 1307 |
+
with Timer("再次赋值新的face_ids"):
|
| 1308 |
+
face_ids_4 = face_ids_3.copy()
|
| 1309 |
+
for i in range(len(parts_2)):
|
| 1310 |
+
_parts = parts_2[i]
|
| 1311 |
+
_parts_id = parts_ids_2[i]
|
| 1312 |
+
for j in _parts:
|
| 1313 |
+
face_ids_4[j] = _parts_id
|
| 1314 |
+
|
| 1315 |
+
final_face_ids = face_ids_4
|
| 1316 |
+
if save_mid_res:
|
| 1317 |
+
save_mesh(
|
| 1318 |
+
os.path.join(save_path, "auto_mask_mesh_final.glb"),
|
| 1319 |
+
mesh,
|
| 1320 |
+
face_ids_4,
|
| 1321 |
+
color_map_2,
|
| 1322 |
+
)
|
| 1323 |
+
|
| 1324 |
+
if post_process:
|
| 1325 |
+
parts = get_connected_region(final_face_ids, adjacent_faces)
|
| 1326 |
+
final_face_ids = do_no_mask_process(parts, final_face_ids)
|
| 1327 |
+
face_ids_5 = do_post_process(
|
| 1328 |
+
face_areas,
|
| 1329 |
+
parts,
|
| 1330 |
+
adjacent_faces,
|
| 1331 |
+
face_ids_4,
|
| 1332 |
+
threshold,
|
| 1333 |
+
show_info=show_info,
|
| 1334 |
+
)
|
| 1335 |
+
if save_mid_res:
|
| 1336 |
+
save_mesh(
|
| 1337 |
+
os.path.join(save_path, "auto_mask_mesh_final_post.glb"),
|
| 1338 |
+
mesh,
|
| 1339 |
+
face_ids_5,
|
| 1340 |
+
color_map_2,
|
| 1341 |
+
)
|
| 1342 |
+
final_face_ids = face_ids_5
|
| 1343 |
+
with Timer("计算最后的aabb"):
|
| 1344 |
+
aabb = get_aabb_from_face_ids(mesh, final_face_ids)
|
| 1345 |
+
return aabb, final_face_ids, mesh
|
| 1346 |
+
|
| 1347 |
+
|
| 1348 |
+
class AutoMask:
|
| 1349 |
+
def __init__(
|
| 1350 |
+
self,
|
| 1351 |
+
ckpt_path,
|
| 1352 |
+
point_num=100000,
|
| 1353 |
+
prompt_num=400,
|
| 1354 |
+
threshold=0.95,
|
| 1355 |
+
post_process=True,
|
| 1356 |
+
):
|
| 1357 |
+
"""
|
| 1358 |
+
ckpt_path: str, 模型路径
|
| 1359 |
+
point_num: int, 采样点数量
|
| 1360 |
+
prompt_num: int, 提示数量
|
| 1361 |
+
threshold: float, 阈值
|
| 1362 |
+
post_process: bool, 是否后处理
|
| 1363 |
+
"""
|
| 1364 |
+
self.model = YSAM()
|
| 1365 |
+
self.model.load_state_dict(
|
| 1366 |
+
state_dict=torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 1367 |
+
)
|
| 1368 |
+
self.model.eval()
|
| 1369 |
+
self.model_parallel = torch.nn.DataParallel(self.model)
|
| 1370 |
+
self.model.cuda()
|
| 1371 |
+
self.model_parallel.cuda()
|
| 1372 |
+
self.point_num = point_num
|
| 1373 |
+
self.prompt_num = prompt_num
|
| 1374 |
+
self.threshold = threshold
|
| 1375 |
+
self.post_process = post_process
|
| 1376 |
+
|
| 1377 |
+
def predict_aabb(
|
| 1378 |
+
self,
|
| 1379 |
+
mesh,
|
| 1380 |
+
point_num=None,
|
| 1381 |
+
prompt_num=None,
|
| 1382 |
+
threshold=None,
|
| 1383 |
+
post_process=None,
|
| 1384 |
+
save_path=None,
|
| 1385 |
+
save_mid_res=False,
|
| 1386 |
+
show_info=True,
|
| 1387 |
+
clean_mesh_flag=True,
|
| 1388 |
+
seed=42,
|
| 1389 |
+
):
|
| 1390 |
+
"""
|
| 1391 |
+
Parameters:
|
| 1392 |
+
mesh: trimesh.Trimesh, 输入网格
|
| 1393 |
+
point_num: int, 采样点数量
|
| 1394 |
+
prompt_num: int, 提示数量
|
| 1395 |
+
threshold: float, 阈值
|
| 1396 |
+
post_process: bool, 是否后处理
|
| 1397 |
+
Returns:
|
| 1398 |
+
aabb: np.ndarray, 包围盒
|
| 1399 |
+
face_ids: np.ndarray, 面id
|
| 1400 |
+
"""
|
| 1401 |
+
point_num = point_num if point_num is not None else self.point_num
|
| 1402 |
+
prompt_num = prompt_num if prompt_num is not None else self.prompt_num
|
| 1403 |
+
threshold = threshold if threshold is not None else self.threshold
|
| 1404 |
+
post_process = post_process if post_process is not None else self.post_process
|
| 1405 |
+
return mesh_sam(
|
| 1406 |
+
[self.model, self.model_parallel],
|
| 1407 |
+
mesh,
|
| 1408 |
+
save_path=save_path,
|
| 1409 |
+
point_num=point_num,
|
| 1410 |
+
prompt_num=prompt_num,
|
| 1411 |
+
threshold=threshold,
|
| 1412 |
+
post_process=post_process,
|
| 1413 |
+
show_info=show_info,
|
| 1414 |
+
save_mid_res=save_mid_res,
|
| 1415 |
+
clean_mesh_flag=clean_mesh_flag,
|
| 1416 |
+
seed=seed,
|
| 1417 |
+
)
|
XPart/partgen/config/infer.yaml
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: "Xpart Pipeline release"
|
| 2 |
+
|
| 3 |
+
ckpt_path: checkpoints/xpart.pt
|
| 4 |
+
|
| 5 |
+
shapevae:
|
| 6 |
+
target: partgen.models.autoencoders.VolumeDecoderShapeVAE
|
| 7 |
+
params:
|
| 8 |
+
num_latents: &num_latents 1024
|
| 9 |
+
embed_dim: 64
|
| 10 |
+
num_freqs: 8
|
| 11 |
+
include_pi: false
|
| 12 |
+
heads: 16
|
| 13 |
+
width: 1024
|
| 14 |
+
num_encoder_layers: 8
|
| 15 |
+
num_decoder_layers: 16
|
| 16 |
+
qkv_bias: false
|
| 17 |
+
qk_norm: true
|
| 18 |
+
scale_factor: &z_scale_factor 1.0039506158752403
|
| 19 |
+
geo_decoder_mlp_expand_ratio: 4
|
| 20 |
+
geo_decoder_downsample_ratio: 1
|
| 21 |
+
geo_decoder_ln_post: true
|
| 22 |
+
point_feats: 4
|
| 23 |
+
pc_size: &pc_size 81920
|
| 24 |
+
pc_sharpedge_size: &pc_sharpedge_size 0
|
| 25 |
+
|
| 26 |
+
bbox_predictor:
|
| 27 |
+
target: partgen.bbox_estimator.auto_mask_api.AutoMask
|
| 28 |
+
params:
|
| 29 |
+
ckpt_path: checkpoints/p3sam.ckpt
|
| 30 |
+
conditioner:
|
| 31 |
+
target: partgen.models.conditioner.condioner_release.Conditioner
|
| 32 |
+
params:
|
| 33 |
+
use_geo: true
|
| 34 |
+
use_obj: true
|
| 35 |
+
use_seg_feat: true
|
| 36 |
+
geo_cfg:
|
| 37 |
+
target: partgen.models.conditioner.part_encoders.PartEncoder
|
| 38 |
+
output_dim: &cross2_output_dim 1024
|
| 39 |
+
params:
|
| 40 |
+
use_local: true
|
| 41 |
+
local_feat_type: latents_shape # [latents,miche-point-query-structural-vae]
|
| 42 |
+
num_tokens_cond: &num_tokens_cond 4096 # num_tokens :2048 for holopart conditioner
|
| 43 |
+
local_geo_cfg:
|
| 44 |
+
target: partgen.models.autoencoders.VolumeDecoderShapeVAE
|
| 45 |
+
params:
|
| 46 |
+
num_latents: *num_tokens_cond
|
| 47 |
+
embed_dim: 64
|
| 48 |
+
num_freqs: 8
|
| 49 |
+
include_pi: false
|
| 50 |
+
heads: 16
|
| 51 |
+
width: 1024
|
| 52 |
+
num_encoder_layers: 8
|
| 53 |
+
num_decoder_layers: 16
|
| 54 |
+
qkv_bias: false
|
| 55 |
+
qk_norm: true
|
| 56 |
+
scale_factor: *z_scale_factor
|
| 57 |
+
geo_decoder_mlp_expand_ratio: 4
|
| 58 |
+
geo_decoder_downsample_ratio: 1
|
| 59 |
+
geo_decoder_ln_post: true
|
| 60 |
+
point_feats: 4
|
| 61 |
+
pc_size: &pc_size_bbox 81920
|
| 62 |
+
pc_sharpedge_size: &pc_sharpedge_size_bbox 0
|
| 63 |
+
|
| 64 |
+
obj_encoder_cfg:
|
| 65 |
+
target: partgen.models.autoencoders.VolumeDecoderShapeVAE
|
| 66 |
+
output_dim: &cross1_output_dim 1024
|
| 67 |
+
params:
|
| 68 |
+
num_latents: 4096
|
| 69 |
+
embed_dim: 64
|
| 70 |
+
num_freqs: 8
|
| 71 |
+
include_pi: false
|
| 72 |
+
heads: 16
|
| 73 |
+
width: 1024
|
| 74 |
+
num_encoder_layers: 8
|
| 75 |
+
num_decoder_layers: 16
|
| 76 |
+
qkv_bias: false
|
| 77 |
+
qk_norm: true
|
| 78 |
+
scale_factor: 1.0039506158752403
|
| 79 |
+
geo_decoder_mlp_expand_ratio: 4
|
| 80 |
+
geo_decoder_downsample_ratio: 1
|
| 81 |
+
geo_decoder_ln_post: true
|
| 82 |
+
point_feats: 4
|
| 83 |
+
pc_size: *pc_size
|
| 84 |
+
pc_sharpedge_size: *pc_sharpedge_size
|
| 85 |
+
seg_feat_cfg:
|
| 86 |
+
target: partgen.models.conditioner.sonata_extractor.SonataFeatureExtractor
|
| 87 |
+
|
| 88 |
+
model:
|
| 89 |
+
target: partgen.models.partformer_dit.PartFormerDITPlain
|
| 90 |
+
params:
|
| 91 |
+
use_self_attention: true
|
| 92 |
+
use_cross_attention: true
|
| 93 |
+
use_cross_attention_2: true
|
| 94 |
+
# cond
|
| 95 |
+
use_bbox_cond: false
|
| 96 |
+
num_freqs: 8
|
| 97 |
+
use_part_embed: true
|
| 98 |
+
valid_num: 50 #*valid_num
|
| 99 |
+
# para
|
| 100 |
+
input_size: *num_latents
|
| 101 |
+
in_channels: 64
|
| 102 |
+
hidden_size: 2048
|
| 103 |
+
encoder_hidden_dim: *cross1_output_dim # for object mesh
|
| 104 |
+
encoder_hidden2_dim: *cross2_output_dim # for part in bbox
|
| 105 |
+
depth: 21
|
| 106 |
+
num_heads: 16
|
| 107 |
+
qk_norm: true
|
| 108 |
+
qkv_bias: false
|
| 109 |
+
qk_norm_type: 'rms'
|
| 110 |
+
with_decoupled_ca: false
|
| 111 |
+
decoupled_ca_dim: *num_tokens_cond
|
| 112 |
+
decoupled_ca_weight: 1.0
|
| 113 |
+
use_attention_pooling: false
|
| 114 |
+
use_pos_emb: false
|
| 115 |
+
num_moe_layers: 6
|
| 116 |
+
num_experts: 8
|
| 117 |
+
moe_top_k: 2
|
| 118 |
+
|
| 119 |
+
scheduler:
|
| 120 |
+
target: partgen.models.diffusion.schedulers.FlowMatchEulerDiscreteScheduler
|
| 121 |
+
params:
|
| 122 |
+
num_train_timesteps: 1000
|
XPart/partgen/config/sonata.json
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"in_channels": 9,
|
| 3 |
+
"order": [
|
| 4 |
+
"z",
|
| 5 |
+
"z-trans",
|
| 6 |
+
"hilbert",
|
| 7 |
+
"hilbert-trans"
|
| 8 |
+
],
|
| 9 |
+
"stride": [
|
| 10 |
+
2,
|
| 11 |
+
2,
|
| 12 |
+
2,
|
| 13 |
+
2
|
| 14 |
+
],
|
| 15 |
+
"enc_depths": [
|
| 16 |
+
3,
|
| 17 |
+
3,
|
| 18 |
+
3,
|
| 19 |
+
12,
|
| 20 |
+
3
|
| 21 |
+
],
|
| 22 |
+
"enc_channels": [
|
| 23 |
+
48,
|
| 24 |
+
96,
|
| 25 |
+
192,
|
| 26 |
+
384,
|
| 27 |
+
512
|
| 28 |
+
],
|
| 29 |
+
"enc_num_head": [
|
| 30 |
+
3,
|
| 31 |
+
6,
|
| 32 |
+
12,
|
| 33 |
+
24,
|
| 34 |
+
32
|
| 35 |
+
],
|
| 36 |
+
"enc_patch_size": [
|
| 37 |
+
1024,
|
| 38 |
+
1024,
|
| 39 |
+
1024,
|
| 40 |
+
1024,
|
| 41 |
+
1024
|
| 42 |
+
],
|
| 43 |
+
"mlp_ratio": 4,
|
| 44 |
+
"qkv_bias": true,
|
| 45 |
+
"qk_scale": null,
|
| 46 |
+
"attn_drop": 0.0,
|
| 47 |
+
"proj_drop": 0.0,
|
| 48 |
+
"drop_path": 0.3,
|
| 49 |
+
"shuffle_orders": true,
|
| 50 |
+
"pre_norm": true,
|
| 51 |
+
"enable_rpe": false,
|
| 52 |
+
"enable_flash": true,
|
| 53 |
+
"upcast_attention": false,
|
| 54 |
+
"upcast_softmax": false,
|
| 55 |
+
"traceable": true,
|
| 56 |
+
"enc_mode": true,
|
| 57 |
+
"mask_token": true
|
| 58 |
+
}
|
XPart/partgen/models/autoencoders/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
from .attention_blocks import CrossAttentionDecoder
|
| 16 |
+
from .attention_processors import (
|
| 17 |
+
CrossAttentionProcessor,
|
| 18 |
+
)
|
| 19 |
+
from .model import VectsetVAE, VolumeDecoderShapeVAE
|
| 20 |
+
|
| 21 |
+
from .surface_extractors import (
|
| 22 |
+
SurfaceExtractors,
|
| 23 |
+
MCSurfaceExtractor,
|
| 24 |
+
DMCSurfaceExtractor,
|
| 25 |
+
Latent2MeshOutput,
|
| 26 |
+
)
|
| 27 |
+
from .volume_decoders import (
|
| 28 |
+
VanillaVolumeDecoder,
|
| 29 |
+
)
|
XPart/partgen/models/autoencoders/attention_blocks.py
ADDED
|
@@ -0,0 +1,770 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Open Source Model Licensed under the Apache License Version 2.0
|
| 2 |
+
# and Other Licenses of the Third-Party Components therein:
|
| 3 |
+
# The below Model in this distribution may have been modified by THL A29 Limited
|
| 4 |
+
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
| 5 |
+
|
| 6 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
| 7 |
+
# The below software and/or models in this distribution may have been
|
| 8 |
+
# modified by THL A29 Limited ("Tencent Modifications").
|
| 9 |
+
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
| 10 |
+
|
| 11 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 12 |
+
# except for the third-party components listed below.
|
| 13 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 14 |
+
# in the repsective licenses of these third-party components.
|
| 15 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 16 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 17 |
+
# all relevant laws and regulations.
|
| 18 |
+
|
| 19 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 20 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 21 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 22 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 23 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
from typing import Optional, Union, List
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
from einops import rearrange
|
| 32 |
+
from torch import Tensor
|
| 33 |
+
|
| 34 |
+
from .attention_processors import CrossAttentionProcessor
|
| 35 |
+
from ...utils.misc import logger
|
| 36 |
+
|
| 37 |
+
scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
|
| 38 |
+
|
| 39 |
+
if os.environ.get("USE_SAGEATTN", "0") == "1":
|
| 40 |
+
try:
|
| 41 |
+
from sageattention import sageattn
|
| 42 |
+
except ImportError:
|
| 43 |
+
raise ImportError(
|
| 44 |
+
'Please install the package "sageattention" to use this USE_SAGEATTN.'
|
| 45 |
+
)
|
| 46 |
+
scaled_dot_product_attention = sageattn
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class FourierEmbedder(nn.Module):
|
| 50 |
+
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
| 51 |
+
each feature dimension of `x[..., i]` into:
|
| 52 |
+
[
|
| 53 |
+
sin(x[..., i]),
|
| 54 |
+
sin(f_1*x[..., i]),
|
| 55 |
+
sin(f_2*x[..., i]),
|
| 56 |
+
...
|
| 57 |
+
sin(f_N * x[..., i]),
|
| 58 |
+
cos(x[..., i]),
|
| 59 |
+
cos(f_1*x[..., i]),
|
| 60 |
+
cos(f_2*x[..., i]),
|
| 61 |
+
...
|
| 62 |
+
cos(f_N * x[..., i]),
|
| 63 |
+
x[..., i] # only present if include_input is True.
|
| 64 |
+
], here f_i is the frequency.
|
| 65 |
+
|
| 66 |
+
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
| 67 |
+
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
| 68 |
+
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
num_freqs (int): the number of frequencies, default is 6;
|
| 72 |
+
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
| 73 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
| 74 |
+
input_dim (int): the input dimension, default is 3;
|
| 75 |
+
include_input (bool): include the input tensor or not, default is True.
|
| 76 |
+
|
| 77 |
+
Attributes:
|
| 78 |
+
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
| 79 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
| 80 |
+
|
| 81 |
+
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
| 82 |
+
otherwise, it is input_dim * num_freqs * 2.
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
num_freqs: int = 6,
|
| 89 |
+
logspace: bool = True,
|
| 90 |
+
input_dim: int = 3,
|
| 91 |
+
include_input: bool = True,
|
| 92 |
+
include_pi: bool = True,
|
| 93 |
+
) -> None:
|
| 94 |
+
"""The initialization"""
|
| 95 |
+
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
if logspace:
|
| 99 |
+
frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
|
| 100 |
+
else:
|
| 101 |
+
frequencies = torch.linspace(
|
| 102 |
+
1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if include_pi:
|
| 106 |
+
frequencies *= torch.pi
|
| 107 |
+
|
| 108 |
+
self.register_buffer("frequencies", frequencies, persistent=False)
|
| 109 |
+
self.include_input = include_input
|
| 110 |
+
self.num_freqs = num_freqs
|
| 111 |
+
|
| 112 |
+
self.out_dim = self.get_dims(input_dim)
|
| 113 |
+
|
| 114 |
+
def get_dims(self, input_dim):
|
| 115 |
+
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
| 116 |
+
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
| 117 |
+
|
| 118 |
+
return out_dim
|
| 119 |
+
|
| 120 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
"""Forward process.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: tensor of shape [..., dim]
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
| 128 |
+
where temp is 1 if include_input is True and 0 otherwise.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
if self.num_freqs > 0:
|
| 132 |
+
embed = (x[..., None].contiguous() * self.frequencies).view(
|
| 133 |
+
*x.shape[:-1], -1
|
| 134 |
+
)
|
| 135 |
+
if self.include_input:
|
| 136 |
+
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
| 137 |
+
else:
|
| 138 |
+
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
| 139 |
+
else:
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class DropPath(nn.Module):
|
| 144 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 145 |
+
|
| 146 |
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
| 147 |
+
super(DropPath, self).__init__()
|
| 148 |
+
self.drop_prob = drop_prob
|
| 149 |
+
self.scale_by_keep = scale_by_keep
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 153 |
+
|
| 154 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 155 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 156 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 157 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 158 |
+
'survival rate' as the argument.
|
| 159 |
+
|
| 160 |
+
"""
|
| 161 |
+
if self.drop_prob == 0.0 or not self.training:
|
| 162 |
+
return x
|
| 163 |
+
keep_prob = 1 - self.drop_prob
|
| 164 |
+
shape = (x.shape[0],) + (1,) * (
|
| 165 |
+
x.ndim - 1
|
| 166 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 167 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 168 |
+
if keep_prob > 0.0 and self.scale_by_keep:
|
| 169 |
+
random_tensor.div_(keep_prob)
|
| 170 |
+
return x * random_tensor
|
| 171 |
+
|
| 172 |
+
def extra_repr(self):
|
| 173 |
+
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class MLP(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
*,
|
| 180 |
+
width: int,
|
| 181 |
+
expand_ratio: int = 4,
|
| 182 |
+
output_width: int = None,
|
| 183 |
+
drop_path_rate: float = 0.0,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.width = width
|
| 187 |
+
self.c_fc = nn.Linear(width, width * expand_ratio)
|
| 188 |
+
self.c_proj = nn.Linear(
|
| 189 |
+
width * expand_ratio, output_width if output_width is not None else width
|
| 190 |
+
)
|
| 191 |
+
self.gelu = nn.GELU()
|
| 192 |
+
self.drop_path = (
|
| 193 |
+
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def forward(self, x):
|
| 197 |
+
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class QKVMultiheadCrossAttention(nn.Module):
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
*,
|
| 204 |
+
heads: int,
|
| 205 |
+
width=None,
|
| 206 |
+
qk_norm=False,
|
| 207 |
+
norm_layer=nn.LayerNorm,
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.heads = heads
|
| 211 |
+
self.q_norm = (
|
| 212 |
+
norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
|
| 213 |
+
if qk_norm
|
| 214 |
+
else nn.Identity()
|
| 215 |
+
)
|
| 216 |
+
self.k_norm = (
|
| 217 |
+
norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
|
| 218 |
+
if qk_norm
|
| 219 |
+
else nn.Identity()
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
self.attn_processor = CrossAttentionProcessor()
|
| 223 |
+
|
| 224 |
+
def forward(self, q, kv):
|
| 225 |
+
_, n_ctx, _ = q.shape
|
| 226 |
+
bs, n_data, width = kv.shape
|
| 227 |
+
attn_ch = width // self.heads // 2
|
| 228 |
+
q = q.view(bs, n_ctx, self.heads, -1)
|
| 229 |
+
kv = kv.view(bs, n_data, self.heads, -1)
|
| 230 |
+
k, v = torch.split(kv, attn_ch, dim=-1)
|
| 231 |
+
|
| 232 |
+
q = self.q_norm(q)
|
| 233 |
+
k = self.k_norm(k)
|
| 234 |
+
q, k, v = map(
|
| 235 |
+
lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v)
|
| 236 |
+
)
|
| 237 |
+
out = self.attn_processor(self, q, k, v)
|
| 238 |
+
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
|
| 239 |
+
return out
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class MultiheadCrossAttention(nn.Module):
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
*,
|
| 246 |
+
width: int,
|
| 247 |
+
heads: int,
|
| 248 |
+
qkv_bias: bool = True,
|
| 249 |
+
data_width: Optional[int] = None,
|
| 250 |
+
norm_layer=nn.LayerNorm,
|
| 251 |
+
qk_norm: bool = False,
|
| 252 |
+
kv_cache: bool = False,
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.width = width
|
| 256 |
+
self.heads = heads
|
| 257 |
+
self.data_width = width if data_width is None else data_width
|
| 258 |
+
self.c_q = nn.Linear(width, width, bias=qkv_bias)
|
| 259 |
+
self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
|
| 260 |
+
self.c_proj = nn.Linear(width, width)
|
| 261 |
+
self.attention = QKVMultiheadCrossAttention(
|
| 262 |
+
heads=heads,
|
| 263 |
+
width=width,
|
| 264 |
+
norm_layer=norm_layer,
|
| 265 |
+
qk_norm=qk_norm,
|
| 266 |
+
)
|
| 267 |
+
self.kv_cache = kv_cache
|
| 268 |
+
self.data = None
|
| 269 |
+
|
| 270 |
+
def forward(self, x, data):
|
| 271 |
+
x = self.c_q(x)
|
| 272 |
+
if self.kv_cache:
|
| 273 |
+
if self.data is None:
|
| 274 |
+
self.data = self.c_kv(data)
|
| 275 |
+
logger.info(
|
| 276 |
+
"Save kv cache,this should be called only once for one mesh"
|
| 277 |
+
)
|
| 278 |
+
data = self.data
|
| 279 |
+
else:
|
| 280 |
+
data = self.c_kv(data)
|
| 281 |
+
x = self.attention(x, data)
|
| 282 |
+
x = self.c_proj(x)
|
| 283 |
+
return x
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class ResidualCrossAttentionBlock(nn.Module):
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
*,
|
| 290 |
+
width: int,
|
| 291 |
+
heads: int,
|
| 292 |
+
mlp_expand_ratio: int = 4,
|
| 293 |
+
data_width: Optional[int] = None,
|
| 294 |
+
qkv_bias: bool = True,
|
| 295 |
+
norm_layer=nn.LayerNorm,
|
| 296 |
+
qk_norm: bool = False,
|
| 297 |
+
):
|
| 298 |
+
super().__init__()
|
| 299 |
+
|
| 300 |
+
if data_width is None:
|
| 301 |
+
data_width = width
|
| 302 |
+
|
| 303 |
+
self.attn = MultiheadCrossAttention(
|
| 304 |
+
width=width,
|
| 305 |
+
heads=heads,
|
| 306 |
+
data_width=data_width,
|
| 307 |
+
qkv_bias=qkv_bias,
|
| 308 |
+
norm_layer=norm_layer,
|
| 309 |
+
qk_norm=qk_norm,
|
| 310 |
+
)
|
| 311 |
+
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 312 |
+
self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
|
| 313 |
+
self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 314 |
+
self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
|
| 315 |
+
|
| 316 |
+
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
| 317 |
+
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
| 318 |
+
x = x + self.mlp(self.ln_3(x))
|
| 319 |
+
return x
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class QKVMultiheadAttention(nn.Module):
|
| 323 |
+
def __init__(
|
| 324 |
+
self, *, heads: int, width=None, qk_norm=False, norm_layer=nn.LayerNorm
|
| 325 |
+
):
|
| 326 |
+
super().__init__()
|
| 327 |
+
self.heads = heads
|
| 328 |
+
self.q_norm = (
|
| 329 |
+
norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
|
| 330 |
+
if qk_norm
|
| 331 |
+
else nn.Identity()
|
| 332 |
+
)
|
| 333 |
+
self.k_norm = (
|
| 334 |
+
norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
|
| 335 |
+
if qk_norm
|
| 336 |
+
else nn.Identity()
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def forward(self, qkv):
|
| 340 |
+
bs, n_ctx, width = qkv.shape
|
| 341 |
+
attn_ch = width // self.heads // 3
|
| 342 |
+
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
| 343 |
+
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
| 344 |
+
|
| 345 |
+
q = self.q_norm(q)
|
| 346 |
+
k = self.k_norm(k)
|
| 347 |
+
|
| 348 |
+
q, k, v = map(
|
| 349 |
+
lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v)
|
| 350 |
+
)
|
| 351 |
+
out = (
|
| 352 |
+
scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
|
| 353 |
+
)
|
| 354 |
+
return out
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class MultiheadAttention(nn.Module):
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
*,
|
| 361 |
+
width: int,
|
| 362 |
+
heads: int,
|
| 363 |
+
qkv_bias: bool,
|
| 364 |
+
norm_layer=nn.LayerNorm,
|
| 365 |
+
qk_norm: bool = False,
|
| 366 |
+
drop_path_rate: float = 0.0,
|
| 367 |
+
):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.width = width
|
| 370 |
+
self.heads = heads
|
| 371 |
+
self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
|
| 372 |
+
self.c_proj = nn.Linear(width, width)
|
| 373 |
+
self.attention = QKVMultiheadAttention(
|
| 374 |
+
heads=heads,
|
| 375 |
+
width=width,
|
| 376 |
+
norm_layer=norm_layer,
|
| 377 |
+
qk_norm=qk_norm,
|
| 378 |
+
)
|
| 379 |
+
self.drop_path = (
|
| 380 |
+
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
def forward(self, x):
|
| 384 |
+
x = self.c_qkv(x)
|
| 385 |
+
x = self.attention(x)
|
| 386 |
+
x = self.drop_path(self.c_proj(x))
|
| 387 |
+
return x
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class ResidualAttentionBlock(nn.Module):
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
*,
|
| 394 |
+
width: int,
|
| 395 |
+
heads: int,
|
| 396 |
+
qkv_bias: bool = True,
|
| 397 |
+
norm_layer=nn.LayerNorm,
|
| 398 |
+
qk_norm: bool = False,
|
| 399 |
+
drop_path_rate: float = 0.0,
|
| 400 |
+
):
|
| 401 |
+
super().__init__()
|
| 402 |
+
self.attn = MultiheadAttention(
|
| 403 |
+
width=width,
|
| 404 |
+
heads=heads,
|
| 405 |
+
qkv_bias=qkv_bias,
|
| 406 |
+
norm_layer=norm_layer,
|
| 407 |
+
qk_norm=qk_norm,
|
| 408 |
+
drop_path_rate=drop_path_rate,
|
| 409 |
+
)
|
| 410 |
+
self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 411 |
+
self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
|
| 412 |
+
self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
|
| 413 |
+
|
| 414 |
+
def forward(self, x: torch.Tensor):
|
| 415 |
+
x = x + self.attn(self.ln_1(x))
|
| 416 |
+
x = x + self.mlp(self.ln_2(x))
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class Transformer(nn.Module):
|
| 421 |
+
def __init__(
|
| 422 |
+
self,
|
| 423 |
+
*,
|
| 424 |
+
width: int,
|
| 425 |
+
layers: int,
|
| 426 |
+
heads: int,
|
| 427 |
+
qkv_bias: bool = True,
|
| 428 |
+
norm_layer=nn.LayerNorm,
|
| 429 |
+
qk_norm: bool = False,
|
| 430 |
+
drop_path_rate: float = 0.0,
|
| 431 |
+
):
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.width = width
|
| 434 |
+
self.layers = layers
|
| 435 |
+
self.resblocks = nn.ModuleList([
|
| 436 |
+
ResidualAttentionBlock(
|
| 437 |
+
width=width,
|
| 438 |
+
heads=heads,
|
| 439 |
+
qkv_bias=qkv_bias,
|
| 440 |
+
norm_layer=norm_layer,
|
| 441 |
+
qk_norm=qk_norm,
|
| 442 |
+
drop_path_rate=drop_path_rate,
|
| 443 |
+
)
|
| 444 |
+
for _ in range(layers)
|
| 445 |
+
])
|
| 446 |
+
|
| 447 |
+
def forward(self, x: torch.Tensor):
|
| 448 |
+
for block in self.resblocks:
|
| 449 |
+
x = block(x)
|
| 450 |
+
return x
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
class CrossAttentionDecoder(nn.Module):
|
| 454 |
+
|
| 455 |
+
def __init__(
|
| 456 |
+
self,
|
| 457 |
+
*,
|
| 458 |
+
out_channels: int,
|
| 459 |
+
fourier_embedder: FourierEmbedder,
|
| 460 |
+
width: int,
|
| 461 |
+
heads: int,
|
| 462 |
+
mlp_expand_ratio: int = 4,
|
| 463 |
+
downsample_ratio: int = 1,
|
| 464 |
+
enable_ln_post: bool = True,
|
| 465 |
+
qkv_bias: bool = True,
|
| 466 |
+
qk_norm: bool = False,
|
| 467 |
+
label_type: str = "binary",
|
| 468 |
+
):
|
| 469 |
+
super().__init__()
|
| 470 |
+
|
| 471 |
+
self.enable_ln_post = enable_ln_post
|
| 472 |
+
self.fourier_embedder = fourier_embedder
|
| 473 |
+
self.downsample_ratio = downsample_ratio
|
| 474 |
+
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
|
| 475 |
+
if self.downsample_ratio != 1:
|
| 476 |
+
self.latents_proj = nn.Linear(width * downsample_ratio, width)
|
| 477 |
+
if self.enable_ln_post == False:
|
| 478 |
+
qk_norm = False
|
| 479 |
+
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
| 480 |
+
width=width,
|
| 481 |
+
mlp_expand_ratio=mlp_expand_ratio,
|
| 482 |
+
heads=heads,
|
| 483 |
+
qkv_bias=qkv_bias,
|
| 484 |
+
qk_norm=qk_norm,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if self.enable_ln_post:
|
| 488 |
+
self.ln_post = nn.LayerNorm(width)
|
| 489 |
+
self.output_proj = nn.Linear(width, out_channels)
|
| 490 |
+
self.label_type = label_type
|
| 491 |
+
self.count = 0
|
| 492 |
+
|
| 493 |
+
def set_cross_attention_processor(self, processor):
|
| 494 |
+
self.cross_attn_decoder.attn.attention.attn_processor = processor
|
| 495 |
+
|
| 496 |
+
# def set_default_cross_attention_processor(self):
|
| 497 |
+
# self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
|
| 498 |
+
|
| 499 |
+
def forward(self, queries=None, query_embeddings=None, latents=None):
|
| 500 |
+
if query_embeddings is None:
|
| 501 |
+
query_embeddings = self.query_proj(
|
| 502 |
+
self.fourier_embedder(queries).to(latents.dtype)
|
| 503 |
+
)
|
| 504 |
+
self.count += query_embeddings.shape[1]
|
| 505 |
+
if self.downsample_ratio != 1:
|
| 506 |
+
latents = self.latents_proj(latents)
|
| 507 |
+
x = self.cross_attn_decoder(query_embeddings, latents)
|
| 508 |
+
if self.enable_ln_post:
|
| 509 |
+
x = self.ln_post(x)
|
| 510 |
+
occ = self.output_proj(x)
|
| 511 |
+
return occ
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def fps(
|
| 515 |
+
src: torch.Tensor,
|
| 516 |
+
batch: Optional[Tensor] = None,
|
| 517 |
+
ratio: Optional[Union[Tensor, float]] = None,
|
| 518 |
+
random_start: bool = True,
|
| 519 |
+
batch_size: Optional[int] = None,
|
| 520 |
+
ptr: Optional[Union[Tensor, List[int]]] = None,
|
| 521 |
+
):
|
| 522 |
+
src = src.float()
|
| 523 |
+
from torch_cluster import fps as fps_fn
|
| 524 |
+
|
| 525 |
+
output = fps_fn(src, batch, ratio, random_start, batch_size, ptr)
|
| 526 |
+
return output
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class PointCrossAttentionEncoder(nn.Module):
|
| 530 |
+
|
| 531 |
+
def __init__(
|
| 532 |
+
self,
|
| 533 |
+
*,
|
| 534 |
+
num_latents: int,
|
| 535 |
+
downsample_ratio: float,
|
| 536 |
+
pc_size: int,
|
| 537 |
+
pc_sharpedge_size: int,
|
| 538 |
+
fourier_embedder: FourierEmbedder,
|
| 539 |
+
point_feats: int,
|
| 540 |
+
width: int,
|
| 541 |
+
heads: int,
|
| 542 |
+
layers: int,
|
| 543 |
+
normal_pe: bool = False,
|
| 544 |
+
qkv_bias: bool = True,
|
| 545 |
+
use_ln_post: bool = False,
|
| 546 |
+
use_checkpoint: bool = False,
|
| 547 |
+
qk_norm: bool = False,
|
| 548 |
+
):
|
| 549 |
+
|
| 550 |
+
super().__init__()
|
| 551 |
+
|
| 552 |
+
self.use_checkpoint = use_checkpoint
|
| 553 |
+
self.num_latents = num_latents
|
| 554 |
+
self.downsample_ratio = downsample_ratio
|
| 555 |
+
self.point_feats = point_feats
|
| 556 |
+
self.normal_pe = normal_pe
|
| 557 |
+
|
| 558 |
+
if pc_sharpedge_size == 0:
|
| 559 |
+
print(
|
| 560 |
+
f"PointCrossAttentionEncoder INFO: pc_sharpedge_size is not given,"
|
| 561 |
+
f" using pc_size as pc_sharpedge_size"
|
| 562 |
+
)
|
| 563 |
+
else:
|
| 564 |
+
print(
|
| 565 |
+
"PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using"
|
| 566 |
+
f" pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}"
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
self.pc_size = pc_size
|
| 570 |
+
self.pc_sharpedge_size = pc_sharpedge_size
|
| 571 |
+
|
| 572 |
+
self.fourier_embedder = fourier_embedder
|
| 573 |
+
|
| 574 |
+
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
| 575 |
+
self.cross_attn = ResidualCrossAttentionBlock(
|
| 576 |
+
width=width, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
self.self_attn = None
|
| 580 |
+
if layers > 0:
|
| 581 |
+
self.self_attn = Transformer(
|
| 582 |
+
width=width,
|
| 583 |
+
layers=layers,
|
| 584 |
+
heads=heads,
|
| 585 |
+
qkv_bias=qkv_bias,
|
| 586 |
+
qk_norm=qk_norm,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
if use_ln_post:
|
| 590 |
+
self.ln_post = nn.LayerNorm(width)
|
| 591 |
+
else:
|
| 592 |
+
self.ln_post = None
|
| 593 |
+
|
| 594 |
+
def sample_points_and_latents(
|
| 595 |
+
self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None
|
| 596 |
+
):
|
| 597 |
+
B, N, D = pc.shape
|
| 598 |
+
num_pts = self.num_latents * self.downsample_ratio
|
| 599 |
+
|
| 600 |
+
# Compute number of latents
|
| 601 |
+
num_latents = int(num_pts / self.downsample_ratio)
|
| 602 |
+
|
| 603 |
+
# Compute the number of random and sharpedge latents
|
| 604 |
+
num_random_query = (
|
| 605 |
+
self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
|
| 606 |
+
)
|
| 607 |
+
num_sharpedge_query = num_latents - num_random_query
|
| 608 |
+
|
| 609 |
+
# Split random and sharpedge surface points
|
| 610 |
+
random_pc, sharpedge_pc = torch.split(
|
| 611 |
+
pc, [self.pc_size, self.pc_sharpedge_size], dim=1
|
| 612 |
+
)
|
| 613 |
+
assert (
|
| 614 |
+
random_pc.shape[1] <= self.pc_size
|
| 615 |
+
), "Random surface points size must be less than or equal to pc_size"
|
| 616 |
+
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, (
|
| 617 |
+
"Sharpedge surface points size must be less than or equal to"
|
| 618 |
+
" pc_sharpedge_size"
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# Randomly select random surface points and random query points
|
| 622 |
+
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
| 623 |
+
random_query_ratio = num_random_query / input_random_pc_size
|
| 624 |
+
idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[
|
| 625 |
+
:input_random_pc_size
|
| 626 |
+
]
|
| 627 |
+
input_random_pc = random_pc[:, idx_random_pc, :]
|
| 628 |
+
flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D)
|
| 629 |
+
N_down = int(flatten_input_random_pc.shape[0] / B)
|
| 630 |
+
batch_down = torch.arange(B).to(pc.device)
|
| 631 |
+
batch_down = torch.repeat_interleave(batch_down, N_down)
|
| 632 |
+
idx_query_random = fps(
|
| 633 |
+
flatten_input_random_pc, batch_down, ratio=random_query_ratio
|
| 634 |
+
)
|
| 635 |
+
query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D)
|
| 636 |
+
|
| 637 |
+
# Randomly select sharpedge surface points and sharpedge query points
|
| 638 |
+
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
| 639 |
+
if input_sharpedge_pc_size == 0:
|
| 640 |
+
input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to(
|
| 641 |
+
pc.device
|
| 642 |
+
)
|
| 643 |
+
query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to(
|
| 644 |
+
pc.device
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size
|
| 648 |
+
idx_sharpedge_pc = torch.randperm(
|
| 649 |
+
sharpedge_pc.shape[1], device=sharpedge_pc.device
|
| 650 |
+
)[:input_sharpedge_pc_size]
|
| 651 |
+
input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :]
|
| 652 |
+
flatten_input_sharpedge_surface_points = input_sharpedge_pc.view(
|
| 653 |
+
B * input_sharpedge_pc_size, D
|
| 654 |
+
)
|
| 655 |
+
N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B)
|
| 656 |
+
batch_down = torch.arange(B).to(pc.device)
|
| 657 |
+
batch_down = torch.repeat_interleave(batch_down, N_down)
|
| 658 |
+
idx_query_sharpedge = fps(
|
| 659 |
+
flatten_input_sharpedge_surface_points,
|
| 660 |
+
batch_down,
|
| 661 |
+
ratio=sharpedge_query_ratio,
|
| 662 |
+
)
|
| 663 |
+
query_sharpedge_pc = flatten_input_sharpedge_surface_points[
|
| 664 |
+
idx_query_sharpedge
|
| 665 |
+
].view(B, -1, D)
|
| 666 |
+
|
| 667 |
+
# Concatenate random and sharpedge surface points and query points
|
| 668 |
+
query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1)
|
| 669 |
+
input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1)
|
| 670 |
+
|
| 671 |
+
# PE
|
| 672 |
+
query = self.fourier_embedder(query_pc)
|
| 673 |
+
data = self.fourier_embedder(input_pc)
|
| 674 |
+
|
| 675 |
+
# Concat normal if given
|
| 676 |
+
if self.point_feats != 0:
|
| 677 |
+
|
| 678 |
+
random_surface_feats, sharpedge_surface_feats = torch.split(
|
| 679 |
+
feats, [self.pc_size, self.pc_sharpedge_size], dim=1
|
| 680 |
+
)
|
| 681 |
+
input_random_surface_feats = random_surface_feats[:, idx_random_pc, :]
|
| 682 |
+
flatten_input_random_surface_feats = input_random_surface_feats.view(
|
| 683 |
+
B * input_random_pc_size, -1
|
| 684 |
+
)
|
| 685 |
+
query_random_feats = flatten_input_random_surface_feats[
|
| 686 |
+
idx_query_random
|
| 687 |
+
].view(B, -1, flatten_input_random_surface_feats.shape[-1])
|
| 688 |
+
|
| 689 |
+
if input_sharpedge_pc_size == 0:
|
| 690 |
+
input_sharpedge_surface_feats = torch.zeros(
|
| 691 |
+
B, 0, self.point_feats, dtype=input_random_surface_feats.dtype
|
| 692 |
+
).to(pc.device)
|
| 693 |
+
query_sharpedge_feats = torch.zeros(
|
| 694 |
+
B, 0, self.point_feats, dtype=query_random_feats.dtype
|
| 695 |
+
).to(pc.device)
|
| 696 |
+
else:
|
| 697 |
+
input_sharpedge_surface_feats = sharpedge_surface_feats[
|
| 698 |
+
:, idx_sharpedge_pc, :
|
| 699 |
+
]
|
| 700 |
+
flatten_input_sharpedge_surface_feats = (
|
| 701 |
+
input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size, -1)
|
| 702 |
+
)
|
| 703 |
+
query_sharpedge_feats = flatten_input_sharpedge_surface_feats[
|
| 704 |
+
idx_query_sharpedge
|
| 705 |
+
].view(B, -1, flatten_input_sharpedge_surface_feats.shape[-1])
|
| 706 |
+
|
| 707 |
+
query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1)
|
| 708 |
+
input_feats = torch.cat(
|
| 709 |
+
[input_random_surface_feats, input_sharpedge_surface_feats], dim=1
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
if self.normal_pe:
|
| 713 |
+
query_normal_pe = self.fourier_embedder(query_feats[..., :3])
|
| 714 |
+
input_normal_pe = self.fourier_embedder(input_feats[..., :3])
|
| 715 |
+
query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1)
|
| 716 |
+
input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1)
|
| 717 |
+
|
| 718 |
+
query = torch.cat([query, query_feats], dim=-1)
|
| 719 |
+
data = torch.cat([data, input_feats], dim=-1)
|
| 720 |
+
|
| 721 |
+
if input_sharpedge_pc_size == 0:
|
| 722 |
+
query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
|
| 723 |
+
input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
|
| 724 |
+
|
| 725 |
+
# print(f'query_pc: {query_pc.shape}')
|
| 726 |
+
# print(f'input_pc: {input_pc.shape}')
|
| 727 |
+
# print(f'query_random_pc: {query_random_pc.shape}')
|
| 728 |
+
# print(f'input_random_pc: {input_random_pc.shape}')
|
| 729 |
+
# print(f'query_sharpedge_pc: {query_sharpedge_pc.shape}')
|
| 730 |
+
# print(f'input_sharpedge_pc: {input_sharpedge_pc.shape}')
|
| 731 |
+
|
| 732 |
+
return (
|
| 733 |
+
query.view(B, -1, query.shape[-1]),
|
| 734 |
+
data.view(B, -1, data.shape[-1]),
|
| 735 |
+
[
|
| 736 |
+
query_pc,
|
| 737 |
+
input_pc,
|
| 738 |
+
query_random_pc,
|
| 739 |
+
input_random_pc,
|
| 740 |
+
query_sharpedge_pc,
|
| 741 |
+
input_sharpedge_pc,
|
| 742 |
+
],
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
def forward(self, pc, feats):
|
| 746 |
+
"""
|
| 747 |
+
|
| 748 |
+
Args:
|
| 749 |
+
pc (torch.FloatTensor): [B, N, 3]
|
| 750 |
+
feats (torch.FloatTensor or None): [B, N, C]
|
| 751 |
+
|
| 752 |
+
Returns:
|
| 753 |
+
|
| 754 |
+
"""
|
| 755 |
+
|
| 756 |
+
query, data, pc_infos = self.sample_points_and_latents(pc, feats)
|
| 757 |
+
|
| 758 |
+
query = self.input_proj(query)
|
| 759 |
+
query = query
|
| 760 |
+
data = self.input_proj(data)
|
| 761 |
+
data = data
|
| 762 |
+
|
| 763 |
+
latents = self.cross_attn(query, data)
|
| 764 |
+
if self.self_attn is not None:
|
| 765 |
+
latents = self.self_attn(latents)
|
| 766 |
+
|
| 767 |
+
if self.ln_post is not None:
|
| 768 |
+
latents = self.ln_post(latents)
|
| 769 |
+
|
| 770 |
+
return latents, pc_infos
|
XPart/partgen/models/autoencoders/attention_processors.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
scaled_dot_product_attention = F.scaled_dot_product_attention
|
| 21 |
+
if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
|
| 22 |
+
try:
|
| 23 |
+
from sageattention import sageattn
|
| 24 |
+
except ImportError:
|
| 25 |
+
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
|
| 26 |
+
scaled_dot_product_attention = sageattn
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CrossAttentionProcessor:
|
| 30 |
+
def __call__(self, attn, q, k, v):
|
| 31 |
+
out = scaled_dot_product_attention(q, k, v)
|
| 32 |
+
return out
|
XPart/partgen/models/autoencoders/model.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Open Source Model Licensed under the Apache License Version 2.0
|
| 2 |
+
# and Other Licenses of the Third-Party Components therein:
|
| 3 |
+
# The below Model in this distribution may have been modified by THL A29 Limited
|
| 4 |
+
# ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
| 5 |
+
|
| 6 |
+
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
| 7 |
+
# The below software and/or models in this distribution may have been
|
| 8 |
+
# modified by THL A29 Limited ("Tencent Modifications").
|
| 9 |
+
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
| 10 |
+
|
| 11 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 12 |
+
# except for the third-party components listed below.
|
| 13 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 14 |
+
# in the repsective licenses of these third-party components.
|
| 15 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 16 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 17 |
+
# all relevant laws and regulations.
|
| 18 |
+
|
| 19 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 20 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 21 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 22 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 23 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
from typing import Tuple, List, Union
|
| 28 |
+
|
| 29 |
+
from functools import partial
|
| 30 |
+
|
| 31 |
+
import copy
|
| 32 |
+
import numpy as np
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import yaml
|
| 36 |
+
|
| 37 |
+
from .attention_blocks import (
|
| 38 |
+
FourierEmbedder,
|
| 39 |
+
Transformer,
|
| 40 |
+
CrossAttentionDecoder,
|
| 41 |
+
PointCrossAttentionEncoder,
|
| 42 |
+
)
|
| 43 |
+
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors, Latent2MeshOutput
|
| 44 |
+
from .volume_decoders import (
|
| 45 |
+
VanillaVolumeDecoder,
|
| 46 |
+
)
|
| 47 |
+
from ...utils.misc import logger, synchronize_timer, smart_load_model
|
| 48 |
+
from ...utils.mesh_utils import extract_geometry_fast
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DiagonalGaussianDistribution(object):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
parameters: Union[torch.Tensor, List[torch.Tensor]],
|
| 55 |
+
deterministic=False,
|
| 56 |
+
feat_dim=1,
|
| 57 |
+
):
|
| 58 |
+
"""
|
| 59 |
+
Initialize a diagonal Gaussian distribution with mean and log-variance parameters.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
parameters (Union[torch.Tensor, List[torch.Tensor]]):
|
| 63 |
+
Either a single tensor containing concatenated mean and log-variance along `feat_dim`,
|
| 64 |
+
or a list of two tensors [mean, logvar].
|
| 65 |
+
deterministic (bool, optional): If True, the distribution is deterministic (zero variance).
|
| 66 |
+
Default is False. feat_dim (int, optional): Dimension along which mean and logvar are
|
| 67 |
+
concatenated if parameters is a single tensor. Default is 1.
|
| 68 |
+
"""
|
| 69 |
+
self.feat_dim = feat_dim
|
| 70 |
+
self.parameters = parameters
|
| 71 |
+
|
| 72 |
+
if isinstance(parameters, list):
|
| 73 |
+
self.mean = parameters[0]
|
| 74 |
+
self.logvar = parameters[1]
|
| 75 |
+
else:
|
| 76 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
|
| 77 |
+
|
| 78 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 79 |
+
self.deterministic = deterministic
|
| 80 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 81 |
+
self.var = torch.exp(self.logvar)
|
| 82 |
+
if self.deterministic:
|
| 83 |
+
self.var = self.std = torch.zeros_like(self.mean)
|
| 84 |
+
|
| 85 |
+
def sample(self):
|
| 86 |
+
"""
|
| 87 |
+
Sample from the diagonal Gaussian distribution.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
torch.Tensor: A sample tensor with the same shape as the mean.
|
| 91 |
+
"""
|
| 92 |
+
x = self.mean + self.std * torch.randn_like(self.mean)
|
| 93 |
+
return x
|
| 94 |
+
|
| 95 |
+
def kl(self, other=None, dims=(1, 2, 3)):
|
| 96 |
+
"""
|
| 97 |
+
Compute the Kullback-Leibler (KL) divergence between this distribution and another.
|
| 98 |
+
|
| 99 |
+
If `other` is None, compute KL divergence to a standard normal distribution N(0, I).
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution.
|
| 103 |
+
dims (tuple, optional): Dimensions along which to compute the mean KL divergence.
|
| 104 |
+
Default is (1, 2, 3).
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
torch.Tensor: The mean KL divergence value.
|
| 108 |
+
"""
|
| 109 |
+
if self.deterministic:
|
| 110 |
+
return torch.Tensor([0.0])
|
| 111 |
+
else:
|
| 112 |
+
if other is None:
|
| 113 |
+
return 0.5 * torch.mean(
|
| 114 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
return 0.5 * torch.mean(
|
| 118 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
| 119 |
+
+ self.var / other.var
|
| 120 |
+
- 1.0
|
| 121 |
+
- self.logvar
|
| 122 |
+
+ other.logvar,
|
| 123 |
+
dim=dims,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def nll(self, sample, dims=(1, 2, 3)):
|
| 127 |
+
if self.deterministic:
|
| 128 |
+
return torch.Tensor([0.0])
|
| 129 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 130 |
+
return 0.5 * torch.sum(
|
| 131 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 132 |
+
dim=dims,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def mode(self):
|
| 136 |
+
return self.mean
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class VectsetVAE(nn.Module):
|
| 140 |
+
|
| 141 |
+
@classmethod
|
| 142 |
+
@synchronize_timer("VectsetVAE Model Loading")
|
| 143 |
+
def from_single_file(
|
| 144 |
+
cls,
|
| 145 |
+
ckpt_path,
|
| 146 |
+
config_path,
|
| 147 |
+
device="cuda",
|
| 148 |
+
dtype=torch.float16,
|
| 149 |
+
use_safetensors=None,
|
| 150 |
+
**kwargs,
|
| 151 |
+
):
|
| 152 |
+
# load config
|
| 153 |
+
with open(config_path, "r") as f:
|
| 154 |
+
config = yaml.safe_load(f)
|
| 155 |
+
|
| 156 |
+
# load ckpt
|
| 157 |
+
if use_safetensors:
|
| 158 |
+
ckpt_path = ckpt_path.replace(".ckpt", ".safetensors")
|
| 159 |
+
if not os.path.exists(ckpt_path):
|
| 160 |
+
raise FileNotFoundError(f"Model file {ckpt_path} not found")
|
| 161 |
+
|
| 162 |
+
logger.info(f"Loading model from {ckpt_path}")
|
| 163 |
+
if use_safetensors:
|
| 164 |
+
import safetensors.torch
|
| 165 |
+
|
| 166 |
+
ckpt = safetensors.torch.load_file(ckpt_path, device="cpu")
|
| 167 |
+
else:
|
| 168 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 169 |
+
|
| 170 |
+
model_kwargs = config["params"]
|
| 171 |
+
model_kwargs.update(kwargs)
|
| 172 |
+
|
| 173 |
+
model = cls(**model_kwargs)
|
| 174 |
+
model.load_state_dict(ckpt)
|
| 175 |
+
model.to(device=device, dtype=dtype)
|
| 176 |
+
return model
|
| 177 |
+
|
| 178 |
+
@classmethod
|
| 179 |
+
def from_pretrained(
|
| 180 |
+
cls,
|
| 181 |
+
model_path,
|
| 182 |
+
device="cuda",
|
| 183 |
+
dtype=torch.float16,
|
| 184 |
+
use_safetensors=False,
|
| 185 |
+
variant="fp16",
|
| 186 |
+
subfolder="hunyuan3d-vae-v2-1",
|
| 187 |
+
**kwargs,
|
| 188 |
+
):
|
| 189 |
+
config_path, ckpt_path = smart_load_model(
|
| 190 |
+
model_path,
|
| 191 |
+
subfolder=subfolder,
|
| 192 |
+
use_safetensors=use_safetensors,
|
| 193 |
+
variant=variant,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return cls.from_single_file(
|
| 197 |
+
ckpt_path,
|
| 198 |
+
config_path,
|
| 199 |
+
device=device,
|
| 200 |
+
dtype=dtype,
|
| 201 |
+
use_safetensors=use_safetensors,
|
| 202 |
+
**kwargs,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def init_from_ckpt(self, path, ignore_keys=()):
|
| 206 |
+
state_dict = torch.load(path, map_location="cpu")
|
| 207 |
+
state_dict = state_dict.get("state_dict", state_dict)
|
| 208 |
+
keys = list(state_dict.keys())
|
| 209 |
+
for k in keys:
|
| 210 |
+
for ik in ignore_keys:
|
| 211 |
+
if k.startswith(ik):
|
| 212 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 213 |
+
del state_dict[k]
|
| 214 |
+
missing, unexpected = self.load_state_dict(state_dict, strict=False)
|
| 215 |
+
print(
|
| 216 |
+
f"Restored from {path} with {len(missing)} missing and"
|
| 217 |
+
f" {len(unexpected)} unexpected keys"
|
| 218 |
+
)
|
| 219 |
+
if len(missing) > 0:
|
| 220 |
+
print(f"Missing Keys: {missing}")
|
| 221 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 222 |
+
|
| 223 |
+
def __init__(self, volume_decoder=None, surface_extractor=None):
|
| 224 |
+
super().__init__()
|
| 225 |
+
if volume_decoder is None:
|
| 226 |
+
volume_decoder = VanillaVolumeDecoder()
|
| 227 |
+
if surface_extractor is None:
|
| 228 |
+
surface_extractor = MCSurfaceExtractor()
|
| 229 |
+
self.volume_decoder = volume_decoder
|
| 230 |
+
self.surface_extractor = surface_extractor
|
| 231 |
+
|
| 232 |
+
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
|
| 233 |
+
with synchronize_timer("Volume decoding"):
|
| 234 |
+
grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
|
| 235 |
+
with synchronize_timer("Surface extraction"):
|
| 236 |
+
outputs = self.surface_extractor(grid_logits, **kwargs)
|
| 237 |
+
return outputs
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class VolumeDecoderShapeVAE(VectsetVAE):
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
*,
|
| 244 |
+
num_latents: int,
|
| 245 |
+
embed_dim: int,
|
| 246 |
+
width: int,
|
| 247 |
+
heads: int,
|
| 248 |
+
num_decoder_layers: int,
|
| 249 |
+
num_encoder_layers: int = 8,
|
| 250 |
+
pc_size: int = 5120,
|
| 251 |
+
pc_sharpedge_size: int = 5120,
|
| 252 |
+
point_feats: int = 3,
|
| 253 |
+
downsample_ratio: int = 20,
|
| 254 |
+
geo_decoder_downsample_ratio: int = 1,
|
| 255 |
+
geo_decoder_mlp_expand_ratio: int = 4,
|
| 256 |
+
geo_decoder_ln_post: bool = True,
|
| 257 |
+
num_freqs: int = 8,
|
| 258 |
+
include_pi: bool = True,
|
| 259 |
+
qkv_bias: bool = True,
|
| 260 |
+
qk_norm: bool = False,
|
| 261 |
+
label_type: str = "binary",
|
| 262 |
+
drop_path_rate: float = 0.0,
|
| 263 |
+
scale_factor: float = 1.0,
|
| 264 |
+
use_ln_post: bool = True,
|
| 265 |
+
ckpt_path=None,
|
| 266 |
+
volume_decoder=None,
|
| 267 |
+
surface_extractor=None,
|
| 268 |
+
):
|
| 269 |
+
super().__init__(volume_decoder, surface_extractor)
|
| 270 |
+
self.geo_decoder_ln_post = geo_decoder_ln_post
|
| 271 |
+
self.downsample_ratio = downsample_ratio
|
| 272 |
+
|
| 273 |
+
self.fourier_embedder = FourierEmbedder(
|
| 274 |
+
num_freqs=num_freqs, include_pi=include_pi
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.encoder = PointCrossAttentionEncoder(
|
| 278 |
+
fourier_embedder=self.fourier_embedder,
|
| 279 |
+
num_latents=num_latents,
|
| 280 |
+
downsample_ratio=self.downsample_ratio,
|
| 281 |
+
pc_size=pc_size,
|
| 282 |
+
pc_sharpedge_size=pc_sharpedge_size,
|
| 283 |
+
point_feats=point_feats,
|
| 284 |
+
width=width,
|
| 285 |
+
heads=heads,
|
| 286 |
+
layers=num_encoder_layers,
|
| 287 |
+
qkv_bias=qkv_bias,
|
| 288 |
+
use_ln_post=use_ln_post,
|
| 289 |
+
qk_norm=qk_norm,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
self.pre_kl = nn.Linear(width, embed_dim * 2)
|
| 293 |
+
self.post_kl = nn.Linear(embed_dim, width)
|
| 294 |
+
|
| 295 |
+
self.transformer = Transformer(
|
| 296 |
+
width=width,
|
| 297 |
+
layers=num_decoder_layers,
|
| 298 |
+
heads=heads,
|
| 299 |
+
qkv_bias=qkv_bias,
|
| 300 |
+
qk_norm=qk_norm,
|
| 301 |
+
drop_path_rate=drop_path_rate,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
self.geo_decoder = CrossAttentionDecoder(
|
| 305 |
+
fourier_embedder=self.fourier_embedder,
|
| 306 |
+
out_channels=1,
|
| 307 |
+
mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
|
| 308 |
+
downsample_ratio=geo_decoder_downsample_ratio,
|
| 309 |
+
enable_ln_post=self.geo_decoder_ln_post,
|
| 310 |
+
width=width // geo_decoder_downsample_ratio,
|
| 311 |
+
heads=heads // geo_decoder_downsample_ratio,
|
| 312 |
+
qkv_bias=qkv_bias,
|
| 313 |
+
qk_norm=qk_norm,
|
| 314 |
+
label_type=label_type,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.scale_factor = scale_factor
|
| 318 |
+
self.latent_shape = (num_latents, embed_dim)
|
| 319 |
+
|
| 320 |
+
if ckpt_path is not None:
|
| 321 |
+
self.init_from_ckpt(ckpt_path)
|
| 322 |
+
|
| 323 |
+
def forward(self, latents):
|
| 324 |
+
latents = self.post_kl(latents)
|
| 325 |
+
latents = self.transformer(latents)
|
| 326 |
+
return latents
|
| 327 |
+
|
| 328 |
+
def encode(self, surface, sample_posterior=True, return_pc_info=False):
|
| 329 |
+
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
| 330 |
+
latents, pc_infos = self.encoder(pc, feats)
|
| 331 |
+
# print(latents.shape, self.pre_kl.weight.shape)
|
| 332 |
+
moments = self.pre_kl(latents)
|
| 333 |
+
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
| 334 |
+
if sample_posterior:
|
| 335 |
+
latents = posterior.sample()
|
| 336 |
+
else:
|
| 337 |
+
latents = posterior.mode()
|
| 338 |
+
if return_pc_info:
|
| 339 |
+
return latents, pc_infos
|
| 340 |
+
else:
|
| 341 |
+
return latents
|
| 342 |
+
|
| 343 |
+
def encode_shape(self, surface, return_pc_info=False):
|
| 344 |
+
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
| 345 |
+
latents, pc_infos = self.encoder(pc, feats)
|
| 346 |
+
if return_pc_info:
|
| 347 |
+
return latents, pc_infos
|
| 348 |
+
else:
|
| 349 |
+
return latents
|
| 350 |
+
|
| 351 |
+
def decode(self, latents):
|
| 352 |
+
latents = self.post_kl(latents)
|
| 353 |
+
latents = self.transformer(latents)
|
| 354 |
+
return latents
|
| 355 |
+
|
| 356 |
+
def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
|
| 357 |
+
logits = self.geo_decoder(queries=queries, latents=latents).squeeze(-1)
|
| 358 |
+
return logits
|
| 359 |
+
|
| 360 |
+
def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
|
| 361 |
+
coarse_kwargs = copy.deepcopy(kwargs)
|
| 362 |
+
coarse_kwargs["octree_resolution"] = 256
|
| 363 |
+
|
| 364 |
+
with synchronize_timer("Coarse Volume decoding"):
|
| 365 |
+
coarse_grid_logits = self.volume_decoder(
|
| 366 |
+
latents, self.geo_decoder, **coarse_kwargs
|
| 367 |
+
)
|
| 368 |
+
with synchronize_timer("Coarse Surface extraction"):
|
| 369 |
+
coarse_mesh = self.surface_extractor(coarse_grid_logits, **coarse_kwargs)
|
| 370 |
+
|
| 371 |
+
assert len(coarse_mesh) == 1
|
| 372 |
+
bbox_gen_by_coarse_matching_cube_mesh = np.stack(
|
| 373 |
+
[coarse_mesh[0].mesh_v.max(0), coarse_mesh[0].mesh_v.min(0)]
|
| 374 |
+
)
|
| 375 |
+
bbox_gen_by_coarse_matching_cube_mesh_range = (
|
| 376 |
+
bbox_gen_by_coarse_matching_cube_mesh[0]
|
| 377 |
+
- bbox_gen_by_coarse_matching_cube_mesh[1]
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# extend by 10%
|
| 381 |
+
bbox_gen_by_coarse_matching_cube_mesh[0] += (
|
| 382 |
+
bbox_gen_by_coarse_matching_cube_mesh_range * 0.1
|
| 383 |
+
)
|
| 384 |
+
bbox_gen_by_coarse_matching_cube_mesh[1] -= (
|
| 385 |
+
bbox_gen_by_coarse_matching_cube_mesh_range * 0.1
|
| 386 |
+
)
|
| 387 |
+
with synchronize_timer("Fine-grained Volume decoding"):
|
| 388 |
+
grid_logits = self.volume_decoder(
|
| 389 |
+
latents,
|
| 390 |
+
self.geo_decoder,
|
| 391 |
+
bbox_corner=bbox_gen_by_coarse_matching_cube_mesh[None],
|
| 392 |
+
**kwargs,
|
| 393 |
+
)
|
| 394 |
+
with synchronize_timer("Fine-grained Surface extraction"):
|
| 395 |
+
outputs = self.surface_extractor(
|
| 396 |
+
grid_logits,
|
| 397 |
+
bbox_corner=bbox_gen_by_coarse_matching_cube_mesh[None],
|
| 398 |
+
**kwargs,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
return outputs
|
| 402 |
+
|
| 403 |
+
def latent2mesh_2(
|
| 404 |
+
self,
|
| 405 |
+
latents: torch.FloatTensor,
|
| 406 |
+
bounds: Union[Tuple[float], List[float], float] = 1.1,
|
| 407 |
+
octree_depth: int = 7,
|
| 408 |
+
num_chunks: int = 10000,
|
| 409 |
+
mc_level: float = -1 / 512,
|
| 410 |
+
octree_resolution: int = None,
|
| 411 |
+
mc_mode: str = "mc",
|
| 412 |
+
) -> List[Latent2MeshOutput]:
|
| 413 |
+
"""
|
| 414 |
+
Args:
|
| 415 |
+
latents: [bs, num_latents, dim]
|
| 416 |
+
bounds:
|
| 417 |
+
octree_depth:
|
| 418 |
+
num_chunks:
|
| 419 |
+
Returns:
|
| 420 |
+
mesh_outputs (List[MeshOutput]): the mesh outputs list.
|
| 421 |
+
"""
|
| 422 |
+
outputs = []
|
| 423 |
+
geometric_func = partial(self.query_geometry, latents=latents)
|
| 424 |
+
# 2. decode geometry
|
| 425 |
+
device = latents.device
|
| 426 |
+
if mc_mode == "dmc" and not hasattr(self, "diffdmc"):
|
| 427 |
+
from diso import DiffDMC
|
| 428 |
+
|
| 429 |
+
self.diffdmc = DiffDMC(dtype=torch.float32).to(device)
|
| 430 |
+
mesh_v_f, has_surface = extract_geometry_fast(
|
| 431 |
+
geometric_func=geometric_func,
|
| 432 |
+
device=device,
|
| 433 |
+
batch_size=len(latents),
|
| 434 |
+
bounds=bounds,
|
| 435 |
+
octree_depth=octree_depth,
|
| 436 |
+
num_chunks=num_chunks,
|
| 437 |
+
disable=False,
|
| 438 |
+
mc_level=mc_level,
|
| 439 |
+
octree_resolution=octree_resolution,
|
| 440 |
+
diffdmc=self.diffdmc if mc_mode == "dmc" else None,
|
| 441 |
+
mc_mode=mc_mode,
|
| 442 |
+
)
|
| 443 |
+
# 3. decode texture
|
| 444 |
+
for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
|
| 445 |
+
if not is_surface:
|
| 446 |
+
outputs.append(None)
|
| 447 |
+
continue
|
| 448 |
+
out = Latent2MeshOutput()
|
| 449 |
+
out.mesh_v = mesh_v
|
| 450 |
+
out.mesh_f = mesh_f
|
| 451 |
+
outputs.append(out)
|
| 452 |
+
return outputs
|
XPart/partgen/models/autoencoders/surface_extractors.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
from typing import Union, Tuple, List
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from skimage import measure
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Latent2MeshOutput:
|
| 23 |
+
def __init__(self, mesh_v=None, mesh_f=None):
|
| 24 |
+
self.mesh_v = mesh_v
|
| 25 |
+
self.mesh_f = mesh_f
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def center_vertices(vertices):
|
| 29 |
+
"""Translate the vertices so that bounding box is centered at zero."""
|
| 30 |
+
vert_min = vertices.min(dim=0)[0]
|
| 31 |
+
vert_max = vertices.max(dim=0)[0]
|
| 32 |
+
vert_center = 0.5 * (vert_min + vert_max)
|
| 33 |
+
return vertices - vert_center
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SurfaceExtractor:
|
| 37 |
+
def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
|
| 38 |
+
"""
|
| 39 |
+
Compute grid size, bounding box minimum coordinates, and bounding box size based on input
|
| 40 |
+
bounds and resolution.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single
|
| 44 |
+
float representing half side length.
|
| 45 |
+
If float, bounds are assumed symmetric around zero in all axes.
|
| 46 |
+
Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax].
|
| 47 |
+
octree_resolution (int): Resolution of the octree grid.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1.
|
| 51 |
+
bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin).
|
| 52 |
+
bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.).
|
| 53 |
+
"""
|
| 54 |
+
if isinstance(bounds, float):
|
| 55 |
+
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 56 |
+
|
| 57 |
+
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
| 58 |
+
bbox_size = bbox_max - bbox_min
|
| 59 |
+
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
| 60 |
+
return grid_size, bbox_min, bbox_size
|
| 61 |
+
|
| 62 |
+
def run(self, *args, **kwargs):
|
| 63 |
+
"""
|
| 64 |
+
Abstract method to extract surface mesh from grid logits.
|
| 65 |
+
|
| 66 |
+
This method should be implemented by subclasses.
|
| 67 |
+
|
| 68 |
+
Raises:
|
| 69 |
+
NotImplementedError: Always, since this is an abstract method.
|
| 70 |
+
"""
|
| 71 |
+
return NotImplementedError
|
| 72 |
+
|
| 73 |
+
def __call__(self, grid_logits, **kwargs):
|
| 74 |
+
"""
|
| 75 |
+
Process a batch of grid logits to extract surface meshes.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...).
|
| 79 |
+
**kwargs: Additional keyword arguments passed to the `run` method.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch.
|
| 83 |
+
If extraction fails for a grid, None is appended at that position.
|
| 84 |
+
"""
|
| 85 |
+
outputs = []
|
| 86 |
+
for i in range(grid_logits.shape[0]):
|
| 87 |
+
try:
|
| 88 |
+
vertices, faces = self.run(grid_logits[i], **kwargs)
|
| 89 |
+
vertices = vertices.astype(np.float32)
|
| 90 |
+
faces = np.ascontiguousarray(faces)
|
| 91 |
+
outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
|
| 92 |
+
|
| 93 |
+
except Exception:
|
| 94 |
+
import traceback
|
| 95 |
+
traceback.print_exc()
|
| 96 |
+
outputs.append(None)
|
| 97 |
+
|
| 98 |
+
return outputs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class MCSurfaceExtractor(SurfaceExtractor):
|
| 102 |
+
def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
|
| 103 |
+
"""
|
| 104 |
+
Extract surface mesh using the Marching Cubes algorithm.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
|
| 108 |
+
mc_level (float): The level (iso-value) at which to extract the surface.
|
| 109 |
+
bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length.
|
| 110 |
+
octree_resolution (int): Resolution of the octree grid.
|
| 111 |
+
**kwargs: Additional keyword arguments (ignored).
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tuple[np.ndarray, np.ndarray]: Tuple containing:
|
| 115 |
+
- vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding
|
| 116 |
+
box coordinates.
|
| 117 |
+
- faces (np.ndarray): Extracted mesh faces (triangles).
|
| 118 |
+
"""
|
| 119 |
+
vertices, faces, normals, _ = measure.marching_cubes(grid_logit.cpu().numpy(),
|
| 120 |
+
mc_level,
|
| 121 |
+
method="lewiner")
|
| 122 |
+
grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
|
| 123 |
+
vertices = vertices / grid_size * bbox_size + bbox_min
|
| 124 |
+
return vertices, faces
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class DMCSurfaceExtractor(SurfaceExtractor):
|
| 128 |
+
def run(self, grid_logit, *, octree_resolution, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
|
| 134 |
+
octree_resolution (int): Resolution of the octree grid.
|
| 135 |
+
**kwargs: Additional keyword arguments (ignored).
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Tuple[np.ndarray, np.ndarray]: Tuple containing:
|
| 139 |
+
- vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy.
|
| 140 |
+
- faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order.
|
| 141 |
+
|
| 142 |
+
Raises:
|
| 143 |
+
ImportError: If the 'diso' package is not installed.
|
| 144 |
+
"""
|
| 145 |
+
device = grid_logit.device
|
| 146 |
+
if not hasattr(self, 'dmc'):
|
| 147 |
+
try:
|
| 148 |
+
from diso import DiffDMC
|
| 149 |
+
self.dmc = DiffDMC(dtype=torch.float32).to(device)
|
| 150 |
+
except:
|
| 151 |
+
raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
|
| 152 |
+
sdf = -grid_logit / octree_resolution
|
| 153 |
+
sdf = sdf.to(torch.float32).contiguous()
|
| 154 |
+
verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
|
| 155 |
+
verts = center_vertices(verts)
|
| 156 |
+
vertices = verts.detach().cpu().numpy()
|
| 157 |
+
faces = faces.detach().cpu().numpy()[:, ::-1]
|
| 158 |
+
return vertices, faces
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
SurfaceExtractors = {
|
| 162 |
+
'mc': MCSurfaceExtractor,
|
| 163 |
+
'dmc': DMCSurfaceExtractor,
|
| 164 |
+
}
|
XPart/partgen/models/autoencoders/volume_decoders.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 16 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 17 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 18 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 19 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 20 |
+
|
| 21 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 22 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 23 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 24 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 25 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
from typing import Union, Tuple, List, Callable
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
from einops import repeat
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
from .attention_blocks import CrossAttentionDecoder
|
| 38 |
+
from ...utils.misc import logger
|
| 39 |
+
from ...utils.mesh_utils import (
|
| 40 |
+
extract_near_surface_volume_fn,
|
| 41 |
+
generate_dense_grid_points,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class VanillaVolumeDecoder:
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def __call__(
|
| 48 |
+
self,
|
| 49 |
+
latents: torch.FloatTensor,
|
| 50 |
+
geo_decoder: Callable,
|
| 51 |
+
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
| 52 |
+
num_chunks: int = 10000,
|
| 53 |
+
octree_resolution: int = None,
|
| 54 |
+
enable_pbar: bool = True,
|
| 55 |
+
**kwargs,
|
| 56 |
+
):
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
Perform volume decoding with a vanilla decoder
|
| 60 |
+
Args:
|
| 61 |
+
latents (torch.FloatTensor): Latent vectors to decode.
|
| 62 |
+
geo_decoder (Callable): The geometry decoder function.
|
| 63 |
+
bounds (Union[Tuple[float], List[float], float]): Bounding box for the volume.
|
| 64 |
+
num_chunks (int): Number of chunks to process at a time.
|
| 65 |
+
octree_resolution (int): Resolution of the octree for sampling points.
|
| 66 |
+
enable_pbar (bool): Whether to enable progress bar.
|
| 67 |
+
Returns:
|
| 68 |
+
grid_logits (torch.FloatTensor): Decoded 3D volume logits.
|
| 69 |
+
"""
|
| 70 |
+
device = latents.device
|
| 71 |
+
dtype = latents.dtype
|
| 72 |
+
batch_size = latents.shape[0]
|
| 73 |
+
|
| 74 |
+
# 1. generate query points
|
| 75 |
+
if isinstance(bounds, float):
|
| 76 |
+
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
| 77 |
+
|
| 78 |
+
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
|
| 79 |
+
xyz_samples, grid_size, length = generate_dense_grid_points(
|
| 80 |
+
bbox_min=bbox_min,
|
| 81 |
+
bbox_max=bbox_max,
|
| 82 |
+
octree_resolution=octree_resolution,
|
| 83 |
+
indexing="ij",
|
| 84 |
+
)
|
| 85 |
+
xyz_samples = (
|
| 86 |
+
torch.from_numpy(xyz_samples)
|
| 87 |
+
.to(device, dtype=dtype)
|
| 88 |
+
.contiguous()
|
| 89 |
+
.reshape(-1, 3)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# 2. latents to 3d volume
|
| 93 |
+
batch_logits = []
|
| 94 |
+
for start in tqdm(
|
| 95 |
+
range(0, xyz_samples.shape[0], num_chunks),
|
| 96 |
+
desc=f"Volume Decoding",
|
| 97 |
+
disable=not enable_pbar,
|
| 98 |
+
):
|
| 99 |
+
chunk_queries = xyz_samples[start : start + num_chunks, :]
|
| 100 |
+
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
|
| 101 |
+
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
| 102 |
+
batch_logits.append(logits)
|
| 103 |
+
|
| 104 |
+
grid_logits = torch.cat(batch_logits, dim=1)
|
| 105 |
+
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
|
| 106 |
+
|
| 107 |
+
return grid_logits
|
XPart/partgen/models/conditioner/condioner_release.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 15 |
+
#
|
| 16 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 17 |
+
# you may not use this file except in compliance with the License.
|
| 18 |
+
# You may obtain a copy of the License at
|
| 19 |
+
#
|
| 20 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 21 |
+
#
|
| 22 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 23 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 24 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 25 |
+
# See the License for the specific language governing permissions and
|
| 26 |
+
# limitations under the License.
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from .part_encoders import PartEncoder
|
| 30 |
+
from ..autoencoders import VolumeDecoderShapeVAE
|
| 31 |
+
from ...utils.misc import (
|
| 32 |
+
instantiate_from_config,
|
| 33 |
+
instantiate_non_trainable_model,
|
| 34 |
+
)
|
| 35 |
+
from .sonata_extractor import SonataFeatureExtractor
|
| 36 |
+
from .part_encoders import PartEncoder
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def debug_sonata_feat(points, feats):
|
| 40 |
+
from sklearn.decomposition import PCA
|
| 41 |
+
import numpy as np
|
| 42 |
+
import trimesh
|
| 43 |
+
import os
|
| 44 |
+
|
| 45 |
+
point_num = points.shape[0]
|
| 46 |
+
feat_save = feats.float().detach().cpu().numpy()
|
| 47 |
+
data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
|
| 48 |
+
pca = PCA(n_components=3)
|
| 49 |
+
data_reduced = pca.fit_transform(data_scaled)
|
| 50 |
+
data_reduced = (data_reduced - data_reduced.min()) / (
|
| 51 |
+
data_reduced.max() - data_reduced.min()
|
| 52 |
+
)
|
| 53 |
+
colors_255 = (data_reduced * 255).astype(np.uint8)
|
| 54 |
+
colors_255 = np.concatenate(
|
| 55 |
+
[colors_255, np.ones((point_num, 1), dtype=np.uint8) * 255], axis=-1
|
| 56 |
+
)
|
| 57 |
+
pc_save = trimesh.points.PointCloud(points, colors=colors_255)
|
| 58 |
+
return pc_save
|
| 59 |
+
# pc_save.export(os.path.join("debug", "point_pca.glb"))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Conditioner(torch.nn.Module):
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
use_image=False,
|
| 67 |
+
use_geo=True,
|
| 68 |
+
use_obj=True,
|
| 69 |
+
use_seg_feat=False,
|
| 70 |
+
geo_cfg=None,
|
| 71 |
+
obj_encoder_cfg=None,
|
| 72 |
+
seg_feat_cfg=None,
|
| 73 |
+
**kwargs
|
| 74 |
+
):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.use_image = use_image
|
| 77 |
+
self.use_obj = use_obj
|
| 78 |
+
self.use_geo = use_geo
|
| 79 |
+
self.use_seg_feat = use_seg_feat
|
| 80 |
+
self.geo_cfg = geo_cfg
|
| 81 |
+
self.obj_encoder_cfg = obj_encoder_cfg
|
| 82 |
+
self.seg_feat_cfg = seg_feat_cfg
|
| 83 |
+
if use_geo and geo_cfg is not None:
|
| 84 |
+
self.geo_encoder: PartEncoder = instantiate_from_config(geo_cfg)
|
| 85 |
+
if hasattr(geo_cfg, "output_dim"):
|
| 86 |
+
self.geo_out_proj = torch.nn.Linear(1024 + 512, geo_cfg.output_dim)
|
| 87 |
+
|
| 88 |
+
if use_obj and obj_encoder_cfg is not None:
|
| 89 |
+
self.obj_encoder: VolumeDecoderShapeVAE = instantiate_non_trainable_model(
|
| 90 |
+
obj_encoder_cfg
|
| 91 |
+
)
|
| 92 |
+
if hasattr(obj_encoder_cfg, "output_dim"):
|
| 93 |
+
self.obj_out_proj = torch.nn.Linear(
|
| 94 |
+
1024 + 512, obj_encoder_cfg.output_dim
|
| 95 |
+
)
|
| 96 |
+
if use_seg_feat and seg_feat_cfg is not None:
|
| 97 |
+
self.seg_feat_encoder: SonataFeatureExtractor = (
|
| 98 |
+
instantiate_non_trainable_model(seg_feat_cfg)
|
| 99 |
+
)
|
| 100 |
+
if hasattr(seg_feat_cfg, "output_dim"):
|
| 101 |
+
self.seg_feat_outproj = torch.nn.Linear(512, seg_feat_cfg.output_dim)
|
| 102 |
+
|
| 103 |
+
def forward(self, part_surface_inbbox, object_surface):
|
| 104 |
+
bz = part_surface_inbbox.shape[0]
|
| 105 |
+
context = {}
|
| 106 |
+
# geo_cond
|
| 107 |
+
if self.use_geo:
|
| 108 |
+
context["geo_cond"], local_pc_infos = self.geo_encoder(
|
| 109 |
+
part_surface_inbbox,
|
| 110 |
+
object_surface,
|
| 111 |
+
return_local_pc_info=True,
|
| 112 |
+
)
|
| 113 |
+
# obj cond
|
| 114 |
+
if self.use_obj:
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
context["obj_cond"], global_pc_infos = self.obj_encoder.encode_shape(
|
| 117 |
+
object_surface, return_pc_info=True
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# seg feat cond
|
| 121 |
+
if self.use_seg_feat:
|
| 122 |
+
# TODO: batchsize must be One
|
| 123 |
+
num_parts = part_surface_inbbox.shape[0]
|
| 124 |
+
with torch.autocast(device_type="cuda", dtype=torch.float32):
|
| 125 |
+
# encode sonata feature
|
| 126 |
+
# with torch.cuda.amp.autocast(enabled=False):
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
point, normal = (
|
| 129 |
+
object_surface[:1, ..., :3].float(),
|
| 130 |
+
object_surface[:1, ..., 3:6].float(),
|
| 131 |
+
)
|
| 132 |
+
point_feat = self.seg_feat_encoder(point, normal)
|
| 133 |
+
# local feat
|
| 134 |
+
if self.use_obj:
|
| 135 |
+
nearest_global_matches = torch.argmin(
|
| 136 |
+
torch.cdist(global_pc_infos[0], object_surface[..., :3]), dim=-1
|
| 137 |
+
)
|
| 138 |
+
# global feat
|
| 139 |
+
global_point_feats = point_feat.expand(num_parts, -1, -1).gather(
|
| 140 |
+
1,
|
| 141 |
+
nearest_global_matches.unsqueeze(-1).expand(
|
| 142 |
+
-1, -1, point_feat.size(-1)
|
| 143 |
+
),
|
| 144 |
+
)
|
| 145 |
+
context["obj_cond"] = torch.concat(
|
| 146 |
+
[context["obj_cond"], global_point_feats], dim=-1
|
| 147 |
+
).to(dtype=self.obj_out_proj.weight.dtype)
|
| 148 |
+
if hasattr(self, "obj_out_proj"):
|
| 149 |
+
context["obj_cond"] = self.obj_out_proj(
|
| 150 |
+
context["obj_cond"]
|
| 151 |
+
) # .float()
|
| 152 |
+
if self.use_geo:
|
| 153 |
+
nearest_local_matches = torch.argmin(
|
| 154 |
+
torch.cdist(local_pc_infos[0], object_surface[..., :3]), dim=-1
|
| 155 |
+
)
|
| 156 |
+
local_point_feats = point_feat.expand(num_parts, -1, -1).gather(
|
| 157 |
+
1,
|
| 158 |
+
nearest_local_matches.unsqueeze(-1).expand(
|
| 159 |
+
-1, -1, point_feat.size(-1)
|
| 160 |
+
),
|
| 161 |
+
)
|
| 162 |
+
context["geo_cond"] = torch.concat(
|
| 163 |
+
[context["geo_cond"], local_point_feats],
|
| 164 |
+
dim=-1,
|
| 165 |
+
).to(dtype=self.geo_out_proj.weight.dtype)
|
| 166 |
+
if hasattr(self, "geo_out_proj"):
|
| 167 |
+
context["geo_cond"] = self.geo_out_proj(
|
| 168 |
+
context["geo_cond"]
|
| 169 |
+
) # .float()
|
| 170 |
+
return context
|
XPart/partgen/models/conditioner/part_encoders.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from ...utils.misc import (
|
| 3 |
+
instantiate_from_config,
|
| 4 |
+
instantiate_non_trainable_model,
|
| 5 |
+
)
|
| 6 |
+
from ..autoencoders.model import (
|
| 7 |
+
VolumeDecoderShapeVAE,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PartEncoder(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
use_local=True,
|
| 15 |
+
local_global_feat_dim=None,
|
| 16 |
+
local_geo_cfg=None,
|
| 17 |
+
local_feat_type="latents",
|
| 18 |
+
num_tokens_cond=2048,
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.local_global_feat_dim = local_global_feat_dim
|
| 22 |
+
self.local_feat_type = local_feat_type
|
| 23 |
+
self.num_tokens_cond = num_tokens_cond
|
| 24 |
+
# local
|
| 25 |
+
self.use_local = use_local
|
| 26 |
+
if use_local:
|
| 27 |
+
if local_geo_cfg is None:
|
| 28 |
+
raise ValueError(
|
| 29 |
+
"local_geo_cfg must be provided when use_local is True"
|
| 30 |
+
)
|
| 31 |
+
assert (
|
| 32 |
+
"ShapeVAE" in local_geo_cfg.get("target").split(".")[-1]
|
| 33 |
+
), "local_geo_cfg must be a ShapeVAE config"
|
| 34 |
+
self.local_encoder: VolumeDecoderShapeVAE = instantiate_from_config(
|
| 35 |
+
local_geo_cfg
|
| 36 |
+
)
|
| 37 |
+
if self.local_global_feat_dim is not None:
|
| 38 |
+
self.local_out_layer = nn.Linear(
|
| 39 |
+
(
|
| 40 |
+
local_geo_cfg.params.embed_dim
|
| 41 |
+
if self.local_feat_type == "latents"
|
| 42 |
+
else local_geo_cfg.params.width
|
| 43 |
+
),
|
| 44 |
+
self.local_global_feat_dim,
|
| 45 |
+
bias=True,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def forward(self, part_surface_inbbox, object_surface, return_local_pc_info=False):
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
aabb: (B, 2, 3) tensor representing the axis-aligned bounding box
|
| 52 |
+
object_surface: (B, N, 3) tensor representing the surface points of the object
|
| 53 |
+
Returns:
|
| 54 |
+
local_features: (B, num_tokens_cond, C) tensor of local features
|
| 55 |
+
global_features: (B,num_tokens_cond, C) tensor of global features
|
| 56 |
+
"""
|
| 57 |
+
# random selection if more than num_tokens_cond points
|
| 58 |
+
if self.use_local:
|
| 59 |
+
# with torch.autocast(
|
| 60 |
+
# device_type=part_surface_inbbox.device.type,
|
| 61 |
+
# dtype=torch.float16,
|
| 62 |
+
# ):
|
| 63 |
+
# with torch.no_grad():
|
| 64 |
+
if self.local_feat_type == "latents":
|
| 65 |
+
local_features, local_pc_infos = self.local_encoder.encode(
|
| 66 |
+
part_surface_inbbox, sample_posterior=True, return_pc_info=True
|
| 67 |
+
) # (B, num_tokens_cond, C)
|
| 68 |
+
elif self.local_feat_type == "latents_shape":
|
| 69 |
+
local_features, local_pc_infos = self.local_encoder.encode_shape(
|
| 70 |
+
part_surface_inbbox, return_pc_info=True
|
| 71 |
+
) # (B, num_tokens_cond, C)
|
| 72 |
+
elif self.local_feat_type == "miche-point-query-structural-vae":
|
| 73 |
+
local_features, local_pc_infos = self.local_encoder.encode(
|
| 74 |
+
part_surface_inbbox, sample_posterior=True, return_pc_info=True
|
| 75 |
+
)
|
| 76 |
+
local_features = self.local_encoder(local_features)
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
f"local_feat_type {self.local_feat_type} not supported"
|
| 80 |
+
)
|
| 81 |
+
# ouput layer
|
| 82 |
+
geo_features = (
|
| 83 |
+
self.local_out_layer(local_features)
|
| 84 |
+
if hasattr(self, "local_out_layer")
|
| 85 |
+
else local_features
|
| 86 |
+
)
|
| 87 |
+
if return_local_pc_info:
|
| 88 |
+
return geo_features, local_pc_infos
|
| 89 |
+
return geo_features
|
XPart/partgen/models/conditioner/sonata_extractor.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .. import sonata
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Union, Optional
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SonataFeatureExtractor(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Feature extractor using Sonata backbone with MLP projection.
|
| 12 |
+
Supports batch processing and gradient computation.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
ckpt_path: Optional[str] = "",
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
# Load Sonata model
|
| 22 |
+
self.sonata = sonata.load_by_config(
|
| 23 |
+
str(Path(__file__).parent.parent.parent / "config" / "sonata.json")
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Store original dtype for later reference
|
| 27 |
+
# self._original_dtype = next(self.parameters()).dtype
|
| 28 |
+
|
| 29 |
+
# Define MLP projection head (same as in train-sonata.py)
|
| 30 |
+
self.mlp = nn.Sequential(
|
| 31 |
+
nn.Linear(1232, 512),
|
| 32 |
+
nn.GELU(),
|
| 33 |
+
nn.Linear(512, 512),
|
| 34 |
+
nn.GELU(),
|
| 35 |
+
nn.Linear(512, 512),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Define transform
|
| 39 |
+
self.transform = sonata.transform.default()
|
| 40 |
+
|
| 41 |
+
# Load checkpoint if provided
|
| 42 |
+
if ckpt_path:
|
| 43 |
+
self.load_checkpoint(ckpt_path)
|
| 44 |
+
|
| 45 |
+
def load_checkpoint(self, checkpoint_path: str):
|
| 46 |
+
"""Load model weights from checkpoint."""
|
| 47 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 48 |
+
|
| 49 |
+
# Extract state dict from Lightning checkpoint
|
| 50 |
+
if "state_dict" in checkpoint:
|
| 51 |
+
state_dict = checkpoint["state_dict"]
|
| 52 |
+
# Remove 'model.' prefix if present from Lightning
|
| 53 |
+
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
|
| 54 |
+
else:
|
| 55 |
+
state_dict = checkpoint
|
| 56 |
+
|
| 57 |
+
# Debug: Show all keys in checkpoint
|
| 58 |
+
print("\n=== Checkpoint Keys ===")
|
| 59 |
+
print(f"Total keys in checkpoint: {len(state_dict)}")
|
| 60 |
+
print("\nSample keys:")
|
| 61 |
+
for i, key in enumerate(list(state_dict.keys())[:10]):
|
| 62 |
+
print(f" {key}")
|
| 63 |
+
if len(state_dict) > 10:
|
| 64 |
+
print(f" ... and {len(state_dict) - 10} more keys")
|
| 65 |
+
|
| 66 |
+
# Load only the relevant weights
|
| 67 |
+
sonata_dict = {
|
| 68 |
+
k.replace("sonata.", ""): v
|
| 69 |
+
for k, v in state_dict.items()
|
| 70 |
+
if k.startswith("sonata.")
|
| 71 |
+
}
|
| 72 |
+
mlp_dict = {
|
| 73 |
+
k.replace("mlp.", ""): v
|
| 74 |
+
for k, v in state_dict.items()
|
| 75 |
+
if k.startswith("mlp.")
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
print(f"\nFound {len(sonata_dict)} Sonata keys")
|
| 79 |
+
print(f"Found {len(mlp_dict)} MLP keys")
|
| 80 |
+
|
| 81 |
+
# Load Sonata weights and show missing/unexpected keys
|
| 82 |
+
if sonata_dict:
|
| 83 |
+
print("\n=== Loading Sonata Weights ===")
|
| 84 |
+
result = self.sonata.load_state_dict(sonata_dict, strict=False)
|
| 85 |
+
if result.missing_keys:
|
| 86 |
+
print(f"\nMissing keys ({len(result.missing_keys)}):")
|
| 87 |
+
for key in result.missing_keys[:20]: # Show first 20
|
| 88 |
+
print(f" - {key}")
|
| 89 |
+
if len(result.missing_keys) > 20:
|
| 90 |
+
print(f" ... and {len(result.missing_keys) - 20} more")
|
| 91 |
+
else:
|
| 92 |
+
print("No missing keys!")
|
| 93 |
+
|
| 94 |
+
if result.unexpected_keys:
|
| 95 |
+
print(f"\nUnexpected keys ({len(result.unexpected_keys)}):")
|
| 96 |
+
for key in result.unexpected_keys[:20]: # Show first 20
|
| 97 |
+
print(f" - {key}")
|
| 98 |
+
if len(result.unexpected_keys) > 20:
|
| 99 |
+
print(f" ... and {len(result.unexpected_keys) - 20} more")
|
| 100 |
+
else:
|
| 101 |
+
print("No unexpected keys!")
|
| 102 |
+
|
| 103 |
+
# Load MLP weights
|
| 104 |
+
if mlp_dict:
|
| 105 |
+
print("\n=== Loading MLP Weights ===")
|
| 106 |
+
result = self.mlp.load_state_dict(mlp_dict, strict=False)
|
| 107 |
+
if result.missing_keys:
|
| 108 |
+
print(f"\nMissing keys: {result.missing_keys}")
|
| 109 |
+
if result.unexpected_keys:
|
| 110 |
+
print(f"Unexpected keys: {result.unexpected_keys}")
|
| 111 |
+
print("MLP weights loaded successfully!")
|
| 112 |
+
|
| 113 |
+
print(f"\n✓ Loaded checkpoint from {checkpoint_path}")
|
| 114 |
+
|
| 115 |
+
def prepare_batch_data(
|
| 116 |
+
self, points: torch.Tensor, normals: Optional[torch.Tensor] = None
|
| 117 |
+
) -> Dict:
|
| 118 |
+
"""
|
| 119 |
+
Prepare batch data for Sonata model.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
points: [B, N, 3] or [N, 3] tensor of point coordinates
|
| 123 |
+
normals: [B, N, 3] or [N, 3] tensor of normals (optional)
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Dictionary formatted for Sonata input
|
| 127 |
+
"""
|
| 128 |
+
# Handle single batch case
|
| 129 |
+
if points.dim() == 2:
|
| 130 |
+
points = points.unsqueeze(0)
|
| 131 |
+
if normals is not None:
|
| 132 |
+
normals = normals.unsqueeze(0)
|
| 133 |
+
# print('Sonata points shape: ', points.shape)
|
| 134 |
+
B, N, _ = points.shape
|
| 135 |
+
|
| 136 |
+
# Prepare batch indices
|
| 137 |
+
batch_idx = torch.arange(B).view(-1, 1).repeat(1, N).reshape(-1)
|
| 138 |
+
|
| 139 |
+
# Flatten points for Sonata format
|
| 140 |
+
coord = points.reshape(B * N, 3)
|
| 141 |
+
|
| 142 |
+
if normals is not None:
|
| 143 |
+
normal = normals.reshape(B * N, 3)
|
| 144 |
+
else:
|
| 145 |
+
# Generate dummy normals if not provided
|
| 146 |
+
normal = torch.ones_like(coord)
|
| 147 |
+
|
| 148 |
+
# Generate dummy colors
|
| 149 |
+
color = torch.ones_like(coord)
|
| 150 |
+
|
| 151 |
+
# Function to convert tensor to numpy array, handling BFloat16
|
| 152 |
+
def to_numpy(tensor):
|
| 153 |
+
# First convert to CPU if needed
|
| 154 |
+
if tensor.is_cuda:
|
| 155 |
+
tensor = tensor.cpu()
|
| 156 |
+
# Convert BFloat16 or other unsupported dtypes to float32
|
| 157 |
+
if tensor.dtype not in [
|
| 158 |
+
torch.float32,
|
| 159 |
+
torch.float64,
|
| 160 |
+
torch.int32,
|
| 161 |
+
torch.int64,
|
| 162 |
+
torch.uint8,
|
| 163 |
+
torch.int8,
|
| 164 |
+
torch.int16,
|
| 165 |
+
]:
|
| 166 |
+
tensor = tensor.to(torch.float32)
|
| 167 |
+
# Then convert to numpy
|
| 168 |
+
return tensor.numpy()
|
| 169 |
+
|
| 170 |
+
# Create data dict
|
| 171 |
+
data_dict = {
|
| 172 |
+
"coord": to_numpy(coord),
|
| 173 |
+
"normal": to_numpy(normal),
|
| 174 |
+
"color": to_numpy(color),
|
| 175 |
+
"batch": to_numpy(batch_idx),
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
# Apply transform
|
| 179 |
+
data_dict = self.transform(data_dict)
|
| 180 |
+
|
| 181 |
+
return data_dict, B, N
|
| 182 |
+
|
| 183 |
+
def forward(
|
| 184 |
+
self, points: torch.Tensor, normals: Optional[torch.Tensor] = None
|
| 185 |
+
) -> torch.Tensor:
|
| 186 |
+
"""
|
| 187 |
+
Extract features from point clouds.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
points: [B, N, 3] or [N, 3] tensor of point coordinates
|
| 191 |
+
normals: [B, N, 3] or [N, 3] tensor of normals (optional)
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
features: [B, N, 512] or [N, 512] tensor of features
|
| 195 |
+
"""
|
| 196 |
+
# Store original shape
|
| 197 |
+
original_shape = points.shape
|
| 198 |
+
single_batch = points.dim() == 2
|
| 199 |
+
|
| 200 |
+
# Prepare data for Sonata
|
| 201 |
+
data_dict, B, N = self.prepare_batch_data(points, normals)
|
| 202 |
+
|
| 203 |
+
# Move to GPU if needed and convert to appropriate dtype
|
| 204 |
+
device = points.device
|
| 205 |
+
dtype = points.dtype
|
| 206 |
+
|
| 207 |
+
# Make sure the entire model is in the correct dtype
|
| 208 |
+
# if dtype != self._original_dtype:
|
| 209 |
+
# self.to(dtype)
|
| 210 |
+
# self._original_dtype = dtype
|
| 211 |
+
|
| 212 |
+
for key in data_dict.keys():
|
| 213 |
+
if isinstance(data_dict[key], torch.Tensor):
|
| 214 |
+
# Convert tensors to the right device and dtype if they're floating point
|
| 215 |
+
if data_dict[key].is_floating_point():
|
| 216 |
+
data_dict[key] = data_dict[key].to(device=device, dtype=dtype)
|
| 217 |
+
else:
|
| 218 |
+
# For integer tensors, just move to device without changing dtype
|
| 219 |
+
data_dict[key] = data_dict[key].to(device)
|
| 220 |
+
|
| 221 |
+
# Extract Sonata features
|
| 222 |
+
point = self.sonata(data_dict)
|
| 223 |
+
|
| 224 |
+
# Handle pooling layers (same as in train-sonata.py)
|
| 225 |
+
while "pooling_parent" in point.keys():
|
| 226 |
+
assert "pooling_inverse" in point.keys()
|
| 227 |
+
parent = point.pop("pooling_parent")
|
| 228 |
+
inverse = point.pop("pooling_inverse")
|
| 229 |
+
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
|
| 230 |
+
point = parent
|
| 231 |
+
|
| 232 |
+
# Get features and apply MLP
|
| 233 |
+
feat = point.feat # [M, 1232]
|
| 234 |
+
feat = self.mlp(feat) # [M, 512]
|
| 235 |
+
|
| 236 |
+
# Map back to original points
|
| 237 |
+
feat = feat[point.inverse] # [B*N, 512]
|
| 238 |
+
|
| 239 |
+
# Reshape to batch format
|
| 240 |
+
feat = feat.reshape(B, -1, feat.shape[-1]) # [B, N, 512]
|
| 241 |
+
|
| 242 |
+
# Return in original format
|
| 243 |
+
if single_batch:
|
| 244 |
+
feat = feat.squeeze(0) # [N, 512]
|
| 245 |
+
|
| 246 |
+
return feat
|
| 247 |
+
|
| 248 |
+
def extract_features_batch(
|
| 249 |
+
self,
|
| 250 |
+
points_list: list,
|
| 251 |
+
normals_list: Optional[list] = None,
|
| 252 |
+
batch_size: int = 8,
|
| 253 |
+
) -> list:
|
| 254 |
+
"""
|
| 255 |
+
Extract features for multiple point clouds in batches.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
points_list: List of [N_i, 3] tensors
|
| 259 |
+
normals_list: List of [N_i, 3] tensors (optional)
|
| 260 |
+
batch_size: Batch size for processing
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
List of [N_i, 512] feature tensors
|
| 264 |
+
"""
|
| 265 |
+
features_list = []
|
| 266 |
+
|
| 267 |
+
# Process in batches
|
| 268 |
+
for i in range(0, len(points_list), batch_size):
|
| 269 |
+
batch_points = points_list[i : i + batch_size]
|
| 270 |
+
batch_normals = normals_list[i : i + batch_size] if normals_list else None
|
| 271 |
+
|
| 272 |
+
# Find max points in batch
|
| 273 |
+
max_n = max(p.shape[0] for p in batch_points)
|
| 274 |
+
|
| 275 |
+
# Pad to same size
|
| 276 |
+
padded_points = []
|
| 277 |
+
masks = []
|
| 278 |
+
for points in batch_points:
|
| 279 |
+
n = points.shape[0]
|
| 280 |
+
if n < max_n:
|
| 281 |
+
padding = torch.zeros(max_n - n, 3, device=points.device)
|
| 282 |
+
points = torch.cat([points, padding], dim=0)
|
| 283 |
+
padded_points.append(points)
|
| 284 |
+
mask = torch.zeros(max_n, dtype=torch.bool, device=points.device)
|
| 285 |
+
mask[:n] = True
|
| 286 |
+
masks.append(mask)
|
| 287 |
+
|
| 288 |
+
# Stack batch
|
| 289 |
+
batch_tensor = torch.stack(padded_points) # [B, max_n, 3]
|
| 290 |
+
|
| 291 |
+
# Handle normals similarly if provided
|
| 292 |
+
if batch_normals:
|
| 293 |
+
padded_normals = []
|
| 294 |
+
for j, normals in enumerate(batch_normals):
|
| 295 |
+
n = normals.shape[0]
|
| 296 |
+
if n < max_n:
|
| 297 |
+
padding = torch.ones(max_n - n, 3, device=normals.device)
|
| 298 |
+
normals = torch.cat([normals, padding], dim=0)
|
| 299 |
+
padded_normals.append(normals)
|
| 300 |
+
normals_tensor = torch.stack(padded_normals)
|
| 301 |
+
else:
|
| 302 |
+
normals_tensor = None
|
| 303 |
+
|
| 304 |
+
# Extract features
|
| 305 |
+
with torch.cuda.amp.autocast(enabled=True):
|
| 306 |
+
batch_features = self.forward(
|
| 307 |
+
batch_tensor, normals_tensor
|
| 308 |
+
) # [B, max_n, 512]
|
| 309 |
+
|
| 310 |
+
# Unpad and add to results
|
| 311 |
+
for j, (feat, mask) in enumerate(zip(batch_features, masks)):
|
| 312 |
+
features_list.append(feat[mask])
|
| 313 |
+
|
| 314 |
+
return features_list
|
| 315 |
+
|
XPart/partgen/models/diffusion/schedulers.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 16 |
+
# except for the third-party components listed below.
|
| 17 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 18 |
+
# in the repsective licenses of these third-party components.
|
| 19 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 20 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 21 |
+
# all relevant laws and regulations.
|
| 22 |
+
|
| 23 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 24 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 25 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 26 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 27 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 28 |
+
|
| 29 |
+
import math
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from typing import List, Optional, Tuple, Union
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import torch
|
| 35 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 36 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 37 |
+
from diffusers.utils import BaseOutput, logging
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
| 44 |
+
"""
|
| 45 |
+
Output class for the scheduler's `step` function output.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
| 49 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
| 50 |
+
denoising loop.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
prev_sample: torch.FloatTensor
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 57 |
+
"""
|
| 58 |
+
NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed
|
| 59 |
+
|
| 60 |
+
Euler scheduler.
|
| 61 |
+
|
| 62 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 63 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 67 |
+
The number of diffusion steps to train the model.
|
| 68 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 69 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 70 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 71 |
+
shift (`float`, defaults to 1.0):
|
| 72 |
+
The shift value for the timestep schedule.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
_compatibles = []
|
| 76 |
+
order = 1
|
| 77 |
+
|
| 78 |
+
@register_to_config
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
num_train_timesteps: int = 1000,
|
| 82 |
+
shift: float = 1.0,
|
| 83 |
+
use_dynamic_shifting=False,
|
| 84 |
+
):
|
| 85 |
+
timesteps = np.linspace(
|
| 86 |
+
1, num_train_timesteps, num_train_timesteps, dtype=np.float32
|
| 87 |
+
).copy()
|
| 88 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
| 89 |
+
|
| 90 |
+
sigmas = timesteps / num_train_timesteps
|
| 91 |
+
if not use_dynamic_shifting:
|
| 92 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
| 93 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 94 |
+
|
| 95 |
+
self.timesteps = sigmas * num_train_timesteps
|
| 96 |
+
|
| 97 |
+
self._step_index = None
|
| 98 |
+
self._begin_index = None
|
| 99 |
+
|
| 100 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 101 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 102 |
+
self.sigma_max = self.sigmas[0].item()
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def step_index(self):
|
| 106 |
+
"""
|
| 107 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 108 |
+
"""
|
| 109 |
+
return self._step_index
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def begin_index(self):
|
| 113 |
+
"""
|
| 114 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 115 |
+
"""
|
| 116 |
+
return self._begin_index
|
| 117 |
+
|
| 118 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 119 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 120 |
+
"""
|
| 121 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
begin_index (`int`):
|
| 125 |
+
The begin index for the scheduler.
|
| 126 |
+
"""
|
| 127 |
+
self._begin_index = begin_index
|
| 128 |
+
|
| 129 |
+
def scale_noise(
|
| 130 |
+
self,
|
| 131 |
+
sample: torch.FloatTensor,
|
| 132 |
+
timestep: Union[float, torch.FloatTensor],
|
| 133 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 134 |
+
) -> torch.FloatTensor:
|
| 135 |
+
"""
|
| 136 |
+
Forward process in flow-matching
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
sample (`torch.FloatTensor`):
|
| 140 |
+
The input sample.
|
| 141 |
+
timestep (`int`, *optional*):
|
| 142 |
+
The current timestep in the diffusion chain.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
`torch.FloatTensor`:
|
| 146 |
+
A scaled input sample.
|
| 147 |
+
"""
|
| 148 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 149 |
+
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
| 150 |
+
|
| 151 |
+
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
| 152 |
+
# mps does not support float64
|
| 153 |
+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
| 154 |
+
timestep = timestep.to(sample.device, dtype=torch.float32)
|
| 155 |
+
else:
|
| 156 |
+
schedule_timesteps = self.timesteps.to(sample.device)
|
| 157 |
+
timestep = timestep.to(sample.device)
|
| 158 |
+
|
| 159 |
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
| 160 |
+
if self.begin_index is None:
|
| 161 |
+
step_indices = [
|
| 162 |
+
self.index_for_timestep(t, schedule_timesteps) for t in timestep
|
| 163 |
+
]
|
| 164 |
+
elif self.step_index is not None:
|
| 165 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 166 |
+
step_indices = [self.step_index] * timestep.shape[0]
|
| 167 |
+
else:
|
| 168 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 169 |
+
step_indices = [self.begin_index] * timestep.shape[0]
|
| 170 |
+
|
| 171 |
+
sigma = sigmas[step_indices].flatten()
|
| 172 |
+
while len(sigma.shape) < len(sample.shape):
|
| 173 |
+
sigma = sigma.unsqueeze(-1)
|
| 174 |
+
|
| 175 |
+
sample = sigma * noise + (1.0 - sigma) * sample
|
| 176 |
+
|
| 177 |
+
return sample
|
| 178 |
+
|
| 179 |
+
def _sigma_to_t(self, sigma):
|
| 180 |
+
return sigma * self.config.num_train_timesteps
|
| 181 |
+
|
| 182 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 183 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 184 |
+
|
| 185 |
+
def set_timesteps(
|
| 186 |
+
self,
|
| 187 |
+
num_inference_steps: int = None,
|
| 188 |
+
device: Union[str, torch.device] = None,
|
| 189 |
+
sigmas: Optional[List[float]] = None,
|
| 190 |
+
mu: Optional[float] = None,
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
num_inference_steps (`int`):
|
| 197 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
| 198 |
+
device (`str` or `torch.device`, *optional*):
|
| 199 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
if self.config.use_dynamic_shifting and mu is None:
|
| 203 |
+
raise ValueError(
|
| 204 |
+
" you have a pass a value for `mu` when `use_dynamic_shifting` is set"
|
| 205 |
+
" to be `True`"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
if sigmas is None:
|
| 209 |
+
self.num_inference_steps = num_inference_steps
|
| 210 |
+
timesteps = np.linspace(
|
| 211 |
+
self._sigma_to_t(self.sigma_max),
|
| 212 |
+
self._sigma_to_t(self.sigma_min),
|
| 213 |
+
num_inference_steps,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
| 217 |
+
|
| 218 |
+
if self.config.use_dynamic_shifting:
|
| 219 |
+
sigmas = self.time_shift(mu, 1.0, sigmas)
|
| 220 |
+
else:
|
| 221 |
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
| 222 |
+
|
| 223 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
| 224 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
| 225 |
+
|
| 226 |
+
self.timesteps = timesteps.to(device=device)
|
| 227 |
+
self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
| 228 |
+
|
| 229 |
+
self._step_index = None
|
| 230 |
+
self._begin_index = None
|
| 231 |
+
|
| 232 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 233 |
+
if schedule_timesteps is None:
|
| 234 |
+
schedule_timesteps = self.timesteps
|
| 235 |
+
|
| 236 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 237 |
+
|
| 238 |
+
# The sigma index that is taken for the **very** first `step`
|
| 239 |
+
# is always the second index (or the last index if there is only 1)
|
| 240 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 241 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 242 |
+
pos = 1 if len(indices) > 1 else 0
|
| 243 |
+
|
| 244 |
+
return indices[pos].item()
|
| 245 |
+
|
| 246 |
+
def _init_step_index(self, timestep):
|
| 247 |
+
if self.begin_index is None:
|
| 248 |
+
if isinstance(timestep, torch.Tensor):
|
| 249 |
+
timestep = timestep.to(self.timesteps.device)
|
| 250 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 251 |
+
else:
|
| 252 |
+
self._step_index = self._begin_index
|
| 253 |
+
|
| 254 |
+
def step(
|
| 255 |
+
self,
|
| 256 |
+
model_output: torch.FloatTensor,
|
| 257 |
+
timestep: Union[float, torch.FloatTensor],
|
| 258 |
+
sample: torch.FloatTensor,
|
| 259 |
+
s_churn: float = 0.0,
|
| 260 |
+
s_tmin: float = 0.0,
|
| 261 |
+
s_tmax: float = float("inf"),
|
| 262 |
+
s_noise: float = 1.0,
|
| 263 |
+
generator: Optional[torch.Generator] = None,
|
| 264 |
+
return_dict: bool = True,
|
| 265 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
| 266 |
+
"""
|
| 267 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 268 |
+
process from the learned model outputs (most often the predicted noise).
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
model_output (`torch.FloatTensor`):
|
| 272 |
+
The direct output from learned diffusion model.
|
| 273 |
+
timestep (`float`):
|
| 274 |
+
The current discrete timestep in the diffusion chain.
|
| 275 |
+
sample (`torch.FloatTensor`):
|
| 276 |
+
A current instance of a sample created by the diffusion process.
|
| 277 |
+
s_churn (`float`):
|
| 278 |
+
s_tmin (`float`):
|
| 279 |
+
s_tmax (`float`):
|
| 280 |
+
s_noise (`float`, defaults to 1.0):
|
| 281 |
+
Scaling factor for noise added to the sample.
|
| 282 |
+
generator (`torch.Generator`, *optional*):
|
| 283 |
+
A random number generator.
|
| 284 |
+
return_dict (`bool`):
|
| 285 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
| 286 |
+
tuple.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
| 290 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
| 291 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
if (
|
| 295 |
+
isinstance(timestep, int)
|
| 296 |
+
or isinstance(timestep, torch.IntTensor)
|
| 297 |
+
or isinstance(timestep, torch.LongTensor)
|
| 298 |
+
):
|
| 299 |
+
raise ValueError(
|
| 300 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as"
|
| 301 |
+
" timesteps to `EulerDiscreteScheduler.step()` is not supported. Make"
|
| 302 |
+
" sure to pass one of the `scheduler.timesteps` as a timestep.",
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
if self.step_index is None:
|
| 306 |
+
self._init_step_index(timestep)
|
| 307 |
+
|
| 308 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 309 |
+
sample = sample.to(torch.float32)
|
| 310 |
+
|
| 311 |
+
sigma = self.sigmas[self.step_index]
|
| 312 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
| 313 |
+
|
| 314 |
+
prev_sample = sample + (sigma_next - sigma) * model_output
|
| 315 |
+
|
| 316 |
+
# Cast sample back to model compatible dtype
|
| 317 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 318 |
+
|
| 319 |
+
# upon completion increase step index by one
|
| 320 |
+
self._step_index += 1
|
| 321 |
+
|
| 322 |
+
if not return_dict:
|
| 323 |
+
return (prev_sample,)
|
| 324 |
+
|
| 325 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 326 |
+
|
| 327 |
+
def __len__(self):
|
| 328 |
+
return self.config.num_train_timesteps
|
| 329 |
+
|
XPart/partgen/models/diffusion/transport/__init__.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
+
# which is licensed under the MIT License.
|
| 3 |
+
#
|
| 4 |
+
# MIT License
|
| 5 |
+
#
|
| 6 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
+
#
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
#
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
+
# copies or substantial portions of the Software.
|
| 17 |
+
#
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
from .transport import Transport, ModelType, WeightType, PathType, Sampler
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def create_transport(
|
| 30 |
+
path_type='Linear',
|
| 31 |
+
prediction="velocity",
|
| 32 |
+
loss_weight=None,
|
| 33 |
+
train_eps=None,
|
| 34 |
+
sample_eps=None,
|
| 35 |
+
train_sample_type="uniform",
|
| 36 |
+
mean = 0.0,
|
| 37 |
+
std = 1.0,
|
| 38 |
+
shift_scale = 1.0,
|
| 39 |
+
):
|
| 40 |
+
"""function for creating Transport object
|
| 41 |
+
**Note**: model prediction defaults to velocity
|
| 42 |
+
Args:
|
| 43 |
+
- path_type: type of path to use; default to linear
|
| 44 |
+
- learn_score: set model prediction to score
|
| 45 |
+
- learn_noise: set model prediction to noise
|
| 46 |
+
- velocity_weighted: weight loss by velocity weight
|
| 47 |
+
- likelihood_weighted: weight loss by likelihood weight
|
| 48 |
+
- train_eps: small epsilon for avoiding instability during training
|
| 49 |
+
- sample_eps: small epsilon for avoiding instability during sampling
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
if prediction == "noise":
|
| 53 |
+
model_type = ModelType.NOISE
|
| 54 |
+
elif prediction == "score":
|
| 55 |
+
model_type = ModelType.SCORE
|
| 56 |
+
else:
|
| 57 |
+
model_type = ModelType.VELOCITY
|
| 58 |
+
|
| 59 |
+
if loss_weight == "velocity":
|
| 60 |
+
loss_type = WeightType.VELOCITY
|
| 61 |
+
elif loss_weight == "likelihood":
|
| 62 |
+
loss_type = WeightType.LIKELIHOOD
|
| 63 |
+
else:
|
| 64 |
+
loss_type = WeightType.NONE
|
| 65 |
+
|
| 66 |
+
path_choice = {
|
| 67 |
+
"Linear": PathType.LINEAR,
|
| 68 |
+
"GVP": PathType.GVP,
|
| 69 |
+
"VP": PathType.VP,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
path_type = path_choice[path_type]
|
| 73 |
+
|
| 74 |
+
if (path_type in [PathType.VP]):
|
| 75 |
+
train_eps = 1e-5 if train_eps is None else train_eps
|
| 76 |
+
sample_eps = 1e-3 if train_eps is None else sample_eps
|
| 77 |
+
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
|
| 78 |
+
train_eps = 1e-3 if train_eps is None else train_eps
|
| 79 |
+
sample_eps = 1e-3 if train_eps is None else sample_eps
|
| 80 |
+
else: # velocity & [GVP, LINEAR] is stable everywhere
|
| 81 |
+
train_eps = 0
|
| 82 |
+
sample_eps = 0
|
| 83 |
+
|
| 84 |
+
# create flow state
|
| 85 |
+
state = Transport(
|
| 86 |
+
model_type=model_type,
|
| 87 |
+
path_type=path_type,
|
| 88 |
+
loss_type=loss_type,
|
| 89 |
+
train_eps=train_eps,
|
| 90 |
+
sample_eps=sample_eps,
|
| 91 |
+
train_sample_type=train_sample_type,
|
| 92 |
+
mean=mean,
|
| 93 |
+
std=std,
|
| 94 |
+
shift_scale =shift_scale,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return state
|
XPart/partgen/models/diffusion/transport/integrators.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
+
# which is licensed under the MIT License.
|
| 3 |
+
#
|
| 4 |
+
# MIT License
|
| 5 |
+
#
|
| 6 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
+
#
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
#
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
+
# copies or substantial portions of the Software.
|
| 17 |
+
#
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch as th
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
from torchdiffeq import odeint
|
| 30 |
+
from functools import partial
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
|
| 33 |
+
class sde:
|
| 34 |
+
"""SDE solver class"""
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
drift,
|
| 38 |
+
diffusion,
|
| 39 |
+
*,
|
| 40 |
+
t0,
|
| 41 |
+
t1,
|
| 42 |
+
num_steps,
|
| 43 |
+
sampler_type,
|
| 44 |
+
):
|
| 45 |
+
assert t0 < t1, "SDE sampler has to be in forward time"
|
| 46 |
+
|
| 47 |
+
self.num_timesteps = num_steps
|
| 48 |
+
self.t = th.linspace(t0, t1, num_steps)
|
| 49 |
+
self.dt = self.t[1] - self.t[0]
|
| 50 |
+
self.drift = drift
|
| 51 |
+
self.diffusion = diffusion
|
| 52 |
+
self.sampler_type = sampler_type
|
| 53 |
+
|
| 54 |
+
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
|
| 55 |
+
w_cur = th.randn(x.size()).to(x)
|
| 56 |
+
t = th.ones(x.size(0)).to(x) * t
|
| 57 |
+
dw = w_cur * th.sqrt(self.dt)
|
| 58 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
| 59 |
+
diffusion = self.diffusion(x, t)
|
| 60 |
+
mean_x = x + drift * self.dt
|
| 61 |
+
x = mean_x + th.sqrt(2 * diffusion) * dw
|
| 62 |
+
return x, mean_x
|
| 63 |
+
|
| 64 |
+
def __Heun_step(self, x, _, t, model, **model_kwargs):
|
| 65 |
+
w_cur = th.randn(x.size()).to(x)
|
| 66 |
+
dw = w_cur * th.sqrt(self.dt)
|
| 67 |
+
t_cur = th.ones(x.size(0)).to(x) * t
|
| 68 |
+
diffusion = self.diffusion(x, t_cur)
|
| 69 |
+
xhat = x + th.sqrt(2 * diffusion) * dw
|
| 70 |
+
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
|
| 71 |
+
xp = xhat + self.dt * K1
|
| 72 |
+
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
|
| 73 |
+
return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
|
| 74 |
+
|
| 75 |
+
def __forward_fn(self):
|
| 76 |
+
"""TODO: generalize here by adding all private functions ending with steps to it"""
|
| 77 |
+
sampler_dict = {
|
| 78 |
+
"Euler": self.__Euler_Maruyama_step,
|
| 79 |
+
"Heun": self.__Heun_step,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
sampler = sampler_dict[self.sampler_type]
|
| 84 |
+
except:
|
| 85 |
+
raise NotImplementedError("Smapler type not implemented.")
|
| 86 |
+
|
| 87 |
+
return sampler
|
| 88 |
+
|
| 89 |
+
def sample(self, init, model, **model_kwargs):
|
| 90 |
+
"""forward loop of sde"""
|
| 91 |
+
x = init
|
| 92 |
+
mean_x = init
|
| 93 |
+
samples = []
|
| 94 |
+
sampler = self.__forward_fn()
|
| 95 |
+
for ti in self.t[:-1]:
|
| 96 |
+
with th.no_grad():
|
| 97 |
+
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
|
| 98 |
+
samples.append(x)
|
| 99 |
+
|
| 100 |
+
return samples
|
| 101 |
+
|
| 102 |
+
class ode:
|
| 103 |
+
"""ODE solver class"""
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
drift,
|
| 107 |
+
*,
|
| 108 |
+
t0,
|
| 109 |
+
t1,
|
| 110 |
+
sampler_type,
|
| 111 |
+
num_steps,
|
| 112 |
+
atol,
|
| 113 |
+
rtol,
|
| 114 |
+
):
|
| 115 |
+
assert t0 < t1, "ODE sampler has to be in forward time"
|
| 116 |
+
|
| 117 |
+
self.drift = drift
|
| 118 |
+
self.t = th.linspace(t0, t1, num_steps)
|
| 119 |
+
self.atol = atol
|
| 120 |
+
self.rtol = rtol
|
| 121 |
+
self.sampler_type = sampler_type
|
| 122 |
+
|
| 123 |
+
def sample(self, x, model, **model_kwargs):
|
| 124 |
+
|
| 125 |
+
device = x[0].device if isinstance(x, tuple) else x.device
|
| 126 |
+
def _fn(t, x):
|
| 127 |
+
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
|
| 128 |
+
model_output = self.drift(x, t, model, **model_kwargs)
|
| 129 |
+
return model_output
|
| 130 |
+
|
| 131 |
+
t = self.t.to(device)
|
| 132 |
+
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
| 133 |
+
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
| 134 |
+
samples = odeint(
|
| 135 |
+
_fn,
|
| 136 |
+
x,
|
| 137 |
+
t,
|
| 138 |
+
method=self.sampler_type,
|
| 139 |
+
atol=atol,
|
| 140 |
+
rtol=rtol
|
| 141 |
+
)
|
| 142 |
+
return samples
|
XPart/partgen/models/diffusion/transport/path.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
+
# which is licensed under the MIT License.
|
| 3 |
+
#
|
| 4 |
+
# MIT License
|
| 5 |
+
#
|
| 6 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
+
#
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
#
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
+
# copies or substantial portions of the Software.
|
| 17 |
+
#
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
import torch as th
|
| 27 |
+
import numpy as np
|
| 28 |
+
from functools import partial
|
| 29 |
+
|
| 30 |
+
def expand_t_like_x(t, x):
|
| 31 |
+
"""Function to reshape time t to broadcastable dimension of x
|
| 32 |
+
Args:
|
| 33 |
+
t: [batch_dim,], time vector
|
| 34 |
+
x: [batch_dim,...], data point
|
| 35 |
+
"""
|
| 36 |
+
dims = [1] * (len(x.size()) - 1)
|
| 37 |
+
t = t.view(t.size(0), *dims)
|
| 38 |
+
return t
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
#################### Coupling Plans ####################
|
| 42 |
+
|
| 43 |
+
class ICPlan:
|
| 44 |
+
"""Linear Coupling Plan"""
|
| 45 |
+
def __init__(self, sigma=0.0):
|
| 46 |
+
self.sigma = sigma
|
| 47 |
+
|
| 48 |
+
def compute_alpha_t(self, t):
|
| 49 |
+
"""Compute the data coefficient along the path"""
|
| 50 |
+
return t, 1
|
| 51 |
+
|
| 52 |
+
def compute_sigma_t(self, t):
|
| 53 |
+
"""Compute the noise coefficient along the path"""
|
| 54 |
+
return 1 - t, -1
|
| 55 |
+
|
| 56 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 57 |
+
"""Compute the ratio between d_alpha and alpha"""
|
| 58 |
+
return 1 / t
|
| 59 |
+
|
| 60 |
+
def compute_drift(self, x, t):
|
| 61 |
+
"""We always output sde according to score parametrization; """
|
| 62 |
+
t = expand_t_like_x(t, x)
|
| 63 |
+
alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
|
| 64 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 65 |
+
drift = alpha_ratio * x
|
| 66 |
+
diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
|
| 67 |
+
|
| 68 |
+
return -drift, diffusion
|
| 69 |
+
|
| 70 |
+
def compute_diffusion(self, x, t, form="constant", norm=1.0):
|
| 71 |
+
"""Compute the diffusion term of the SDE
|
| 72 |
+
Args:
|
| 73 |
+
x: [batch_dim, ...], data point
|
| 74 |
+
t: [batch_dim,], time vector
|
| 75 |
+
form: str, form of the diffusion term
|
| 76 |
+
norm: float, norm of the diffusion term
|
| 77 |
+
"""
|
| 78 |
+
t = expand_t_like_x(t, x)
|
| 79 |
+
choices = {
|
| 80 |
+
"constant": norm,
|
| 81 |
+
"SBDM": norm * self.compute_drift(x, t)[1],
|
| 82 |
+
"sigma": norm * self.compute_sigma_t(t)[0],
|
| 83 |
+
"linear": norm * (1 - t),
|
| 84 |
+
"decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
|
| 85 |
+
"inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
diffusion = choices[form]
|
| 90 |
+
except KeyError:
|
| 91 |
+
raise NotImplementedError(f"Diffusion form {form} not implemented")
|
| 92 |
+
|
| 93 |
+
return diffusion
|
| 94 |
+
|
| 95 |
+
def get_score_from_velocity(self, velocity, x, t):
|
| 96 |
+
"""Wrapper function: transfrom velocity prediction model to score
|
| 97 |
+
Args:
|
| 98 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 99 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 100 |
+
t: [batch_dim,] time tensor
|
| 101 |
+
"""
|
| 102 |
+
t = expand_t_like_x(t, x)
|
| 103 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
| 104 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 105 |
+
mean = x
|
| 106 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 107 |
+
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
|
| 108 |
+
score = (reverse_alpha_ratio * velocity - mean) / var
|
| 109 |
+
return score
|
| 110 |
+
|
| 111 |
+
def get_noise_from_velocity(self, velocity, x, t):
|
| 112 |
+
"""Wrapper function: transfrom velocity prediction model to denoiser
|
| 113 |
+
Args:
|
| 114 |
+
velocity: [batch_dim, ...] shaped tensor; velocity model output
|
| 115 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 116 |
+
t: [batch_dim,] time tensor
|
| 117 |
+
"""
|
| 118 |
+
t = expand_t_like_x(t, x)
|
| 119 |
+
alpha_t, d_alpha_t = self.compute_alpha_t(t)
|
| 120 |
+
sigma_t, d_sigma_t = self.compute_sigma_t(t)
|
| 121 |
+
mean = x
|
| 122 |
+
reverse_alpha_ratio = alpha_t / d_alpha_t
|
| 123 |
+
var = reverse_alpha_ratio * d_sigma_t - sigma_t
|
| 124 |
+
noise = (reverse_alpha_ratio * velocity - mean) / var
|
| 125 |
+
return noise
|
| 126 |
+
|
| 127 |
+
def get_velocity_from_score(self, score, x, t):
|
| 128 |
+
"""Wrapper function: transfrom score prediction model to velocity
|
| 129 |
+
Args:
|
| 130 |
+
score: [batch_dim, ...] shaped tensor; score model output
|
| 131 |
+
x: [batch_dim, ...] shaped tensor; x_t data point
|
| 132 |
+
t: [batch_dim,] time tensor
|
| 133 |
+
"""
|
| 134 |
+
t = expand_t_like_x(t, x)
|
| 135 |
+
drift, var = self.compute_drift(x, t)
|
| 136 |
+
velocity = var * score - drift
|
| 137 |
+
return velocity
|
| 138 |
+
|
| 139 |
+
def compute_mu_t(self, t, x0, x1):
|
| 140 |
+
"""Compute the mean of time-dependent density p_t"""
|
| 141 |
+
t = expand_t_like_x(t, x1)
|
| 142 |
+
alpha_t, _ = self.compute_alpha_t(t)
|
| 143 |
+
sigma_t, _ = self.compute_sigma_t(t)
|
| 144 |
+
# t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1
|
| 145 |
+
return alpha_t * x1 + sigma_t * x0
|
| 146 |
+
|
| 147 |
+
def compute_xt(self, t, x0, x1):
|
| 148 |
+
"""Sample xt from time-dependent density p_t; rng is required"""
|
| 149 |
+
xt = self.compute_mu_t(t, x0, x1)
|
| 150 |
+
return xt
|
| 151 |
+
|
| 152 |
+
def compute_ut(self, t, x0, x1, xt):
|
| 153 |
+
"""Compute the vector field corresponding to p_t"""
|
| 154 |
+
t = expand_t_like_x(t, x1)
|
| 155 |
+
_, d_alpha_t = self.compute_alpha_t(t)
|
| 156 |
+
_, d_sigma_t = self.compute_sigma_t(t)
|
| 157 |
+
return d_alpha_t * x1 + d_sigma_t * x0
|
| 158 |
+
|
| 159 |
+
def plan(self, t, x0, x1):
|
| 160 |
+
xt = self.compute_xt(t, x0, x1)
|
| 161 |
+
ut = self.compute_ut(t, x0, x1, xt)
|
| 162 |
+
return t, xt, ut
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class VPCPlan(ICPlan):
|
| 166 |
+
"""class for VP path flow matching"""
|
| 167 |
+
|
| 168 |
+
def __init__(self, sigma_min=0.1, sigma_max=20.0):
|
| 169 |
+
self.sigma_min = sigma_min
|
| 170 |
+
self.sigma_max = sigma_max
|
| 171 |
+
self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \
|
| 172 |
+
(self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
|
| 173 |
+
self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \
|
| 174 |
+
(self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def compute_alpha_t(self, t):
|
| 178 |
+
"""Compute coefficient of x1"""
|
| 179 |
+
alpha_t = self.log_mean_coeff(t)
|
| 180 |
+
alpha_t = th.exp(alpha_t)
|
| 181 |
+
d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
|
| 182 |
+
return alpha_t, d_alpha_t
|
| 183 |
+
|
| 184 |
+
def compute_sigma_t(self, t):
|
| 185 |
+
"""Compute coefficient of x0"""
|
| 186 |
+
p_sigma_t = 2 * self.log_mean_coeff(t)
|
| 187 |
+
sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
|
| 188 |
+
d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
|
| 189 |
+
return sigma_t, d_sigma_t
|
| 190 |
+
|
| 191 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 192 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
| 193 |
+
return self.d_log_mean_coeff(t)
|
| 194 |
+
|
| 195 |
+
def compute_drift(self, x, t):
|
| 196 |
+
"""Compute the drift term of the SDE"""
|
| 197 |
+
t = expand_t_like_x(t, x)
|
| 198 |
+
beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
|
| 199 |
+
return -0.5 * beta_t * x, beta_t / 2
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class GVPCPlan(ICPlan):
|
| 203 |
+
def __init__(self, sigma=0.0):
|
| 204 |
+
super().__init__(sigma)
|
| 205 |
+
|
| 206 |
+
def compute_alpha_t(self, t):
|
| 207 |
+
"""Compute coefficient of x1"""
|
| 208 |
+
alpha_t = th.sin(t * np.pi / 2)
|
| 209 |
+
d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
|
| 210 |
+
return alpha_t, d_alpha_t
|
| 211 |
+
|
| 212 |
+
def compute_sigma_t(self, t):
|
| 213 |
+
"""Compute coefficient of x0"""
|
| 214 |
+
sigma_t = th.cos(t * np.pi / 2)
|
| 215 |
+
d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
|
| 216 |
+
return sigma_t, d_sigma_t
|
| 217 |
+
|
| 218 |
+
def compute_d_alpha_alpha_ratio_t(self, t):
|
| 219 |
+
"""Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
|
| 220 |
+
return np.pi / (2 * th.tan(t * np.pi / 2))
|
XPart/partgen/models/diffusion/transport/transport.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
+
# which is licensed under the MIT License.
|
| 3 |
+
#
|
| 4 |
+
# MIT License
|
| 5 |
+
#
|
| 6 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
+
#
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
#
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
+
# copies or substantial portions of the Software.
|
| 17 |
+
#
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
import torch as th
|
| 27 |
+
import numpy as np
|
| 28 |
+
import logging
|
| 29 |
+
|
| 30 |
+
import enum
|
| 31 |
+
|
| 32 |
+
from . import path
|
| 33 |
+
from .utils import EasyDict, log_state, mean_flat
|
| 34 |
+
from .integrators import ode, sde
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ModelType(enum.Enum):
|
| 38 |
+
"""
|
| 39 |
+
Which type of output the model predicts.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
NOISE = enum.auto() # the model predicts epsilon
|
| 43 |
+
SCORE = enum.auto() # the model predicts \nabla \log p(x)
|
| 44 |
+
VELOCITY = enum.auto() # the model predicts v(x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class PathType(enum.Enum):
|
| 48 |
+
"""
|
| 49 |
+
Which type of path to use.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
LINEAR = enum.auto()
|
| 53 |
+
GVP = enum.auto()
|
| 54 |
+
VP = enum.auto()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class WeightType(enum.Enum):
|
| 58 |
+
"""
|
| 59 |
+
Which type of weighting to use.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
NONE = enum.auto()
|
| 63 |
+
VELOCITY = enum.auto()
|
| 64 |
+
LIKELIHOOD = enum.auto()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Transport:
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
*,
|
| 72 |
+
model_type,
|
| 73 |
+
path_type,
|
| 74 |
+
loss_type,
|
| 75 |
+
train_eps,
|
| 76 |
+
sample_eps,
|
| 77 |
+
train_sample_type="uniform",
|
| 78 |
+
**kwargs,
|
| 79 |
+
):
|
| 80 |
+
path_options = {
|
| 81 |
+
PathType.LINEAR: path.ICPlan,
|
| 82 |
+
PathType.GVP: path.GVPCPlan,
|
| 83 |
+
PathType.VP: path.VPCPlan,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
self.loss_type = loss_type
|
| 87 |
+
self.model_type = model_type
|
| 88 |
+
self.path_sampler = path_options[path_type]()
|
| 89 |
+
self.train_eps = train_eps
|
| 90 |
+
self.sample_eps = sample_eps
|
| 91 |
+
self.train_sample_type = train_sample_type
|
| 92 |
+
if self.train_sample_type == "logit_normal":
|
| 93 |
+
self.mean = kwargs["mean"]
|
| 94 |
+
self.std = kwargs["std"]
|
| 95 |
+
self.shift_scale = kwargs["shift_scale"]
|
| 96 |
+
print(f"using logit normal sample, shift scale is {self.shift_scale}")
|
| 97 |
+
|
| 98 |
+
def prior_logp(self, z):
|
| 99 |
+
"""
|
| 100 |
+
Standard multivariate normal prior
|
| 101 |
+
Assume z is batched
|
| 102 |
+
"""
|
| 103 |
+
shape = th.tensor(z.size())
|
| 104 |
+
N = th.prod(shape[1:])
|
| 105 |
+
_fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
|
| 106 |
+
return th.vmap(_fn)(z)
|
| 107 |
+
|
| 108 |
+
def check_interval(
|
| 109 |
+
self,
|
| 110 |
+
train_eps,
|
| 111 |
+
sample_eps,
|
| 112 |
+
*,
|
| 113 |
+
diffusion_form="SBDM",
|
| 114 |
+
sde=False,
|
| 115 |
+
reverse=False,
|
| 116 |
+
eval=False,
|
| 117 |
+
last_step_size=0.0,
|
| 118 |
+
):
|
| 119 |
+
t0 = 0
|
| 120 |
+
t1 = 1
|
| 121 |
+
eps = train_eps if not eval else sample_eps
|
| 122 |
+
if type(self.path_sampler) in [path.VPCPlan]:
|
| 123 |
+
|
| 124 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
| 125 |
+
|
| 126 |
+
elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
|
| 127 |
+
self.model_type != ModelType.VELOCITY or sde
|
| 128 |
+
): # avoid numerical issue by taking a first semi-implicit step
|
| 129 |
+
|
| 130 |
+
t0 = (
|
| 131 |
+
eps
|
| 132 |
+
if (diffusion_form == "SBDM" and sde)
|
| 133 |
+
or self.model_type != ModelType.VELOCITY
|
| 134 |
+
else 0
|
| 135 |
+
)
|
| 136 |
+
t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
|
| 137 |
+
|
| 138 |
+
if reverse:
|
| 139 |
+
t0, t1 = 1 - t0, 1 - t1
|
| 140 |
+
|
| 141 |
+
return t0, t1
|
| 142 |
+
|
| 143 |
+
def sample(self, x1):
|
| 144 |
+
"""Sampling x0 & t based on shape of x1 (if needed)
|
| 145 |
+
Args:
|
| 146 |
+
x1 - data point; [batch, *dim]
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
x0 = th.randn_like(x1)
|
| 150 |
+
if self.train_sample_type == "uniform":
|
| 151 |
+
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
|
| 152 |
+
t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
|
| 153 |
+
t = t.to(x1)
|
| 154 |
+
elif self.train_sample_type == "logit_normal":
|
| 155 |
+
t = th.randn((x1.shape[0],)) * self.std + self.mean
|
| 156 |
+
t = t.to(x1)
|
| 157 |
+
t = 1 / (1 + th.exp(-t))
|
| 158 |
+
|
| 159 |
+
t = (
|
| 160 |
+
np.sqrt(self.shift_scale)
|
| 161 |
+
* t
|
| 162 |
+
/ (1 + (np.sqrt(self.shift_scale) - 1) * t)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return t, x0, x1
|
| 166 |
+
|
| 167 |
+
def training_losses(self, model, x1, model_kwargs=None):
|
| 168 |
+
"""Loss for training the score model
|
| 169 |
+
Args:
|
| 170 |
+
- model: backbone model; could be score, noise, or velocity
|
| 171 |
+
- x1: datapoint
|
| 172 |
+
- model_kwargs: additional arguments for the model
|
| 173 |
+
"""
|
| 174 |
+
if model_kwargs == None:
|
| 175 |
+
model_kwargs = {}
|
| 176 |
+
|
| 177 |
+
t, x0, x1 = self.sample(x1)
|
| 178 |
+
t, xt, ut = self.path_sampler.plan(t, x0, x1)
|
| 179 |
+
model_output = model(xt, t, **model_kwargs)
|
| 180 |
+
B, *_, C = xt.shape
|
| 181 |
+
assert model_output.size() == (B, *xt.size()[1:-1], C)
|
| 182 |
+
|
| 183 |
+
terms = {}
|
| 184 |
+
terms["pred"] = model_output
|
| 185 |
+
if self.model_type == ModelType.VELOCITY:
|
| 186 |
+
terms["loss"] = mean_flat(((model_output - ut) ** 2))
|
| 187 |
+
else:
|
| 188 |
+
_, drift_var = self.path_sampler.compute_drift(xt, t)
|
| 189 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
|
| 190 |
+
if self.loss_type in [WeightType.VELOCITY]:
|
| 191 |
+
weight = (drift_var / sigma_t) ** 2
|
| 192 |
+
elif self.loss_type in [WeightType.LIKELIHOOD]:
|
| 193 |
+
weight = drift_var / (sigma_t**2)
|
| 194 |
+
elif self.loss_type in [WeightType.NONE]:
|
| 195 |
+
weight = 1
|
| 196 |
+
else:
|
| 197 |
+
raise NotImplementedError()
|
| 198 |
+
|
| 199 |
+
if self.model_type == ModelType.NOISE:
|
| 200 |
+
terms["loss"] = mean_flat(weight * ((model_output - x0) ** 2))
|
| 201 |
+
else:
|
| 202 |
+
terms["loss"] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
|
| 203 |
+
|
| 204 |
+
return terms
|
| 205 |
+
|
| 206 |
+
def get_drift(self):
|
| 207 |
+
"""member function for obtaining the drift of the probability flow ODE"""
|
| 208 |
+
|
| 209 |
+
def score_ode(x, t, model, **model_kwargs):
|
| 210 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
| 211 |
+
model_output = model(x, t, **model_kwargs)
|
| 212 |
+
return -drift_mean + drift_var * model_output # by change of variable
|
| 213 |
+
|
| 214 |
+
def noise_ode(x, t, model, **model_kwargs):
|
| 215 |
+
drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
|
| 216 |
+
sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
|
| 217 |
+
model_output = model(x, t, **model_kwargs)
|
| 218 |
+
score = model_output / -sigma_t
|
| 219 |
+
return -drift_mean + drift_var * score
|
| 220 |
+
|
| 221 |
+
def velocity_ode(x, t, model, **model_kwargs):
|
| 222 |
+
model_output = model(x, t, **model_kwargs)
|
| 223 |
+
return model_output
|
| 224 |
+
|
| 225 |
+
if self.model_type == ModelType.NOISE:
|
| 226 |
+
drift_fn = noise_ode
|
| 227 |
+
elif self.model_type == ModelType.SCORE:
|
| 228 |
+
drift_fn = score_ode
|
| 229 |
+
else:
|
| 230 |
+
drift_fn = velocity_ode
|
| 231 |
+
|
| 232 |
+
def body_fn(x, t, model, **model_kwargs):
|
| 233 |
+
model_output = drift_fn(x, t, model, **model_kwargs)
|
| 234 |
+
assert (
|
| 235 |
+
model_output.shape == x.shape
|
| 236 |
+
), "Output shape from ODE solver must match input shape"
|
| 237 |
+
return model_output
|
| 238 |
+
|
| 239 |
+
return body_fn
|
| 240 |
+
|
| 241 |
+
def get_score(
|
| 242 |
+
self,
|
| 243 |
+
):
|
| 244 |
+
"""member function for obtaining score of
|
| 245 |
+
x_t = alpha_t * x + sigma_t * eps"""
|
| 246 |
+
if self.model_type == ModelType.NOISE:
|
| 247 |
+
score_fn = (
|
| 248 |
+
lambda x, t, model, **kwargs: model(x, t, **kwargs)
|
| 249 |
+
/ -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
|
| 250 |
+
)
|
| 251 |
+
elif self.model_type == ModelType.SCORE:
|
| 252 |
+
score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
|
| 253 |
+
elif self.model_type == ModelType.VELOCITY:
|
| 254 |
+
score_fn = (
|
| 255 |
+
lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
|
| 256 |
+
model(x, t, **kwargs), x, t
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
raise NotImplementedError()
|
| 261 |
+
|
| 262 |
+
return score_fn
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class Sampler:
|
| 266 |
+
"""Sampler class for the transport model"""
|
| 267 |
+
|
| 268 |
+
def __init__(
|
| 269 |
+
self,
|
| 270 |
+
transport,
|
| 271 |
+
):
|
| 272 |
+
"""Constructor for a general sampler; supporting different sampling methods
|
| 273 |
+
Args:
|
| 274 |
+
- transport: an tranport object specify model prediction & interpolant type
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
self.transport = transport
|
| 278 |
+
self.drift = self.transport.get_drift()
|
| 279 |
+
self.score = self.transport.get_score()
|
| 280 |
+
|
| 281 |
+
def __get_sde_diffusion_and_drift(
|
| 282 |
+
self,
|
| 283 |
+
*,
|
| 284 |
+
diffusion_form="SBDM",
|
| 285 |
+
diffusion_norm=1.0,
|
| 286 |
+
):
|
| 287 |
+
|
| 288 |
+
def diffusion_fn(x, t):
|
| 289 |
+
diffusion = self.transport.path_sampler.compute_diffusion(
|
| 290 |
+
x, t, form=diffusion_form, norm=diffusion_norm
|
| 291 |
+
)
|
| 292 |
+
return diffusion
|
| 293 |
+
|
| 294 |
+
sde_drift = lambda x, t, model, **kwargs: self.drift(
|
| 295 |
+
x, t, model, **kwargs
|
| 296 |
+
) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
|
| 297 |
+
|
| 298 |
+
sde_diffusion = diffusion_fn
|
| 299 |
+
|
| 300 |
+
return sde_drift, sde_diffusion
|
| 301 |
+
|
| 302 |
+
def __get_last_step(
|
| 303 |
+
self,
|
| 304 |
+
sde_drift,
|
| 305 |
+
*,
|
| 306 |
+
last_step,
|
| 307 |
+
last_step_size,
|
| 308 |
+
):
|
| 309 |
+
"""Get the last step function of the SDE solver"""
|
| 310 |
+
|
| 311 |
+
if last_step is None:
|
| 312 |
+
last_step_fn = lambda x, t, model, **model_kwargs: x
|
| 313 |
+
elif last_step == "Mean":
|
| 314 |
+
last_step_fn = (
|
| 315 |
+
lambda x, t, model, **model_kwargs: x
|
| 316 |
+
+ sde_drift(x, t, model, **model_kwargs) * last_step_size
|
| 317 |
+
)
|
| 318 |
+
elif last_step == "Tweedie":
|
| 319 |
+
alpha = (
|
| 320 |
+
self.transport.path_sampler.compute_alpha_t
|
| 321 |
+
) # simple aliasing; the original name was too long
|
| 322 |
+
sigma = self.transport.path_sampler.compute_sigma_t
|
| 323 |
+
last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (
|
| 324 |
+
sigma(t)[0][0] ** 2
|
| 325 |
+
) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
|
| 326 |
+
elif last_step == "Euler":
|
| 327 |
+
last_step_fn = (
|
| 328 |
+
lambda x, t, model, **model_kwargs: x
|
| 329 |
+
+ self.drift(x, t, model, **model_kwargs) * last_step_size
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
raise NotImplementedError()
|
| 333 |
+
|
| 334 |
+
return last_step_fn
|
| 335 |
+
|
| 336 |
+
def sample_sde(
|
| 337 |
+
self,
|
| 338 |
+
*,
|
| 339 |
+
sampling_method="Euler",
|
| 340 |
+
diffusion_form="SBDM",
|
| 341 |
+
diffusion_norm=1.0,
|
| 342 |
+
last_step="Mean",
|
| 343 |
+
last_step_size=0.04,
|
| 344 |
+
num_steps=250,
|
| 345 |
+
):
|
| 346 |
+
"""returns a sampling function with given SDE settings
|
| 347 |
+
Args:
|
| 348 |
+
- sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
|
| 349 |
+
- diffusion_form: function form of diffusion coefficient; default to be matching SBDM
|
| 350 |
+
- diffusion_norm: function magnitude of diffusion coefficient; default to 1
|
| 351 |
+
- last_step: type of the last step; default to identity
|
| 352 |
+
- last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
|
| 353 |
+
- num_steps: total integration step of SDE
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
if last_step is None:
|
| 357 |
+
last_step_size = 0.0
|
| 358 |
+
|
| 359 |
+
sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
|
| 360 |
+
diffusion_form=diffusion_form,
|
| 361 |
+
diffusion_norm=diffusion_norm,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
t0, t1 = self.transport.check_interval(
|
| 365 |
+
self.transport.train_eps,
|
| 366 |
+
self.transport.sample_eps,
|
| 367 |
+
diffusion_form=diffusion_form,
|
| 368 |
+
sde=True,
|
| 369 |
+
eval=True,
|
| 370 |
+
reverse=False,
|
| 371 |
+
last_step_size=last_step_size,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
_sde = sde(
|
| 375 |
+
sde_drift,
|
| 376 |
+
sde_diffusion,
|
| 377 |
+
t0=t0,
|
| 378 |
+
t1=t1,
|
| 379 |
+
num_steps=num_steps,
|
| 380 |
+
sampler_type=sampling_method,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
last_step_fn = self.__get_last_step(
|
| 384 |
+
sde_drift, last_step=last_step, last_step_size=last_step_size
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
def _sample(init, model, **model_kwargs):
|
| 388 |
+
xs = _sde.sample(init, model, **model_kwargs)
|
| 389 |
+
ts = th.ones(init.size(0), device=init.device) * t1
|
| 390 |
+
x = last_step_fn(xs[-1], ts, model, **model_kwargs)
|
| 391 |
+
xs.append(x)
|
| 392 |
+
|
| 393 |
+
assert len(xs) == num_steps, "Samples does not match the number of steps"
|
| 394 |
+
|
| 395 |
+
return xs
|
| 396 |
+
|
| 397 |
+
return _sample
|
| 398 |
+
|
| 399 |
+
def sample_ode(
|
| 400 |
+
self,
|
| 401 |
+
*,
|
| 402 |
+
sampling_method="dopri5",
|
| 403 |
+
num_steps=50,
|
| 404 |
+
atol=1e-6,
|
| 405 |
+
rtol=1e-3,
|
| 406 |
+
reverse=False,
|
| 407 |
+
):
|
| 408 |
+
"""returns a sampling function with given ODE settings
|
| 409 |
+
Args:
|
| 410 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 411 |
+
- num_steps:
|
| 412 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 413 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 414 |
+
- atol: absolute error tolerance for the solver
|
| 415 |
+
- rtol: relative error tolerance for the solver
|
| 416 |
+
- reverse: whether solving the ODE in reverse (data to noise); default to False
|
| 417 |
+
"""
|
| 418 |
+
if reverse:
|
| 419 |
+
drift = lambda x, t, model, **kwargs: self.drift(
|
| 420 |
+
x, th.ones_like(t) * (1 - t), model, **kwargs
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
drift = self.drift
|
| 424 |
+
|
| 425 |
+
t0, t1 = self.transport.check_interval(
|
| 426 |
+
self.transport.train_eps,
|
| 427 |
+
self.transport.sample_eps,
|
| 428 |
+
sde=False,
|
| 429 |
+
eval=True,
|
| 430 |
+
reverse=reverse,
|
| 431 |
+
last_step_size=0.0,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
_ode = ode(
|
| 435 |
+
drift=drift,
|
| 436 |
+
t0=t0,
|
| 437 |
+
t1=t1,
|
| 438 |
+
sampler_type=sampling_method,
|
| 439 |
+
num_steps=num_steps,
|
| 440 |
+
atol=atol,
|
| 441 |
+
rtol=rtol,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
return _ode.sample
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def sample_ode_likelihood(
|
| 448 |
+
self,
|
| 449 |
+
*,
|
| 450 |
+
sampling_method="dopri5",
|
| 451 |
+
num_steps=50,
|
| 452 |
+
atol=1e-6,
|
| 453 |
+
rtol=1e-3,
|
| 454 |
+
):
|
| 455 |
+
"""returns a sampling function for calculating likelihood with given ODE settings
|
| 456 |
+
Args:
|
| 457 |
+
- sampling_method: type of sampler used in solving the ODE; default to be Dopri5
|
| 458 |
+
- num_steps:
|
| 459 |
+
- fixed solver (Euler, Heun): the actual number of integration steps performed
|
| 460 |
+
- adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
|
| 461 |
+
- atol: absolute error tolerance for the solver
|
| 462 |
+
- rtol: relative error tolerance for the solver
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
def _likelihood_drift(x, t, model, **model_kwargs):
|
| 466 |
+
x, _ = x
|
| 467 |
+
eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
|
| 468 |
+
t = th.ones_like(t) * (1 - t)
|
| 469 |
+
with th.enable_grad():
|
| 470 |
+
x.requires_grad = True
|
| 471 |
+
grad = th.autograd.grad(
|
| 472 |
+
th.sum(self.drift(x, t, model, **model_kwargs) * eps), x
|
| 473 |
+
)[0]
|
| 474 |
+
logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
|
| 475 |
+
drift = self.drift(x, t, model, **model_kwargs)
|
| 476 |
+
return (-drift, logp_grad)
|
| 477 |
+
|
| 478 |
+
t0, t1 = self.transport.check_interval(
|
| 479 |
+
self.transport.train_eps,
|
| 480 |
+
self.transport.sample_eps,
|
| 481 |
+
sde=False,
|
| 482 |
+
eval=True,
|
| 483 |
+
reverse=False,
|
| 484 |
+
last_step_size=0.0,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
_ode = ode(
|
| 488 |
+
drift=_likelihood_drift,
|
| 489 |
+
t0=t0,
|
| 490 |
+
t1=t1,
|
| 491 |
+
sampler_type=sampling_method,
|
| 492 |
+
num_steps=num_steps,
|
| 493 |
+
atol=atol,
|
| 494 |
+
rtol=rtol,
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
def _sample_fn(x, model, **model_kwargs):
|
| 498 |
+
init_logp = th.zeros(x.size(0)).to(x)
|
| 499 |
+
input = (x, init_logp)
|
| 500 |
+
drift, delta_logp = _ode.sample(input, model, **model_kwargs)
|
| 501 |
+
drift, delta_logp = drift[-1], delta_logp[-1]
|
| 502 |
+
prior_logp = self.transport.prior_logp(drift)
|
| 503 |
+
logp = prior_logp - delta_logp
|
| 504 |
+
return logp, drift
|
| 505 |
+
|
| 506 |
+
return _sample_fn
|
XPart/partgen/models/diffusion/transport/utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file includes code derived from the SiT project (https://github.com/willisma/SiT),
|
| 2 |
+
# which is licensed under the MIT License.
|
| 3 |
+
#
|
| 4 |
+
# MIT License
|
| 5 |
+
#
|
| 6 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
+
#
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
#
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
+
# copies or substantial portions of the Software.
|
| 17 |
+
#
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 24 |
+
# SOFTWARE.
|
| 25 |
+
|
| 26 |
+
import torch as th
|
| 27 |
+
|
| 28 |
+
class EasyDict:
|
| 29 |
+
|
| 30 |
+
def __init__(self, sub_dict):
|
| 31 |
+
for k, v in sub_dict.items():
|
| 32 |
+
setattr(self, k, v)
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, key):
|
| 35 |
+
return getattr(self, key)
|
| 36 |
+
|
| 37 |
+
def mean_flat(x):
|
| 38 |
+
"""
|
| 39 |
+
Take the mean over all non-batch dimensions.
|
| 40 |
+
"""
|
| 41 |
+
return th.mean(x, dim=list(range(1, len(x.size()))))
|
| 42 |
+
|
| 43 |
+
def log_state(state):
|
| 44 |
+
result = []
|
| 45 |
+
|
| 46 |
+
sorted_state = dict(sorted(state.items()))
|
| 47 |
+
for key, value in sorted_state.items():
|
| 48 |
+
# Check if the value is an instance of a class
|
| 49 |
+
if "<object" in str(value) or "object at" in str(value):
|
| 50 |
+
result.append(f"{key}: [{value.__class__.__name__}]")
|
| 51 |
+
else:
|
| 52 |
+
result.append(f"{key}: {value}")
|
| 53 |
+
|
| 54 |
+
return '\n'.join(result)
|
XPart/partgen/models/moe_layers.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
| 2 |
+
# except for the third-party components listed below.
|
| 3 |
+
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
| 4 |
+
# in the repsective licenses of these third-party components.
|
| 5 |
+
# Users must comply with all terms and conditions of original licenses of these third-party
|
| 6 |
+
# components and must ensure that the usage of the third party components adheres to
|
| 7 |
+
# all relevant laws and regulations.
|
| 8 |
+
|
| 9 |
+
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
| 10 |
+
# their software and algorithms, including trained model weights, parameters (including
|
| 11 |
+
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
| 12 |
+
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
| 13 |
+
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from diffusers.models.attention import FeedForward
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AddAuxiliaryLoss(torch.autograd.Function):
|
| 24 |
+
"""
|
| 25 |
+
The trick function of adding auxiliary (aux) loss,
|
| 26 |
+
which includes the gradient of the aux loss during backpropagation.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def forward(ctx, x, loss):
|
| 31 |
+
assert loss.numel() == 1
|
| 32 |
+
ctx.dtype = loss.dtype
|
| 33 |
+
ctx.required_aux_loss = loss.requires_grad
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def backward(ctx, grad_output):
|
| 38 |
+
grad_loss = None
|
| 39 |
+
if ctx.required_aux_loss:
|
| 40 |
+
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
|
| 41 |
+
return grad_output, grad_loss
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MoEGate(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.top_k = num_experts_per_tok
|
| 50 |
+
self.n_routed_experts = num_experts
|
| 51 |
+
|
| 52 |
+
self.scoring_func = "softmax"
|
| 53 |
+
self.alpha = aux_loss_alpha
|
| 54 |
+
self.seq_aux = False
|
| 55 |
+
|
| 56 |
+
# topk selection algorithm
|
| 57 |
+
self.norm_topk_prob = False
|
| 58 |
+
self.gating_dim = embed_dim
|
| 59 |
+
self.weight = nn.Parameter(
|
| 60 |
+
torch.empty((self.n_routed_experts, self.gating_dim))
|
| 61 |
+
)
|
| 62 |
+
self.reset_parameters()
|
| 63 |
+
|
| 64 |
+
def reset_parameters(self) -> None:
|
| 65 |
+
import torch.nn.init as init
|
| 66 |
+
|
| 67 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 68 |
+
|
| 69 |
+
def forward(self, hidden_states):
|
| 70 |
+
bsz, seq_len, h = hidden_states.shape
|
| 71 |
+
# print(bsz, seq_len, h)
|
| 72 |
+
### compute gating score
|
| 73 |
+
hidden_states = hidden_states.view(-1, h)
|
| 74 |
+
logits = F.linear(hidden_states, self.weight, None)
|
| 75 |
+
if self.scoring_func == "softmax":
|
| 76 |
+
scores = logits.softmax(dim=-1)
|
| 77 |
+
else:
|
| 78 |
+
raise NotImplementedError(
|
| 79 |
+
f"insupportable scoring function for MoE gating: {self.scoring_func}"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
### select top-k experts
|
| 83 |
+
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
| 84 |
+
|
| 85 |
+
### norm gate to sum 1
|
| 86 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 87 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 88 |
+
topk_weight = topk_weight / denominator
|
| 89 |
+
|
| 90 |
+
### expert-level computation auxiliary loss
|
| 91 |
+
if self.training and self.alpha > 0.0:
|
| 92 |
+
scores_for_aux = scores
|
| 93 |
+
aux_topk = self.top_k
|
| 94 |
+
# always compute aux loss based on the naive greedy topk method
|
| 95 |
+
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
| 96 |
+
if self.seq_aux:
|
| 97 |
+
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
| 98 |
+
ce = torch.zeros(
|
| 99 |
+
bsz, self.n_routed_experts, device=hidden_states.device
|
| 100 |
+
)
|
| 101 |
+
ce.scatter_add_(
|
| 102 |
+
1,
|
| 103 |
+
topk_idx_for_aux_loss,
|
| 104 |
+
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
|
| 105 |
+
).div_(seq_len * aux_topk / self.n_routed_experts)
|
| 106 |
+
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean()
|
| 107 |
+
aux_loss = aux_loss * self.alpha
|
| 108 |
+
else:
|
| 109 |
+
mask_ce = F.one_hot(
|
| 110 |
+
topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
|
| 111 |
+
)
|
| 112 |
+
ce = mask_ce.float().mean(0)
|
| 113 |
+
Pi = scores_for_aux.mean(0)
|
| 114 |
+
fi = ce * self.n_routed_experts
|
| 115 |
+
aux_loss = (Pi * fi).sum() * self.alpha
|
| 116 |
+
else:
|
| 117 |
+
aux_loss = None
|
| 118 |
+
return topk_idx, topk_weight, aux_loss
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class MoEBlock(nn.Module):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
dim,
|
| 125 |
+
num_experts=8,
|
| 126 |
+
moe_top_k=2,
|
| 127 |
+
activation_fn="gelu",
|
| 128 |
+
dropout=0.0,
|
| 129 |
+
final_dropout=False,
|
| 130 |
+
ff_inner_dim=None,
|
| 131 |
+
ff_bias=True,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.moe_top_k = moe_top_k
|
| 135 |
+
self.experts = nn.ModuleList([
|
| 136 |
+
FeedForward(
|
| 137 |
+
dim,
|
| 138 |
+
dropout=dropout,
|
| 139 |
+
activation_fn=activation_fn,
|
| 140 |
+
final_dropout=final_dropout,
|
| 141 |
+
inner_dim=ff_inner_dim,
|
| 142 |
+
bias=ff_bias,
|
| 143 |
+
)
|
| 144 |
+
for i in range(num_experts)
|
| 145 |
+
])
|
| 146 |
+
self.gate = MoEGate(
|
| 147 |
+
embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.shared_experts = FeedForward(
|
| 151 |
+
dim,
|
| 152 |
+
dropout=dropout,
|
| 153 |
+
activation_fn=activation_fn,
|
| 154 |
+
final_dropout=final_dropout,
|
| 155 |
+
inner_dim=ff_inner_dim,
|
| 156 |
+
bias=ff_bias,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def initialize_weight(self):
|
| 160 |
+
pass
|
| 161 |
+
|
| 162 |
+
def forward(self, hidden_states):
|
| 163 |
+
identity = hidden_states
|
| 164 |
+
orig_shape = hidden_states.shape
|
| 165 |
+
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
| 166 |
+
|
| 167 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 168 |
+
flat_topk_idx = topk_idx.view(-1)
|
| 169 |
+
if self.training:
|
| 170 |
+
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0)
|
| 171 |
+
y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
|
| 172 |
+
for i, expert in enumerate(self.experts):
|
| 173 |
+
tmp = expert(hidden_states[flat_topk_idx == i])
|
| 174 |
+
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
|
| 175 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
| 176 |
+
y = y.view(*orig_shape)
|
| 177 |
+
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
| 178 |
+
else:
|
| 179 |
+
y = self.moe_infer(
|
| 180 |
+
hidden_states, flat_topk_idx, topk_weight.view(-1, 1)
|
| 181 |
+
).view(*orig_shape)
|
| 182 |
+
y = y + self.shared_experts(identity)
|
| 183 |
+
return y
|
| 184 |
+
|
| 185 |
+
@torch.no_grad()
|
| 186 |
+
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
| 187 |
+
expert_cache = torch.zeros_like(x)
|
| 188 |
+
idxs = flat_expert_indices.argsort()
|
| 189 |
+
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
| 190 |
+
token_idxs = idxs // self.moe_top_k
|
| 191 |
+
for i, end_idx in enumerate(tokens_per_expert):
|
| 192 |
+
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
| 193 |
+
if start_idx == end_idx:
|
| 194 |
+
continue
|
| 195 |
+
expert = self.experts[i]
|
| 196 |
+
exp_token_idx = token_idxs[start_idx:end_idx]
|
| 197 |
+
expert_tokens = x[exp_token_idx]
|
| 198 |
+
expert_out = expert(expert_tokens)
|
| 199 |
+
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
| 200 |
+
|
| 201 |
+
# for fp16 and other dtype
|
| 202 |
+
expert_cache = expert_cache.to(expert_out.dtype)
|
| 203 |
+
expert_cache.scatter_reduce_(
|
| 204 |
+
0,
|
| 205 |
+
exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
|
| 206 |
+
expert_out,
|
| 207 |
+
reduce="sum",
|
| 208 |
+
)
|
| 209 |
+
return expert_cache
|
XPart/partgen/models/partformer_dit.py
ADDED
|
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Newest version: add local&global context (cross-attn), and local&global attn (self-attn)
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from .moe_layers import MoEBlock
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 15 |
+
"""
|
| 16 |
+
embed_dim: output dimension for each position
|
| 17 |
+
pos: a list of positions to be encoded: size (M,)
|
| 18 |
+
out: (M, D)
|
| 19 |
+
"""
|
| 20 |
+
assert embed_dim % 2 == 0
|
| 21 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 22 |
+
omega /= embed_dim / 2.0
|
| 23 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 24 |
+
|
| 25 |
+
pos = pos.reshape(-1) # (M,)
|
| 26 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 27 |
+
|
| 28 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 29 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 30 |
+
|
| 31 |
+
return np.concatenate([emb_sin, emb_cos], axis=1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Timesteps(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
num_channels: int,
|
| 38 |
+
downscale_freq_shift: float = 0.0,
|
| 39 |
+
scale: int = 1,
|
| 40 |
+
max_period: int = 10000,
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.num_channels = num_channels
|
| 44 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 45 |
+
self.scale = scale
|
| 46 |
+
self.max_period = max_period
|
| 47 |
+
|
| 48 |
+
def forward(self, timesteps):
|
| 49 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 50 |
+
embedding_dim = self.num_channels
|
| 51 |
+
half_dim = embedding_dim // 2
|
| 52 |
+
exponent = -math.log(self.max_period) * torch.arange(
|
| 53 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 54 |
+
)
|
| 55 |
+
exponent = exponent / (half_dim - self.downscale_freq_shift)
|
| 56 |
+
emb = torch.exp(exponent)
|
| 57 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 58 |
+
emb = self.scale * emb
|
| 59 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 60 |
+
if embedding_dim % 2 == 1:
|
| 61 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 62 |
+
return emb
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TimestepEmbedder(nn.Module):
|
| 66 |
+
"""
|
| 67 |
+
Embeds scalar timesteps into vector representations.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
hidden_size,
|
| 73 |
+
frequency_embedding_size=256,
|
| 74 |
+
cond_proj_dim=None,
|
| 75 |
+
out_size=None,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
if out_size is None:
|
| 79 |
+
out_size = hidden_size
|
| 80 |
+
self.mlp = nn.Sequential(
|
| 81 |
+
nn.Linear(hidden_size, frequency_embedding_size, bias=True),
|
| 82 |
+
nn.GELU(),
|
| 83 |
+
nn.Linear(frequency_embedding_size, out_size, bias=True),
|
| 84 |
+
)
|
| 85 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 86 |
+
|
| 87 |
+
if cond_proj_dim is not None:
|
| 88 |
+
self.cond_proj = nn.Linear(
|
| 89 |
+
cond_proj_dim, frequency_embedding_size, bias=False
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.time_embed = Timesteps(hidden_size)
|
| 93 |
+
|
| 94 |
+
def forward(self, t, condition):
|
| 95 |
+
|
| 96 |
+
t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype)
|
| 97 |
+
|
| 98 |
+
# t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
| 99 |
+
if condition is not None:
|
| 100 |
+
t_freq = t_freq + self.cond_proj(condition)
|
| 101 |
+
|
| 102 |
+
t = self.mlp(t_freq)
|
| 103 |
+
t = t.unsqueeze(dim=1)
|
| 104 |
+
return t
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class MLP(nn.Module):
|
| 108 |
+
def __init__(self, *, width: int):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.width = width
|
| 111 |
+
self.fc1 = nn.Linear(width, width * 4)
|
| 112 |
+
self.fc2 = nn.Linear(width * 4, width)
|
| 113 |
+
self.gelu = nn.GELU()
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
return self.fc2(self.gelu(self.fc1(x)))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class CrossAttention(nn.Module):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
qdim,
|
| 123 |
+
kdim,
|
| 124 |
+
num_heads,
|
| 125 |
+
qkv_bias=True,
|
| 126 |
+
qk_norm=False,
|
| 127 |
+
norm_layer=nn.LayerNorm,
|
| 128 |
+
with_decoupled_ca=False,
|
| 129 |
+
decoupled_ca_dim=16,
|
| 130 |
+
decoupled_ca_weight=1.0,
|
| 131 |
+
**kwargs,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.qdim = qdim
|
| 135 |
+
self.kdim = kdim
|
| 136 |
+
self.num_heads = num_heads
|
| 137 |
+
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
| 138 |
+
self.head_dim = self.qdim // num_heads
|
| 139 |
+
assert (
|
| 140 |
+
self.head_dim % 8 == 0 and self.head_dim <= 128
|
| 141 |
+
), "Only support head_dim <= 128 and divisible by 8"
|
| 142 |
+
self.scale = self.head_dim**-0.5
|
| 143 |
+
|
| 144 |
+
self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias)
|
| 145 |
+
self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias)
|
| 146 |
+
self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias)
|
| 147 |
+
|
| 148 |
+
# TODO: eps should be 1 / 65530 if using fp16
|
| 149 |
+
self.q_norm = (
|
| 150 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 151 |
+
if qk_norm
|
| 152 |
+
else nn.Identity()
|
| 153 |
+
)
|
| 154 |
+
self.k_norm = (
|
| 155 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 156 |
+
if qk_norm
|
| 157 |
+
else nn.Identity()
|
| 158 |
+
)
|
| 159 |
+
self.out_proj = nn.Linear(qdim, qdim, bias=True)
|
| 160 |
+
|
| 161 |
+
self.with_dca = with_decoupled_ca
|
| 162 |
+
if self.with_dca:
|
| 163 |
+
self.kv_proj_dca = nn.Linear(kdim, 2 * qdim, bias=qkv_bias)
|
| 164 |
+
self.k_norm_dca = (
|
| 165 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 166 |
+
if qk_norm
|
| 167 |
+
else nn.Identity()
|
| 168 |
+
)
|
| 169 |
+
self.dca_dim = decoupled_ca_dim
|
| 170 |
+
self.dca_weight = decoupled_ca_weight
|
| 171 |
+
# zero init
|
| 172 |
+
nn.init.zeros_(self.out_proj.weight)
|
| 173 |
+
nn.init.zeros_(self.out_proj.bias)
|
| 174 |
+
|
| 175 |
+
def forward(self, x, y):
|
| 176 |
+
"""
|
| 177 |
+
Parameters
|
| 178 |
+
----------
|
| 179 |
+
x: torch.Tensor
|
| 180 |
+
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
| 181 |
+
y: torch.Tensor
|
| 182 |
+
(batch, seqlen2, hidden_dim2)
|
| 183 |
+
freqs_cis_img: torch.Tensor
|
| 184 |
+
(batch, hidden_dim // 2), RoPE for image
|
| 185 |
+
"""
|
| 186 |
+
b, s1, c = x.shape # [b, s1, D]
|
| 187 |
+
|
| 188 |
+
if self.with_dca:
|
| 189 |
+
token_len = y.shape[1]
|
| 190 |
+
context_dca = y[:, -self.dca_dim :, :]
|
| 191 |
+
kv_dca = self.kv_proj_dca(context_dca).view(
|
| 192 |
+
b, self.dca_dim, 2, self.num_heads, self.head_dim
|
| 193 |
+
)
|
| 194 |
+
k_dca, v_dca = kv_dca.unbind(dim=2) # [b, s, h, d]
|
| 195 |
+
k_dca = self.k_norm_dca(k_dca)
|
| 196 |
+
y = y[:, : (token_len - self.dca_dim), :]
|
| 197 |
+
|
| 198 |
+
_, s2, c = y.shape # [b, s2, 1024]
|
| 199 |
+
q = self.to_q(x)
|
| 200 |
+
k = self.to_k(y)
|
| 201 |
+
v = self.to_v(y)
|
| 202 |
+
|
| 203 |
+
kv = torch.cat((k, v), dim=-1)
|
| 204 |
+
split_size = kv.shape[-1] // self.num_heads // 2
|
| 205 |
+
kv = kv.view(1, -1, self.num_heads, split_size * 2)
|
| 206 |
+
k, v = torch.split(kv, split_size, dim=-1)
|
| 207 |
+
|
| 208 |
+
q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
| 209 |
+
k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
|
| 210 |
+
v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
|
| 211 |
+
|
| 212 |
+
q = self.q_norm(q)
|
| 213 |
+
k = self.k_norm(k)
|
| 214 |
+
|
| 215 |
+
with torch.backends.cuda.sdp_kernel(
|
| 216 |
+
enable_flash=True, enable_math=False, enable_mem_efficient=True
|
| 217 |
+
):
|
| 218 |
+
q, k, v = map(
|
| 219 |
+
lambda t: rearrange(t, "b n h d -> b h n d", h=self.num_heads),
|
| 220 |
+
(q, k, v),
|
| 221 |
+
)
|
| 222 |
+
context = (
|
| 223 |
+
F.scaled_dot_product_attention(q, k, v)
|
| 224 |
+
.transpose(1, 2)
|
| 225 |
+
.reshape(b, s1, -1)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if self.with_dca:
|
| 229 |
+
with torch.backends.cuda.sdp_kernel(
|
| 230 |
+
enable_flash=True, enable_math=False, enable_mem_efficient=True
|
| 231 |
+
):
|
| 232 |
+
k_dca, v_dca = map(
|
| 233 |
+
lambda t: rearrange(t, "b n h d -> b h n d", h=self.num_heads),
|
| 234 |
+
(k_dca, v_dca),
|
| 235 |
+
)
|
| 236 |
+
context_dca = (
|
| 237 |
+
F.scaled_dot_product_attention(q, k_dca, v_dca)
|
| 238 |
+
.transpose(1, 2)
|
| 239 |
+
.reshape(b, s1, -1)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
context = context + self.dca_weight * context_dca
|
| 243 |
+
|
| 244 |
+
out = self.out_proj(context) # context.reshape - B, L1, -1
|
| 245 |
+
|
| 246 |
+
return out
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class Attention(nn.Module):
|
| 250 |
+
"""
|
| 251 |
+
We rename some layer names to align with flash attention
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
def __init__(
|
| 255 |
+
self,
|
| 256 |
+
dim,
|
| 257 |
+
num_heads,
|
| 258 |
+
qkv_bias=True,
|
| 259 |
+
qk_norm=False,
|
| 260 |
+
norm_layer=nn.LayerNorm,
|
| 261 |
+
use_global_processor=False,
|
| 262 |
+
):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.use_global_processor = use_global_processor
|
| 265 |
+
self.dim = dim
|
| 266 |
+
self.num_heads = num_heads
|
| 267 |
+
assert self.dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 268 |
+
self.head_dim = self.dim // num_heads
|
| 269 |
+
# This assertion is aligned with flash attention
|
| 270 |
+
assert (
|
| 271 |
+
self.head_dim % 8 == 0 and self.head_dim <= 128
|
| 272 |
+
), "Only support head_dim <= 128 and divisible by 8"
|
| 273 |
+
self.scale = self.head_dim**-0.5
|
| 274 |
+
|
| 275 |
+
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 276 |
+
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
|
| 277 |
+
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
|
| 278 |
+
# TODO: eps should be 1 / 65530 if using fp16
|
| 279 |
+
self.q_norm = (
|
| 280 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 281 |
+
if qk_norm
|
| 282 |
+
else nn.Identity()
|
| 283 |
+
)
|
| 284 |
+
self.k_norm = (
|
| 285 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
|
| 286 |
+
if qk_norm
|
| 287 |
+
else nn.Identity()
|
| 288 |
+
)
|
| 289 |
+
self.out_proj = nn.Linear(dim, dim)
|
| 290 |
+
|
| 291 |
+
# set processor
|
| 292 |
+
self.processor = LocalGlobalProcessor(use_global=use_global_processor)
|
| 293 |
+
|
| 294 |
+
def forward(self, x):
|
| 295 |
+
return self.processor(self, x)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class AttentionPool(nn.Module):
|
| 299 |
+
def __init__(
|
| 300 |
+
self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
|
| 301 |
+
):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.positional_embedding = nn.Parameter(
|
| 304 |
+
torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5
|
| 305 |
+
)
|
| 306 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 307 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 308 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 309 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 310 |
+
self.num_heads = num_heads
|
| 311 |
+
|
| 312 |
+
def forward(self, x, attention_mask=None):
|
| 313 |
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
| 314 |
+
if attention_mask is not None:
|
| 315 |
+
attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2)
|
| 316 |
+
global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0)
|
| 317 |
+
x = torch.cat([global_emb[None,], x], dim=0)
|
| 318 |
+
|
| 319 |
+
else:
|
| 320 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
| 321 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
| 322 |
+
x, _ = F.multi_head_attention_forward(
|
| 323 |
+
query=x[:1],
|
| 324 |
+
key=x,
|
| 325 |
+
value=x,
|
| 326 |
+
embed_dim_to_check=x.shape[-1],
|
| 327 |
+
num_heads=self.num_heads,
|
| 328 |
+
q_proj_weight=self.q_proj.weight,
|
| 329 |
+
k_proj_weight=self.k_proj.weight,
|
| 330 |
+
v_proj_weight=self.v_proj.weight,
|
| 331 |
+
in_proj_weight=None,
|
| 332 |
+
in_proj_bias=torch.cat(
|
| 333 |
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
|
| 334 |
+
),
|
| 335 |
+
bias_k=None,
|
| 336 |
+
bias_v=None,
|
| 337 |
+
add_zero_attn=False,
|
| 338 |
+
dropout_p=0,
|
| 339 |
+
out_proj_weight=self.c_proj.weight,
|
| 340 |
+
out_proj_bias=self.c_proj.bias,
|
| 341 |
+
use_separate_proj_weight=True,
|
| 342 |
+
training=self.training,
|
| 343 |
+
need_weights=False,
|
| 344 |
+
)
|
| 345 |
+
return x.squeeze(0)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class LocalGlobalProcessor:
|
| 349 |
+
def __init__(self, use_global=False):
|
| 350 |
+
self.use_global = use_global
|
| 351 |
+
|
| 352 |
+
def __call__(
|
| 353 |
+
self,
|
| 354 |
+
attn: Attention,
|
| 355 |
+
hidden_states: torch.Tensor,
|
| 356 |
+
):
|
| 357 |
+
"""
|
| 358 |
+
hidden_states: [B, L, C]
|
| 359 |
+
"""
|
| 360 |
+
if self.use_global:
|
| 361 |
+
B_old, N_old, C_old = hidden_states.shape
|
| 362 |
+
hidden_states = hidden_states.reshape(1, -1, C_old)
|
| 363 |
+
B, N, C = hidden_states.shape
|
| 364 |
+
|
| 365 |
+
q = attn.to_q(hidden_states)
|
| 366 |
+
k = attn.to_k(hidden_states)
|
| 367 |
+
v = attn.to_v(hidden_states)
|
| 368 |
+
|
| 369 |
+
qkv = torch.cat((q, k, v), dim=-1)
|
| 370 |
+
split_size = qkv.shape[-1] // attn.num_heads // 3
|
| 371 |
+
qkv = qkv.view(1, -1, attn.num_heads, split_size * 3)
|
| 372 |
+
q, k, v = torch.split(qkv, split_size, dim=-1)
|
| 373 |
+
|
| 374 |
+
q = q.reshape(B, N, attn.num_heads, attn.head_dim).transpose(
|
| 375 |
+
1, 2
|
| 376 |
+
) # [b, h, s, d]
|
| 377 |
+
k = k.reshape(B, N, attn.num_heads, attn.head_dim).transpose(
|
| 378 |
+
1, 2
|
| 379 |
+
) # [b, h, s, d]
|
| 380 |
+
v = v.reshape(B, N, attn.num_heads, attn.head_dim).transpose(1, 2)
|
| 381 |
+
|
| 382 |
+
q = attn.q_norm(q) # [b, h, s, d]
|
| 383 |
+
k = attn.k_norm(k) # [b, h, s, d]
|
| 384 |
+
|
| 385 |
+
with torch.backends.cuda.sdp_kernel(
|
| 386 |
+
enable_flash=True, enable_math=False, enable_mem_efficient=True
|
| 387 |
+
):
|
| 388 |
+
hidden_states = F.scaled_dot_product_attention(q, k, v)
|
| 389 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(B, N, -1)
|
| 390 |
+
|
| 391 |
+
hidden_states = attn.out_proj(hidden_states)
|
| 392 |
+
if self.use_global:
|
| 393 |
+
hidden_states = hidden_states.reshape(B_old, N_old, -1)
|
| 394 |
+
return hidden_states
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class PartFormerDitBlock(nn.Module):
|
| 398 |
+
|
| 399 |
+
def __init__(
|
| 400 |
+
self,
|
| 401 |
+
hidden_size,
|
| 402 |
+
num_heads,
|
| 403 |
+
use_self_attention: bool = True,
|
| 404 |
+
use_cross_attention: bool = False,
|
| 405 |
+
use_cross_attention_2: bool = False,
|
| 406 |
+
encoder_hidden_dim=1024, # cross-attn encoder_hidden_states dim
|
| 407 |
+
encoder_hidden2_dim=1024, # cross-attn 2 encoder_hidden_states dim
|
| 408 |
+
# cross_attn2_weight=0.0,
|
| 409 |
+
qkv_bias=True,
|
| 410 |
+
qk_norm=False,
|
| 411 |
+
norm_layer=nn.LayerNorm,
|
| 412 |
+
qk_norm_layer=nn.RMSNorm,
|
| 413 |
+
with_decoupled_ca=False,
|
| 414 |
+
decoupled_ca_dim=16,
|
| 415 |
+
decoupled_ca_weight=1.0,
|
| 416 |
+
skip_connection=False,
|
| 417 |
+
timested_modulate=False,
|
| 418 |
+
c_emb_size=0, # time embedding size
|
| 419 |
+
use_moe: bool = False,
|
| 420 |
+
num_experts: int = 8,
|
| 421 |
+
moe_top_k: int = 2,
|
| 422 |
+
):
|
| 423 |
+
super().__init__()
|
| 424 |
+
# self.cross_attn2_weight = cross_attn2_weight
|
| 425 |
+
use_ele_affine = True
|
| 426 |
+
# ========================= Self-Attention =========================
|
| 427 |
+
self.use_self_attention = use_self_attention
|
| 428 |
+
if self.use_self_attention:
|
| 429 |
+
self.norm1 = norm_layer(
|
| 430 |
+
hidden_size, elementwise_affine=use_ele_affine, eps=1e-6
|
| 431 |
+
)
|
| 432 |
+
self.attn1 = Attention(
|
| 433 |
+
hidden_size,
|
| 434 |
+
num_heads=num_heads,
|
| 435 |
+
qkv_bias=qkv_bias,
|
| 436 |
+
qk_norm=qk_norm,
|
| 437 |
+
norm_layer=qk_norm_layer,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# ========================= Add =========================
|
| 441 |
+
# Simply use add like SDXL.
|
| 442 |
+
self.timested_modulate = timested_modulate
|
| 443 |
+
if self.timested_modulate:
|
| 444 |
+
self.default_modulation = nn.Sequential(
|
| 445 |
+
nn.SiLU(), nn.Linear(c_emb_size, hidden_size, bias=True)
|
| 446 |
+
)
|
| 447 |
+
# ========================= Cross-Attention =========================
|
| 448 |
+
self.use_cross_attention = use_cross_attention
|
| 449 |
+
if self.use_cross_attention:
|
| 450 |
+
self.norm2 = norm_layer(
|
| 451 |
+
hidden_size, elementwise_affine=use_ele_affine, eps=1e-6
|
| 452 |
+
)
|
| 453 |
+
self.attn2 = CrossAttention(
|
| 454 |
+
hidden_size,
|
| 455 |
+
encoder_hidden_dim,
|
| 456 |
+
num_heads=num_heads,
|
| 457 |
+
qkv_bias=qkv_bias,
|
| 458 |
+
qk_norm=qk_norm,
|
| 459 |
+
norm_layer=qk_norm_layer,
|
| 460 |
+
with_decoupled_ca=False,
|
| 461 |
+
)
|
| 462 |
+
self.use_cross_attention_2 = use_cross_attention_2
|
| 463 |
+
if self.use_cross_attention_2:
|
| 464 |
+
self.norm2_2 = norm_layer(
|
| 465 |
+
hidden_size, elementwise_affine=use_ele_affine, eps=1e-6
|
| 466 |
+
)
|
| 467 |
+
self.attn2_2 = CrossAttention(
|
| 468 |
+
hidden_size,
|
| 469 |
+
encoder_hidden2_dim,
|
| 470 |
+
num_heads=num_heads,
|
| 471 |
+
qkv_bias=qkv_bias,
|
| 472 |
+
qk_norm=qk_norm,
|
| 473 |
+
norm_layer=qk_norm_layer,
|
| 474 |
+
with_decoupled_ca=with_decoupled_ca,
|
| 475 |
+
decoupled_ca_dim=decoupled_ca_dim,
|
| 476 |
+
decoupled_ca_weight=decoupled_ca_weight,
|
| 477 |
+
)
|
| 478 |
+
# ========================= FFN =========================
|
| 479 |
+
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 480 |
+
self.use_moe = use_moe
|
| 481 |
+
if self.use_moe:
|
| 482 |
+
print("using moe")
|
| 483 |
+
self.moe = MoEBlock(
|
| 484 |
+
hidden_size,
|
| 485 |
+
num_experts=num_experts,
|
| 486 |
+
moe_top_k=moe_top_k,
|
| 487 |
+
dropout=0.0,
|
| 488 |
+
activation_fn="gelu",
|
| 489 |
+
final_dropout=False,
|
| 490 |
+
ff_inner_dim=int(hidden_size * 4.0),
|
| 491 |
+
ff_bias=True,
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
self.mlp = MLP(width=hidden_size)
|
| 495 |
+
# ========================= skip FFN =========================
|
| 496 |
+
if skip_connection:
|
| 497 |
+
self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 498 |
+
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
|
| 499 |
+
else:
|
| 500 |
+
self.skip_linear = None
|
| 501 |
+
|
| 502 |
+
def forward(
|
| 503 |
+
self,
|
| 504 |
+
hidden_states: torch.Tensor,
|
| 505 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 506 |
+
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
| 507 |
+
temb: Optional[torch.Tensor] = None,
|
| 508 |
+
skip_value: torch.Tensor = None,
|
| 509 |
+
):
|
| 510 |
+
# skip connection
|
| 511 |
+
if self.skip_linear is not None:
|
| 512 |
+
cat = torch.cat([skip_value, hidden_states], dim=-1)
|
| 513 |
+
hidden_states = self.skip_linear(cat)
|
| 514 |
+
hidden_states = self.skip_norm(hidden_states)
|
| 515 |
+
# local global attn (self-attn)
|
| 516 |
+
if self.timested_modulate:
|
| 517 |
+
shift_msa = self.default_modulation(temb).unsqueeze(dim=1)
|
| 518 |
+
hidden_states = hidden_states + shift_msa
|
| 519 |
+
if self.use_self_attention:
|
| 520 |
+
attn_output = self.attn1(self.norm1(hidden_states))
|
| 521 |
+
hidden_states = hidden_states + attn_output
|
| 522 |
+
# image cross attn
|
| 523 |
+
if self.use_cross_attention:
|
| 524 |
+
original_cross_out = self.attn2(
|
| 525 |
+
self.norm2(hidden_states),
|
| 526 |
+
encoder_hidden_states,
|
| 527 |
+
)
|
| 528 |
+
# added local-global cross attn
|
| 529 |
+
# 2. Cross-Attention
|
| 530 |
+
if self.use_cross_attention_2:
|
| 531 |
+
cross_out_2 = self.attn2_2(
|
| 532 |
+
self.norm2_2(hidden_states),
|
| 533 |
+
encoder_hidden_states_2,
|
| 534 |
+
)
|
| 535 |
+
hidden_states = (
|
| 536 |
+
hidden_states
|
| 537 |
+
+ (original_cross_out if self.use_cross_attention else 0)
|
| 538 |
+
+ (cross_out_2 if self.use_cross_attention_2 else 0)
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# FFN Layer
|
| 542 |
+
mlp_inputs = self.norm3(hidden_states)
|
| 543 |
+
|
| 544 |
+
if self.use_moe:
|
| 545 |
+
hidden_states = hidden_states + self.moe(mlp_inputs)
|
| 546 |
+
else:
|
| 547 |
+
hidden_states = hidden_states + self.mlp(mlp_inputs)
|
| 548 |
+
|
| 549 |
+
return hidden_states
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class FinalLayer(nn.Module):
|
| 553 |
+
"""
|
| 554 |
+
The final layer of HunYuanDiT.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def __init__(self, final_hidden_size, out_channels):
|
| 558 |
+
super().__init__()
|
| 559 |
+
self.final_hidden_size = final_hidden_size
|
| 560 |
+
self.norm_final = nn.LayerNorm(
|
| 561 |
+
final_hidden_size, elementwise_affine=True, eps=1e-6
|
| 562 |
+
)
|
| 563 |
+
self.linear = nn.Linear(final_hidden_size, out_channels, bias=True)
|
| 564 |
+
|
| 565 |
+
def forward(self, x):
|
| 566 |
+
x = self.norm_final(x)
|
| 567 |
+
x = x[:, 1:]
|
| 568 |
+
x = self.linear(x)
|
| 569 |
+
return x
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class PartFormerDITPlain(nn.Module):
|
| 573 |
+
|
| 574 |
+
def __init__(
|
| 575 |
+
self,
|
| 576 |
+
input_size=1024,
|
| 577 |
+
in_channels=4,
|
| 578 |
+
hidden_size=1024,
|
| 579 |
+
use_self_attention=True,
|
| 580 |
+
use_cross_attention=True,
|
| 581 |
+
use_cross_attention_2=True,
|
| 582 |
+
encoder_hidden_dim=1024, # cross-attn encoder_hidden_states dim
|
| 583 |
+
encoder_hidden2_dim=1024, # cross-attn 2 encoder_hidden_states dim
|
| 584 |
+
depth=24,
|
| 585 |
+
num_heads=16,
|
| 586 |
+
qk_norm=False,
|
| 587 |
+
qkv_bias=True,
|
| 588 |
+
norm_type="layer",
|
| 589 |
+
qk_norm_type="rms",
|
| 590 |
+
with_decoupled_ca=False,
|
| 591 |
+
decoupled_ca_dim=16,
|
| 592 |
+
decoupled_ca_weight=1.0,
|
| 593 |
+
use_pos_emb=False,
|
| 594 |
+
# use_attention_pooling=True,
|
| 595 |
+
guidance_cond_proj_dim=None,
|
| 596 |
+
num_moe_layers: int = 6,
|
| 597 |
+
num_experts: int = 8,
|
| 598 |
+
moe_top_k: int = 2,
|
| 599 |
+
**kwargs,
|
| 600 |
+
):
|
| 601 |
+
super().__init__()
|
| 602 |
+
|
| 603 |
+
self.input_size = input_size
|
| 604 |
+
self.depth = depth
|
| 605 |
+
self.in_channels = in_channels
|
| 606 |
+
self.out_channels = in_channels
|
| 607 |
+
self.num_heads = num_heads
|
| 608 |
+
|
| 609 |
+
self.hidden_size = hidden_size
|
| 610 |
+
self.norm = nn.LayerNorm if norm_type == "layer" else nn.RMSNorm
|
| 611 |
+
self.qk_norm = nn.RMSNorm if qk_norm_type == "rms" else nn.LayerNorm
|
| 612 |
+
# embedding
|
| 613 |
+
self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
|
| 614 |
+
self.t_embedder = TimestepEmbedder(
|
| 615 |
+
hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim
|
| 616 |
+
)
|
| 617 |
+
# Will use fixed sin-cos embedding:
|
| 618 |
+
self.use_pos_emb = use_pos_emb
|
| 619 |
+
if self.use_pos_emb:
|
| 620 |
+
self.register_buffer("pos_embed", torch.zeros(1, input_size, hidden_size))
|
| 621 |
+
pos = np.arange(self.input_size, dtype=np.float32)
|
| 622 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], pos)
|
| 623 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 624 |
+
|
| 625 |
+
# self.use_attention_pooling = use_attention_pooling
|
| 626 |
+
# if use_attention_pooling:
|
| 627 |
+
|
| 628 |
+
# self.pooler = AttentionPool(
|
| 629 |
+
# self.text_len, encoder_hidden_dim, num_heads=8, output_dim=1024
|
| 630 |
+
# )
|
| 631 |
+
# self.extra_embedder = nn.Sequential(
|
| 632 |
+
# nn.Linear(1024, hidden_size * 4),
|
| 633 |
+
# nn.SiLU(),
|
| 634 |
+
# nn.Linear(hidden_size * 4, hidden_size, bias=True),
|
| 635 |
+
# )
|
| 636 |
+
# for part embedding
|
| 637 |
+
self.use_bbox_cond = kwargs.get("use_bbox_cond", False)
|
| 638 |
+
if self.use_bbox_cond:
|
| 639 |
+
self.bbox_conditioner = BboxEmbedder(
|
| 640 |
+
out_size=hidden_size,
|
| 641 |
+
num_freqs=kwargs.get("num_freqs", 8),
|
| 642 |
+
)
|
| 643 |
+
self.use_part_embed = kwargs.get("use_part_embed", False)
|
| 644 |
+
if self.use_part_embed:
|
| 645 |
+
self.valid_num = kwargs.get("valid_num", 50)
|
| 646 |
+
self.part_embed = nn.Parameter(torch.randn(self.valid_num, hidden_size))
|
| 647 |
+
# zero init part_embed
|
| 648 |
+
self.part_embed.data.zero_()
|
| 649 |
+
# transformer blocks
|
| 650 |
+
self.blocks = nn.ModuleList([
|
| 651 |
+
PartFormerDitBlock(
|
| 652 |
+
hidden_size,
|
| 653 |
+
num_heads,
|
| 654 |
+
use_self_attention=use_self_attention,
|
| 655 |
+
use_cross_attention=use_cross_attention,
|
| 656 |
+
use_cross_attention_2=use_cross_attention_2,
|
| 657 |
+
encoder_hidden_dim=encoder_hidden_dim, # cross-attn encoder_hidden_states dim
|
| 658 |
+
encoder_hidden2_dim=encoder_hidden2_dim, # cross-attn 2 encoder_hidden_states dim
|
| 659 |
+
# cross_attn2_weight=cross_attn2_weight,
|
| 660 |
+
qkv_bias=qkv_bias,
|
| 661 |
+
qk_norm=qk_norm,
|
| 662 |
+
norm_layer=self.norm,
|
| 663 |
+
qk_norm_layer=self.qk_norm,
|
| 664 |
+
with_decoupled_ca=with_decoupled_ca,
|
| 665 |
+
decoupled_ca_dim=decoupled_ca_dim,
|
| 666 |
+
decoupled_ca_weight=decoupled_ca_weight,
|
| 667 |
+
skip_connection=layer > depth // 2,
|
| 668 |
+
use_moe=True if depth - layer <= num_moe_layers else False,
|
| 669 |
+
num_experts=num_experts,
|
| 670 |
+
moe_top_k=moe_top_k,
|
| 671 |
+
)
|
| 672 |
+
for layer in range(depth)
|
| 673 |
+
])
|
| 674 |
+
# set local-global processor
|
| 675 |
+
for layer, block in enumerate(self.blocks):
|
| 676 |
+
if hasattr(block, "attn1") and (layer + 1) % 2 == 0:
|
| 677 |
+
block.attn1.processor = LocalGlobalProcessor(use_global=True)
|
| 678 |
+
|
| 679 |
+
self.depth = depth
|
| 680 |
+
|
| 681 |
+
self.final_layer = FinalLayer(hidden_size, self.out_channels)
|
| 682 |
+
|
| 683 |
+
def forward(self, x, t, contexts: dict, **kwargs):
|
| 684 |
+
"""
|
| 685 |
+
|
| 686 |
+
x: [B, N, C]
|
| 687 |
+
t: [B]
|
| 688 |
+
contexts: dict
|
| 689 |
+
image_context: [B, K*ni, C]
|
| 690 |
+
geo_context: [B, K*ng, C] or [B, K*ng, C*2]
|
| 691 |
+
aabb: [B, K, 2, 3]
|
| 692 |
+
num_tokens: [B, N]
|
| 693 |
+
|
| 694 |
+
N = K * num_tokens
|
| 695 |
+
|
| 696 |
+
For parts pretrain : K = 1
|
| 697 |
+
"""
|
| 698 |
+
# prepare input
|
| 699 |
+
aabb: torch.Tensor = kwargs.get("aabb", None)
|
| 700 |
+
# image_context = contexts.get("image_un_cond", None)
|
| 701 |
+
object_context = contexts.get("obj_cond", None)
|
| 702 |
+
geo_context = contexts.get("geo_cond", None)
|
| 703 |
+
num_tokens: torch.Tensor = kwargs.get("num_tokens", None)
|
| 704 |
+
# timeembedding and input projection
|
| 705 |
+
t = self.t_embedder(t, condition=kwargs.get("guidance_cond"))
|
| 706 |
+
x = self.x_embedder(x)
|
| 707 |
+
|
| 708 |
+
if self.use_pos_emb:
|
| 709 |
+
pos_embed = self.pos_embed.to(x.dtype)
|
| 710 |
+
x = x + pos_embed
|
| 711 |
+
|
| 712 |
+
# c is time embedding (adding pooling context or not)
|
| 713 |
+
# if self.use_attention_pooling:
|
| 714 |
+
# # TODO: attention_pooling for all contexts
|
| 715 |
+
# extra_vec = self.pooler(image_context, None)
|
| 716 |
+
# c = t + self.extra_embedder(extra_vec) # [B, D]
|
| 717 |
+
# else:
|
| 718 |
+
# c = t
|
| 719 |
+
c = t
|
| 720 |
+
# bounding box
|
| 721 |
+
if self.use_bbox_cond:
|
| 722 |
+
center_extent = torch.cat(
|
| 723 |
+
[torch.mean(aabb, dim=-2), aabb[..., 1, :] - aabb[..., 0, :]], dim=-1
|
| 724 |
+
)
|
| 725 |
+
bbox_embeds = self.bbox_conditioner(center_extent)
|
| 726 |
+
# TODO: now only support batch_size=1
|
| 727 |
+
bbox_embeds = torch.repeat_interleave(
|
| 728 |
+
bbox_embeds, repeats=num_tokens[0], dim=1
|
| 729 |
+
)
|
| 730 |
+
x = x + bbox_embeds
|
| 731 |
+
# part id embedding
|
| 732 |
+
if self.use_part_embed:
|
| 733 |
+
num_parts = aabb.shape[1]
|
| 734 |
+
random_idx = torch.randperm(self.valid_num)[:num_parts]
|
| 735 |
+
part_embeds = self.part_embed[random_idx].unsqueeze(1)
|
| 736 |
+
# import pdb
|
| 737 |
+
|
| 738 |
+
# pdb.set_trace()
|
| 739 |
+
x = x + part_embeds
|
| 740 |
+
x = torch.cat([c, x], dim=1)
|
| 741 |
+
skip_value_list = []
|
| 742 |
+
for layer, block in enumerate(self.blocks):
|
| 743 |
+
skip_value = None if layer <= self.depth // 2 else skip_value_list.pop()
|
| 744 |
+
x = block(
|
| 745 |
+
hidden_states=x,
|
| 746 |
+
# encoder_hidden_states=image_context,
|
| 747 |
+
encoder_hidden_states=object_context,
|
| 748 |
+
encoder_hidden_states_2=geo_context,
|
| 749 |
+
temb=c,
|
| 750 |
+
skip_value=skip_value,
|
| 751 |
+
)
|
| 752 |
+
if layer < self.depth // 2:
|
| 753 |
+
skip_value_list.append(x)
|
| 754 |
+
|
| 755 |
+
x = self.final_layer(x)
|
| 756 |
+
return x
|
XPart/partgen/models/sonata/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from .model import load, load_by_config
|
| 17 |
+
|
| 18 |
+
from . import model
|
| 19 |
+
from . import module
|
| 20 |
+
from . import structure
|
| 21 |
+
from . import data
|
| 22 |
+
from . import transform
|
| 23 |
+
from . import utils
|
| 24 |
+
from . import registry
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"load",
|
| 28 |
+
"load_by_config",
|
| 29 |
+
"model",
|
| 30 |
+
"module",
|
| 31 |
+
"structure",
|
| 32 |
+
"transform",
|
| 33 |
+
"registry",
|
| 34 |
+
"utils",
|
| 35 |
+
]
|
XPart/partgen/models/sonata/data.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from collections.abc import Mapping, Sequence
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
DATAS = ["sample1", "sample1_high_res", "sample1_dino"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load(
|
| 27 |
+
name: str = "sonata",
|
| 28 |
+
download_root: str = None,
|
| 29 |
+
):
|
| 30 |
+
if name in DATAS:
|
| 31 |
+
print(f"Loading data from HuggingFace: {name} ...")
|
| 32 |
+
data_path = hf_hub_download(
|
| 33 |
+
repo_id="pointcept/demo",
|
| 34 |
+
filename=f"{name}.npz",
|
| 35 |
+
repo_type="dataset",
|
| 36 |
+
revision="main",
|
| 37 |
+
local_dir=download_root or os.path.expanduser("~/.cache/sonata/data"),
|
| 38 |
+
)
|
| 39 |
+
elif os.path.isfile(name):
|
| 40 |
+
print(f"Loading data in local path: {name} ...")
|
| 41 |
+
data_path = name
|
| 42 |
+
else:
|
| 43 |
+
raise RuntimeError(f"Data {name} not found; available models = {DATAS}")
|
| 44 |
+
return dict(np.load(data_path))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
from torch.utils.data.dataloader import default_collate
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def collate_fn(batch):
|
| 51 |
+
"""
|
| 52 |
+
collate function for point cloud which support dict and list,
|
| 53 |
+
'coord' is necessary to determine 'offset'
|
| 54 |
+
"""
|
| 55 |
+
if not isinstance(batch, Sequence):
|
| 56 |
+
raise TypeError(f"{batch.dtype} is not supported.")
|
| 57 |
+
|
| 58 |
+
if isinstance(batch[0], torch.Tensor):
|
| 59 |
+
return torch.cat(list(batch))
|
| 60 |
+
elif isinstance(batch[0], str):
|
| 61 |
+
# str is also a kind of Sequence, judgement should before Sequence
|
| 62 |
+
return list(batch)
|
| 63 |
+
elif isinstance(batch[0], Sequence):
|
| 64 |
+
for data in batch:
|
| 65 |
+
data.append(torch.tensor([data[0].shape[0]]))
|
| 66 |
+
batch = [collate_fn(samples) for samples in zip(*batch)]
|
| 67 |
+
batch[-1] = torch.cumsum(batch[-1], dim=0).int()
|
| 68 |
+
return batch
|
| 69 |
+
elif isinstance(batch[0], Mapping):
|
| 70 |
+
batch = {
|
| 71 |
+
key: (
|
| 72 |
+
collate_fn([d[key] for d in batch])
|
| 73 |
+
if "offset" not in key
|
| 74 |
+
# offset -> bincount -> concat bincount-> concat offset
|
| 75 |
+
else torch.cumsum(
|
| 76 |
+
collate_fn([d[key].diff(prepend=torch.tensor([0])) for d in batch]),
|
| 77 |
+
dim=0,
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
for key in batch[0]
|
| 81 |
+
}
|
| 82 |
+
return batch
|
| 83 |
+
else:
|
| 84 |
+
return default_collate(batch)
|
XPart/partgen/models/sonata/model.py
ADDED
|
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Point Transformer - V3 Mode2 - Sonata
|
| 3 |
+
Pointcept detached version
|
| 4 |
+
|
| 5 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
|
| 6 |
+
Please cite our work if the code is helpful to you.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 23 |
+
#
|
| 24 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 25 |
+
# you may not use this file except in compliance with the License.
|
| 26 |
+
# You may obtain a copy of the License at
|
| 27 |
+
#
|
| 28 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 29 |
+
#
|
| 30 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 31 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 32 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 33 |
+
# See the License for the specific language governing permissions and
|
| 34 |
+
# limitations under the License.
|
| 35 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 36 |
+
#
|
| 37 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 38 |
+
# you may not use this file except in compliance with the License.
|
| 39 |
+
# You may obtain a copy of the License at
|
| 40 |
+
#
|
| 41 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 42 |
+
#
|
| 43 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 44 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 45 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 46 |
+
# See the License for the specific language governing permissions and
|
| 47 |
+
# limitations under the License.
|
| 48 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 49 |
+
#
|
| 50 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 51 |
+
# you may not use this file except in compliance with the License.
|
| 52 |
+
# You may obtain a copy of the License at
|
| 53 |
+
#
|
| 54 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 55 |
+
#
|
| 56 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 57 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 58 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 59 |
+
# See the License for the specific language governing permissions and
|
| 60 |
+
# limitations under the License.
|
| 61 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 62 |
+
#
|
| 63 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 64 |
+
# you may not use this file except in compliance with the License.
|
| 65 |
+
# You may obtain a copy of the License at
|
| 66 |
+
#
|
| 67 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 68 |
+
#
|
| 69 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 70 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 71 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 72 |
+
# See the License for the specific language governing permissions and
|
| 73 |
+
# limitations under the License.
|
| 74 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 75 |
+
#
|
| 76 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 77 |
+
# you may not use this file except in compliance with the License.
|
| 78 |
+
# You may obtain a copy of the License at
|
| 79 |
+
#
|
| 80 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 81 |
+
#
|
| 82 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 83 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 84 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 85 |
+
# See the License for the specific language governing permissions and
|
| 86 |
+
# limitations under the License.
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
import os
|
| 90 |
+
from packaging import version
|
| 91 |
+
from huggingface_hub import hf_hub_download, PyTorchModelHubMixin
|
| 92 |
+
from addict import Dict
|
| 93 |
+
import torch
|
| 94 |
+
import torch.nn as nn
|
| 95 |
+
from torch.nn.init import trunc_normal_
|
| 96 |
+
import spconv.pytorch as spconv
|
| 97 |
+
import torch_scatter
|
| 98 |
+
from timm.layers import DropPath
|
| 99 |
+
import json
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
import flash_attn
|
| 103 |
+
except ImportError:
|
| 104 |
+
flash_attn = None
|
| 105 |
+
|
| 106 |
+
from .structure import Point
|
| 107 |
+
from .module import PointSequential, PointModule
|
| 108 |
+
from .utils import offset2bincount
|
| 109 |
+
|
| 110 |
+
MODELS = [
|
| 111 |
+
"sonata",
|
| 112 |
+
"sonata_small",
|
| 113 |
+
"sonata_linear_prob_head_sc",
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class LayerScale(nn.Module):
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
dim: int,
|
| 121 |
+
init_values: float = 1e-5,
|
| 122 |
+
inplace: bool = False,
|
| 123 |
+
) -> None:
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.inplace = inplace
|
| 126 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 127 |
+
|
| 128 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 129 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class RPE(torch.nn.Module):
|
| 133 |
+
def __init__(self, patch_size, num_heads):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.patch_size = patch_size
|
| 136 |
+
self.num_heads = num_heads
|
| 137 |
+
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
|
| 138 |
+
self.rpe_num = 2 * self.pos_bnd + 1
|
| 139 |
+
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
|
| 140 |
+
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
|
| 141 |
+
|
| 142 |
+
def forward(self, coord):
|
| 143 |
+
idx = (
|
| 144 |
+
coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
|
| 145 |
+
+ self.pos_bnd # relative position to positive index
|
| 146 |
+
+ torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
|
| 147 |
+
)
|
| 148 |
+
out = self.rpe_table.index_select(0, idx.reshape(-1))
|
| 149 |
+
out = out.view(idx.shape + (-1,)).sum(3)
|
| 150 |
+
out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
|
| 151 |
+
return out
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class SerializedAttention(PointModule):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
channels,
|
| 158 |
+
num_heads,
|
| 159 |
+
patch_size,
|
| 160 |
+
qkv_bias=True,
|
| 161 |
+
qk_scale=None,
|
| 162 |
+
attn_drop=0.0,
|
| 163 |
+
proj_drop=0.0,
|
| 164 |
+
order_index=0,
|
| 165 |
+
enable_rpe=False,
|
| 166 |
+
enable_flash=True,
|
| 167 |
+
upcast_attention=True,
|
| 168 |
+
upcast_softmax=True,
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
assert channels % num_heads == 0
|
| 172 |
+
self.channels = channels
|
| 173 |
+
self.num_heads = num_heads
|
| 174 |
+
self.scale = qk_scale or (channels // num_heads) ** -0.5
|
| 175 |
+
self.order_index = order_index
|
| 176 |
+
self.upcast_attention = upcast_attention
|
| 177 |
+
self.upcast_softmax = upcast_softmax
|
| 178 |
+
self.enable_rpe = enable_rpe
|
| 179 |
+
self.enable_flash = enable_flash
|
| 180 |
+
if enable_flash:
|
| 181 |
+
assert (
|
| 182 |
+
enable_rpe is False
|
| 183 |
+
), "Set enable_rpe to False when enable Flash Attention"
|
| 184 |
+
assert (
|
| 185 |
+
upcast_attention is False
|
| 186 |
+
), "Set upcast_attention to False when enable Flash Attention"
|
| 187 |
+
assert (
|
| 188 |
+
upcast_softmax is False
|
| 189 |
+
), "Set upcast_softmax to False when enable Flash Attention"
|
| 190 |
+
assert flash_attn is not None, "Make sure flash_attn is installed."
|
| 191 |
+
self.patch_size = patch_size
|
| 192 |
+
self.attn_drop = attn_drop
|
| 193 |
+
else:
|
| 194 |
+
# when disable flash attention, we still don't want to use mask
|
| 195 |
+
# consequently, patch size will auto set to the
|
| 196 |
+
# min number of patch_size_max and number of points
|
| 197 |
+
self.patch_size_max = patch_size
|
| 198 |
+
self.patch_size = 0
|
| 199 |
+
self.attn_drop = torch.nn.Dropout(attn_drop)
|
| 200 |
+
|
| 201 |
+
self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
|
| 202 |
+
self.proj = torch.nn.Linear(channels, channels)
|
| 203 |
+
self.proj_drop = torch.nn.Dropout(proj_drop)
|
| 204 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
| 205 |
+
self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
|
| 206 |
+
|
| 207 |
+
@torch.no_grad()
|
| 208 |
+
def get_rel_pos(self, point, order):
|
| 209 |
+
K = self.patch_size
|
| 210 |
+
rel_pos_key = f"rel_pos_{self.order_index}"
|
| 211 |
+
if rel_pos_key not in point.keys():
|
| 212 |
+
grid_coord = point.grid_coord[order]
|
| 213 |
+
grid_coord = grid_coord.reshape(-1, K, 3)
|
| 214 |
+
point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
|
| 215 |
+
return point[rel_pos_key]
|
| 216 |
+
|
| 217 |
+
@torch.no_grad()
|
| 218 |
+
def get_padding_and_inverse(self, point):
|
| 219 |
+
pad_key = "pad"
|
| 220 |
+
unpad_key = "unpad"
|
| 221 |
+
cu_seqlens_key = "cu_seqlens_key"
|
| 222 |
+
if (
|
| 223 |
+
pad_key not in point.keys()
|
| 224 |
+
or unpad_key not in point.keys()
|
| 225 |
+
or cu_seqlens_key not in point.keys()
|
| 226 |
+
):
|
| 227 |
+
offset = point.offset
|
| 228 |
+
bincount = offset2bincount(offset)
|
| 229 |
+
bincount_pad = (
|
| 230 |
+
torch.div(
|
| 231 |
+
bincount + self.patch_size - 1,
|
| 232 |
+
self.patch_size,
|
| 233 |
+
rounding_mode="trunc",
|
| 234 |
+
)
|
| 235 |
+
* self.patch_size
|
| 236 |
+
)
|
| 237 |
+
# only pad point when num of points larger than patch_size
|
| 238 |
+
mask_pad = bincount > self.patch_size
|
| 239 |
+
bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
|
| 240 |
+
_offset = nn.functional.pad(offset, (1, 0))
|
| 241 |
+
_offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))
|
| 242 |
+
pad = torch.arange(_offset_pad[-1], device=offset.device)
|
| 243 |
+
unpad = torch.arange(_offset[-1], device=offset.device)
|
| 244 |
+
cu_seqlens = []
|
| 245 |
+
for i in range(len(offset)):
|
| 246 |
+
unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
|
| 247 |
+
if bincount[i] != bincount_pad[i]:
|
| 248 |
+
pad[
|
| 249 |
+
_offset_pad[i + 1]
|
| 250 |
+
- self.patch_size
|
| 251 |
+
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1]
|
| 252 |
+
] = pad[
|
| 253 |
+
_offset_pad[i + 1]
|
| 254 |
+
- 2 * self.patch_size
|
| 255 |
+
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1]
|
| 256 |
+
- self.patch_size
|
| 257 |
+
]
|
| 258 |
+
pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
|
| 259 |
+
cu_seqlens.append(
|
| 260 |
+
torch.arange(
|
| 261 |
+
_offset_pad[i],
|
| 262 |
+
_offset_pad[i + 1],
|
| 263 |
+
step=self.patch_size,
|
| 264 |
+
dtype=torch.int32,
|
| 265 |
+
device=offset.device,
|
| 266 |
+
)
|
| 267 |
+
)
|
| 268 |
+
point[pad_key] = pad
|
| 269 |
+
point[unpad_key] = unpad
|
| 270 |
+
point[cu_seqlens_key] = nn.functional.pad(
|
| 271 |
+
torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
|
| 272 |
+
)
|
| 273 |
+
return point[pad_key], point[unpad_key], point[cu_seqlens_key]
|
| 274 |
+
|
| 275 |
+
def forward(self, point):
|
| 276 |
+
if not self.enable_flash:
|
| 277 |
+
self.patch_size = min(
|
| 278 |
+
offset2bincount(point.offset).min().tolist(), self.patch_size_max
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
H = self.num_heads
|
| 282 |
+
K = self.patch_size
|
| 283 |
+
C = self.channels
|
| 284 |
+
|
| 285 |
+
pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
|
| 286 |
+
|
| 287 |
+
order = point.serialized_order[self.order_index][pad]
|
| 288 |
+
inverse = unpad[point.serialized_inverse[self.order_index]]
|
| 289 |
+
|
| 290 |
+
# padding and reshape feat and batch for serialized point patch
|
| 291 |
+
qkv = self.qkv(point.feat)[order]
|
| 292 |
+
|
| 293 |
+
if not self.enable_flash:
|
| 294 |
+
# encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
|
| 295 |
+
q, k, v = (
|
| 296 |
+
qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
|
| 297 |
+
)
|
| 298 |
+
# attn
|
| 299 |
+
if self.upcast_attention:
|
| 300 |
+
q = q.float()
|
| 301 |
+
k = k.float()
|
| 302 |
+
attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
|
| 303 |
+
if self.enable_rpe:
|
| 304 |
+
attn = attn + self.rpe(self.get_rel_pos(point, order))
|
| 305 |
+
if self.upcast_softmax:
|
| 306 |
+
attn = attn.float()
|
| 307 |
+
attn = self.softmax(attn)
|
| 308 |
+
attn = self.attn_drop(attn).to(qkv.dtype)
|
| 309 |
+
feat = (attn @ v).transpose(1, 2).reshape(-1, C)
|
| 310 |
+
else:
|
| 311 |
+
feat = flash_attn.flash_attn_varlen_qkvpacked_func(
|
| 312 |
+
qkv.half().reshape(-1, 3, H, C // H),
|
| 313 |
+
cu_seqlens,
|
| 314 |
+
max_seqlen=self.patch_size,
|
| 315 |
+
dropout_p=self.attn_drop if self.training else 0,
|
| 316 |
+
softmax_scale=self.scale,
|
| 317 |
+
).reshape(-1, C)
|
| 318 |
+
feat = feat.to(qkv.dtype)
|
| 319 |
+
feat = feat[inverse]
|
| 320 |
+
|
| 321 |
+
# ffn
|
| 322 |
+
feat = self.proj(feat)
|
| 323 |
+
feat = self.proj_drop(feat)
|
| 324 |
+
point.feat = feat
|
| 325 |
+
return point
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class MLP(nn.Module):
|
| 329 |
+
def __init__(
|
| 330 |
+
self,
|
| 331 |
+
in_channels,
|
| 332 |
+
hidden_channels=None,
|
| 333 |
+
out_channels=None,
|
| 334 |
+
act_layer=nn.GELU,
|
| 335 |
+
drop=0.0,
|
| 336 |
+
):
|
| 337 |
+
super().__init__()
|
| 338 |
+
out_channels = out_channels or in_channels
|
| 339 |
+
hidden_channels = hidden_channels or in_channels
|
| 340 |
+
self.fc1 = nn.Linear(in_channels, hidden_channels)
|
| 341 |
+
self.act = act_layer()
|
| 342 |
+
self.fc2 = nn.Linear(hidden_channels, out_channels)
|
| 343 |
+
self.drop = nn.Dropout(drop)
|
| 344 |
+
|
| 345 |
+
def forward(self, x):
|
| 346 |
+
x = self.fc1(x)
|
| 347 |
+
x = self.act(x)
|
| 348 |
+
x = self.drop(x)
|
| 349 |
+
x = self.fc2(x)
|
| 350 |
+
x = self.drop(x)
|
| 351 |
+
return x
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class Block(PointModule):
|
| 355 |
+
def __init__(
|
| 356 |
+
self,
|
| 357 |
+
channels,
|
| 358 |
+
num_heads,
|
| 359 |
+
patch_size=48,
|
| 360 |
+
mlp_ratio=4.0,
|
| 361 |
+
qkv_bias=True,
|
| 362 |
+
qk_scale=None,
|
| 363 |
+
attn_drop=0.0,
|
| 364 |
+
proj_drop=0.0,
|
| 365 |
+
drop_path=0.0,
|
| 366 |
+
layer_scale=None,
|
| 367 |
+
norm_layer=nn.LayerNorm,
|
| 368 |
+
act_layer=nn.GELU,
|
| 369 |
+
pre_norm=True,
|
| 370 |
+
order_index=0,
|
| 371 |
+
cpe_indice_key=None,
|
| 372 |
+
enable_rpe=False,
|
| 373 |
+
enable_flash=True,
|
| 374 |
+
upcast_attention=True,
|
| 375 |
+
upcast_softmax=True,
|
| 376 |
+
):
|
| 377 |
+
super().__init__()
|
| 378 |
+
self.channels = channels
|
| 379 |
+
self.pre_norm = pre_norm
|
| 380 |
+
|
| 381 |
+
self.cpe = PointSequential(
|
| 382 |
+
spconv.SubMConv3d(
|
| 383 |
+
channels,
|
| 384 |
+
channels,
|
| 385 |
+
kernel_size=3,
|
| 386 |
+
bias=True,
|
| 387 |
+
indice_key=cpe_indice_key,
|
| 388 |
+
),
|
| 389 |
+
nn.Linear(channels, channels),
|
| 390 |
+
norm_layer(channels),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
self.norm1 = PointSequential(norm_layer(channels))
|
| 394 |
+
self.ls1 = PointSequential(
|
| 395 |
+
LayerScale(channels, init_values=layer_scale)
|
| 396 |
+
if layer_scale is not None
|
| 397 |
+
else nn.Identity()
|
| 398 |
+
)
|
| 399 |
+
self.attn = SerializedAttention(
|
| 400 |
+
channels=channels,
|
| 401 |
+
patch_size=patch_size,
|
| 402 |
+
num_heads=num_heads,
|
| 403 |
+
qkv_bias=qkv_bias,
|
| 404 |
+
qk_scale=qk_scale,
|
| 405 |
+
attn_drop=attn_drop,
|
| 406 |
+
proj_drop=proj_drop,
|
| 407 |
+
order_index=order_index,
|
| 408 |
+
enable_rpe=enable_rpe,
|
| 409 |
+
enable_flash=enable_flash,
|
| 410 |
+
upcast_attention=upcast_attention,
|
| 411 |
+
upcast_softmax=upcast_softmax,
|
| 412 |
+
)
|
| 413 |
+
self.norm2 = PointSequential(norm_layer(channels))
|
| 414 |
+
self.ls2 = PointSequential(
|
| 415 |
+
LayerScale(channels, init_values=layer_scale)
|
| 416 |
+
if layer_scale is not None
|
| 417 |
+
else nn.Identity()
|
| 418 |
+
)
|
| 419 |
+
self.mlp = PointSequential(
|
| 420 |
+
MLP(
|
| 421 |
+
in_channels=channels,
|
| 422 |
+
hidden_channels=int(channels * mlp_ratio),
|
| 423 |
+
out_channels=channels,
|
| 424 |
+
act_layer=act_layer,
|
| 425 |
+
drop=proj_drop,
|
| 426 |
+
)
|
| 427 |
+
)
|
| 428 |
+
self.drop_path = PointSequential(
|
| 429 |
+
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
def forward(self, point: Point):
|
| 433 |
+
shortcut = point.feat
|
| 434 |
+
point = self.cpe(point)
|
| 435 |
+
point.feat = shortcut + point.feat
|
| 436 |
+
shortcut = point.feat
|
| 437 |
+
if self.pre_norm:
|
| 438 |
+
point = self.norm1(point)
|
| 439 |
+
point = self.drop_path(self.ls1(self.attn(point)))
|
| 440 |
+
point.feat = shortcut + point.feat
|
| 441 |
+
if not self.pre_norm:
|
| 442 |
+
point = self.norm1(point)
|
| 443 |
+
|
| 444 |
+
shortcut = point.feat
|
| 445 |
+
if self.pre_norm:
|
| 446 |
+
point = self.norm2(point)
|
| 447 |
+
point = self.drop_path(self.ls2(self.mlp(point)))
|
| 448 |
+
point.feat = shortcut + point.feat
|
| 449 |
+
if not self.pre_norm:
|
| 450 |
+
point = self.norm2(point)
|
| 451 |
+
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
|
| 452 |
+
return point
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class GridPooling(PointModule):
|
| 456 |
+
def __init__(
|
| 457 |
+
self,
|
| 458 |
+
in_channels,
|
| 459 |
+
out_channels,
|
| 460 |
+
stride=2,
|
| 461 |
+
norm_layer=None,
|
| 462 |
+
act_layer=None,
|
| 463 |
+
reduce="max",
|
| 464 |
+
shuffle_orders=True,
|
| 465 |
+
traceable=True, # record parent and cluster
|
| 466 |
+
):
|
| 467 |
+
super().__init__()
|
| 468 |
+
self.in_channels = in_channels
|
| 469 |
+
self.out_channels = out_channels
|
| 470 |
+
|
| 471 |
+
self.stride = stride
|
| 472 |
+
assert reduce in ["sum", "mean", "min", "max"]
|
| 473 |
+
self.reduce = reduce
|
| 474 |
+
self.shuffle_orders = shuffle_orders
|
| 475 |
+
self.traceable = traceable
|
| 476 |
+
|
| 477 |
+
self.proj = nn.Linear(in_channels, out_channels)
|
| 478 |
+
if norm_layer is not None:
|
| 479 |
+
self.norm = PointSequential(norm_layer(out_channels))
|
| 480 |
+
if act_layer is not None:
|
| 481 |
+
self.act = PointSequential(act_layer())
|
| 482 |
+
|
| 483 |
+
def forward(self, point: Point):
|
| 484 |
+
if "grid_coord" in point.keys():
|
| 485 |
+
grid_coord = point.grid_coord
|
| 486 |
+
elif {"coord", "grid_size"}.issubset(point.keys()):
|
| 487 |
+
grid_coord = torch.div(
|
| 488 |
+
point.coord - point.coord.min(0)[0],
|
| 489 |
+
point.grid_size,
|
| 490 |
+
rounding_mode="trunc",
|
| 491 |
+
).int()
|
| 492 |
+
else:
|
| 493 |
+
raise AssertionError(
|
| 494 |
+
"[gird_coord] or [coord, grid_size] should be include in the Point"
|
| 495 |
+
)
|
| 496 |
+
grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc")
|
| 497 |
+
grid_coord = grid_coord | point.batch.view(-1, 1) << 48
|
| 498 |
+
grid_coord, cluster, counts = torch.unique(
|
| 499 |
+
grid_coord,
|
| 500 |
+
sorted=True,
|
| 501 |
+
return_inverse=True,
|
| 502 |
+
return_counts=True,
|
| 503 |
+
dim=0,
|
| 504 |
+
)
|
| 505 |
+
grid_coord = grid_coord & ((1 << 48) - 1)
|
| 506 |
+
# indices of point sorted by cluster, for torch_scatter.segment_csr
|
| 507 |
+
_, indices = torch.sort(cluster)
|
| 508 |
+
# index pointer for sorted point, for torch_scatter.segment_csr
|
| 509 |
+
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
|
| 510 |
+
# head_indices of each cluster, for reduce attr e.g. code, batch
|
| 511 |
+
head_indices = indices[idx_ptr[:-1]]
|
| 512 |
+
point_dict = Dict(
|
| 513 |
+
feat=torch_scatter.segment_csr(
|
| 514 |
+
self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
|
| 515 |
+
),
|
| 516 |
+
coord=torch_scatter.segment_csr(
|
| 517 |
+
point.coord[indices], idx_ptr, reduce="mean"
|
| 518 |
+
),
|
| 519 |
+
grid_coord=grid_coord,
|
| 520 |
+
batch=point.batch[head_indices],
|
| 521 |
+
)
|
| 522 |
+
if "origin_coord" in point.keys():
|
| 523 |
+
point_dict["origin_coord"] = torch_scatter.segment_csr(
|
| 524 |
+
point.origin_coord[indices], idx_ptr, reduce="mean"
|
| 525 |
+
)
|
| 526 |
+
if "condition" in point.keys():
|
| 527 |
+
point_dict["condition"] = point.condition
|
| 528 |
+
if "context" in point.keys():
|
| 529 |
+
point_dict["context"] = point.context
|
| 530 |
+
if "name" in point.keys():
|
| 531 |
+
point_dict["name"] = point.name
|
| 532 |
+
if "split" in point.keys():
|
| 533 |
+
point_dict["split"] = point.split
|
| 534 |
+
if "color" in point.keys():
|
| 535 |
+
point_dict["color"] = torch_scatter.segment_csr(
|
| 536 |
+
point.color[indices], idx_ptr, reduce="mean"
|
| 537 |
+
)
|
| 538 |
+
if "grid_size" in point.keys():
|
| 539 |
+
point_dict["grid_size"] = point.grid_size * self.stride
|
| 540 |
+
|
| 541 |
+
if self.traceable:
|
| 542 |
+
point_dict["pooling_inverse"] = cluster
|
| 543 |
+
point_dict["pooling_parent"] = point
|
| 544 |
+
order = point.order
|
| 545 |
+
point = Point(point_dict)
|
| 546 |
+
if self.norm is not None:
|
| 547 |
+
point = self.norm(point)
|
| 548 |
+
if self.act is not None:
|
| 549 |
+
point = self.act(point)
|
| 550 |
+
point.serialization(order=order, shuffle_orders=self.shuffle_orders)
|
| 551 |
+
point.sparsify()
|
| 552 |
+
return point
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
class GridUnpooling(PointModule):
|
| 556 |
+
def __init__(
|
| 557 |
+
self,
|
| 558 |
+
in_channels,
|
| 559 |
+
skip_channels,
|
| 560 |
+
out_channels,
|
| 561 |
+
norm_layer=None,
|
| 562 |
+
act_layer=None,
|
| 563 |
+
traceable=False, # record parent and cluster
|
| 564 |
+
):
|
| 565 |
+
super().__init__()
|
| 566 |
+
self.proj = PointSequential(nn.Linear(in_channels, out_channels))
|
| 567 |
+
self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))
|
| 568 |
+
|
| 569 |
+
if norm_layer is not None:
|
| 570 |
+
self.proj.add(norm_layer(out_channels))
|
| 571 |
+
self.proj_skip.add(norm_layer(out_channels))
|
| 572 |
+
|
| 573 |
+
if act_layer is not None:
|
| 574 |
+
self.proj.add(act_layer())
|
| 575 |
+
self.proj_skip.add(act_layer())
|
| 576 |
+
|
| 577 |
+
self.traceable = traceable
|
| 578 |
+
|
| 579 |
+
def forward(self, point):
|
| 580 |
+
assert "pooling_parent" in point.keys()
|
| 581 |
+
assert "pooling_inverse" in point.keys()
|
| 582 |
+
parent = point.pop("pooling_parent")
|
| 583 |
+
inverse = point.pooling_inverse
|
| 584 |
+
feat = point.feat
|
| 585 |
+
|
| 586 |
+
parent = self.proj_skip(parent)
|
| 587 |
+
parent.feat = parent.feat + self.proj(point).feat[inverse]
|
| 588 |
+
parent.sparse_conv_feat = parent.sparse_conv_feat.replace_feature(parent.feat)
|
| 589 |
+
|
| 590 |
+
if self.traceable:
|
| 591 |
+
point.feat = feat
|
| 592 |
+
parent["unpooling_parent"] = point
|
| 593 |
+
return parent
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
class Embedding(PointModule):
|
| 597 |
+
def __init__(
|
| 598 |
+
self,
|
| 599 |
+
in_channels,
|
| 600 |
+
embed_channels,
|
| 601 |
+
norm_layer=None,
|
| 602 |
+
act_layer=None,
|
| 603 |
+
mask_token=False,
|
| 604 |
+
):
|
| 605 |
+
super().__init__()
|
| 606 |
+
self.in_channels = in_channels
|
| 607 |
+
self.embed_channels = embed_channels
|
| 608 |
+
|
| 609 |
+
self.stem = PointSequential(linear=nn.Linear(in_channels, embed_channels))
|
| 610 |
+
if norm_layer is not None:
|
| 611 |
+
self.stem.add(norm_layer(embed_channels), name="norm")
|
| 612 |
+
if act_layer is not None:
|
| 613 |
+
self.stem.add(act_layer(), name="act")
|
| 614 |
+
|
| 615 |
+
if mask_token:
|
| 616 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_channels))
|
| 617 |
+
else:
|
| 618 |
+
self.mask_token = None
|
| 619 |
+
|
| 620 |
+
def forward(self, point: Point):
|
| 621 |
+
point = self.stem(point)
|
| 622 |
+
if "mask" in point.keys():
|
| 623 |
+
point.feat = torch.where(
|
| 624 |
+
point.mask.unsqueeze(-1),
|
| 625 |
+
self.mask_token.to(point.feat.dtype),
|
| 626 |
+
point.feat,
|
| 627 |
+
)
|
| 628 |
+
return point
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
class PointTransformerV3(PointModule, PyTorchModelHubMixin):
|
| 632 |
+
def __init__(
|
| 633 |
+
self,
|
| 634 |
+
in_channels=6,
|
| 635 |
+
order=("z", "z-trans"),
|
| 636 |
+
stride=(2, 2, 2, 2),
|
| 637 |
+
enc_depths=(3, 3, 3, 12, 3),
|
| 638 |
+
enc_channels=(48, 96, 192, 384, 512),
|
| 639 |
+
enc_num_head=(3, 6, 12, 24, 32),
|
| 640 |
+
enc_patch_size=(1024, 1024, 1024, 1024, 1024),
|
| 641 |
+
dec_depths=(3, 3, 3, 3),
|
| 642 |
+
dec_channels=(96, 96, 192, 384),
|
| 643 |
+
dec_num_head=(6, 6, 12, 32),
|
| 644 |
+
dec_patch_size=(1024, 1024, 1024, 1024),
|
| 645 |
+
mlp_ratio=4,
|
| 646 |
+
qkv_bias=True,
|
| 647 |
+
qk_scale=None,
|
| 648 |
+
attn_drop=0.0,
|
| 649 |
+
proj_drop=0.0,
|
| 650 |
+
drop_path=0.3,
|
| 651 |
+
layer_scale=None,
|
| 652 |
+
pre_norm=True,
|
| 653 |
+
shuffle_orders=True,
|
| 654 |
+
enable_rpe=False,
|
| 655 |
+
enable_flash=True,
|
| 656 |
+
upcast_attention=False,
|
| 657 |
+
upcast_softmax=False,
|
| 658 |
+
traceable=False,
|
| 659 |
+
mask_token=False,
|
| 660 |
+
enc_mode=False,
|
| 661 |
+
freeze_encoder=False,
|
| 662 |
+
):
|
| 663 |
+
super().__init__()
|
| 664 |
+
self.num_stages = len(enc_depths)
|
| 665 |
+
self.order = [order] if isinstance(order, str) else order
|
| 666 |
+
self.enc_mode = enc_mode
|
| 667 |
+
self.shuffle_orders = shuffle_orders
|
| 668 |
+
self.freeze_encoder = freeze_encoder
|
| 669 |
+
|
| 670 |
+
assert self.num_stages == len(stride) + 1
|
| 671 |
+
assert self.num_stages == len(enc_depths)
|
| 672 |
+
assert self.num_stages == len(enc_channels)
|
| 673 |
+
assert self.num_stages == len(enc_num_head)
|
| 674 |
+
assert self.num_stages == len(enc_patch_size)
|
| 675 |
+
assert self.enc_mode or self.num_stages == len(dec_depths) + 1
|
| 676 |
+
assert self.enc_mode or self.num_stages == len(dec_channels) + 1
|
| 677 |
+
assert self.enc_mode or self.num_stages == len(dec_num_head) + 1
|
| 678 |
+
assert self.enc_mode or self.num_stages == len(dec_patch_size) + 1
|
| 679 |
+
|
| 680 |
+
print(f"flash attention: {enable_flash}")
|
| 681 |
+
|
| 682 |
+
# normalization layer
|
| 683 |
+
ln_layer = nn.LayerNorm
|
| 684 |
+
# activation layers
|
| 685 |
+
act_layer = nn.GELU
|
| 686 |
+
|
| 687 |
+
self.embedding = Embedding(
|
| 688 |
+
in_channels=in_channels,
|
| 689 |
+
embed_channels=enc_channels[0],
|
| 690 |
+
norm_layer=ln_layer,
|
| 691 |
+
act_layer=act_layer,
|
| 692 |
+
mask_token=mask_token,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# encoder
|
| 696 |
+
enc_drop_path = [
|
| 697 |
+
x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
|
| 698 |
+
]
|
| 699 |
+
self.enc = PointSequential()
|
| 700 |
+
for s in range(self.num_stages):
|
| 701 |
+
enc_drop_path_ = enc_drop_path[
|
| 702 |
+
sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
|
| 703 |
+
]
|
| 704 |
+
enc = PointSequential()
|
| 705 |
+
if s > 0:
|
| 706 |
+
enc.add(
|
| 707 |
+
GridPooling(
|
| 708 |
+
in_channels=enc_channels[s - 1],
|
| 709 |
+
out_channels=enc_channels[s],
|
| 710 |
+
stride=stride[s - 1],
|
| 711 |
+
norm_layer=ln_layer,
|
| 712 |
+
act_layer=act_layer,
|
| 713 |
+
),
|
| 714 |
+
name="down",
|
| 715 |
+
)
|
| 716 |
+
for i in range(enc_depths[s]):
|
| 717 |
+
enc.add(
|
| 718 |
+
Block(
|
| 719 |
+
channels=enc_channels[s],
|
| 720 |
+
num_heads=enc_num_head[s],
|
| 721 |
+
patch_size=enc_patch_size[s],
|
| 722 |
+
mlp_ratio=mlp_ratio,
|
| 723 |
+
qkv_bias=qkv_bias,
|
| 724 |
+
qk_scale=qk_scale,
|
| 725 |
+
attn_drop=attn_drop,
|
| 726 |
+
proj_drop=proj_drop,
|
| 727 |
+
drop_path=enc_drop_path_[i],
|
| 728 |
+
layer_scale=layer_scale,
|
| 729 |
+
norm_layer=ln_layer,
|
| 730 |
+
act_layer=act_layer,
|
| 731 |
+
pre_norm=pre_norm,
|
| 732 |
+
order_index=i % len(self.order),
|
| 733 |
+
cpe_indice_key=f"stage{s}",
|
| 734 |
+
enable_rpe=enable_rpe,
|
| 735 |
+
enable_flash=enable_flash,
|
| 736 |
+
upcast_attention=upcast_attention,
|
| 737 |
+
upcast_softmax=upcast_softmax,
|
| 738 |
+
),
|
| 739 |
+
name=f"block{i}",
|
| 740 |
+
)
|
| 741 |
+
if len(enc) != 0:
|
| 742 |
+
self.enc.add(module=enc, name=f"enc{s}")
|
| 743 |
+
|
| 744 |
+
# decoder
|
| 745 |
+
if not self.enc_mode:
|
| 746 |
+
dec_drop_path = [
|
| 747 |
+
x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
|
| 748 |
+
]
|
| 749 |
+
self.dec = PointSequential()
|
| 750 |
+
dec_channels = list(dec_channels) + [enc_channels[-1]]
|
| 751 |
+
for s in reversed(range(self.num_stages - 1)):
|
| 752 |
+
dec_drop_path_ = dec_drop_path[
|
| 753 |
+
sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
|
| 754 |
+
]
|
| 755 |
+
dec_drop_path_.reverse()
|
| 756 |
+
dec = PointSequential()
|
| 757 |
+
dec.add(
|
| 758 |
+
GridUnpooling(
|
| 759 |
+
in_channels=dec_channels[s + 1],
|
| 760 |
+
skip_channels=enc_channels[s],
|
| 761 |
+
out_channels=dec_channels[s],
|
| 762 |
+
norm_layer=ln_layer,
|
| 763 |
+
act_layer=act_layer,
|
| 764 |
+
traceable=traceable,
|
| 765 |
+
),
|
| 766 |
+
name="up",
|
| 767 |
+
)
|
| 768 |
+
for i in range(dec_depths[s]):
|
| 769 |
+
dec.add(
|
| 770 |
+
Block(
|
| 771 |
+
channels=dec_channels[s],
|
| 772 |
+
num_heads=dec_num_head[s],
|
| 773 |
+
patch_size=dec_patch_size[s],
|
| 774 |
+
mlp_ratio=mlp_ratio,
|
| 775 |
+
qkv_bias=qkv_bias,
|
| 776 |
+
qk_scale=qk_scale,
|
| 777 |
+
attn_drop=attn_drop,
|
| 778 |
+
proj_drop=proj_drop,
|
| 779 |
+
drop_path=dec_drop_path_[i],
|
| 780 |
+
layer_scale=layer_scale,
|
| 781 |
+
norm_layer=ln_layer,
|
| 782 |
+
act_layer=act_layer,
|
| 783 |
+
pre_norm=pre_norm,
|
| 784 |
+
order_index=i % len(self.order),
|
| 785 |
+
cpe_indice_key=f"stage{s}",
|
| 786 |
+
enable_rpe=enable_rpe,
|
| 787 |
+
enable_flash=enable_flash,
|
| 788 |
+
upcast_attention=upcast_attention,
|
| 789 |
+
upcast_softmax=upcast_softmax,
|
| 790 |
+
),
|
| 791 |
+
name=f"block{i}",
|
| 792 |
+
)
|
| 793 |
+
self.dec.add(module=dec, name=f"dec{s}")
|
| 794 |
+
if self.freeze_encoder:
|
| 795 |
+
for p in self.embedding.parameters():
|
| 796 |
+
p.requires_grad = False
|
| 797 |
+
for p in self.enc.parameters():
|
| 798 |
+
p.requires_grad = False
|
| 799 |
+
self.apply(self._init_weights)
|
| 800 |
+
|
| 801 |
+
@staticmethod
|
| 802 |
+
def _init_weights(module):
|
| 803 |
+
if isinstance(module, nn.Linear):
|
| 804 |
+
trunc_normal_(module.weight, std=0.02)
|
| 805 |
+
if module.bias is not None:
|
| 806 |
+
nn.init.zeros_(module.bias)
|
| 807 |
+
elif isinstance(module, spconv.SubMConv3d):
|
| 808 |
+
trunc_normal_(module.weight, std=0.02)
|
| 809 |
+
if module.bias is not None:
|
| 810 |
+
nn.init.zeros_(module.bias)
|
| 811 |
+
|
| 812 |
+
def forward(self, data_dict):
|
| 813 |
+
point = Point(data_dict)
|
| 814 |
+
point = self.embedding(point)
|
| 815 |
+
|
| 816 |
+
point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
|
| 817 |
+
point.sparsify()
|
| 818 |
+
|
| 819 |
+
point = self.enc(point)
|
| 820 |
+
if not self.enc_mode:
|
| 821 |
+
point = self.dec(point)
|
| 822 |
+
return point
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def load(
|
| 826 |
+
name: str = "sonata",
|
| 827 |
+
repo_id="facebook/sonata",
|
| 828 |
+
download_root: str = None,
|
| 829 |
+
custom_config: dict = None,
|
| 830 |
+
ckpt_only: bool = False,
|
| 831 |
+
):
|
| 832 |
+
if name in MODELS:
|
| 833 |
+
print(f"Loading checkpoint from HuggingFace: {name} ...")
|
| 834 |
+
ckpt_path = hf_hub_download(
|
| 835 |
+
repo_id=repo_id,
|
| 836 |
+
filename=f"{name}.pth",
|
| 837 |
+
repo_type="model",
|
| 838 |
+
revision="main",
|
| 839 |
+
local_dir=download_root or os.path.expanduser("~/.cache/sonata/ckpt"),
|
| 840 |
+
)
|
| 841 |
+
elif os.path.isfile(name):
|
| 842 |
+
print(f"Loading checkpoint in local path: {name} ...")
|
| 843 |
+
ckpt_path = name
|
| 844 |
+
else:
|
| 845 |
+
raise RuntimeError(f"Model {name} not found; available models = {MODELS}")
|
| 846 |
+
|
| 847 |
+
if version.parse(torch.__version__) >= version.parse("2.4"):
|
| 848 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 849 |
+
else:
|
| 850 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 851 |
+
if custom_config is not None:
|
| 852 |
+
for key, value in custom_config.items():
|
| 853 |
+
ckpt["config"][key] = value
|
| 854 |
+
|
| 855 |
+
if ckpt_only:
|
| 856 |
+
return ckpt
|
| 857 |
+
|
| 858 |
+
# 关闭flash attention
|
| 859 |
+
# ckpt["config"]['enable_flash'] = False
|
| 860 |
+
|
| 861 |
+
model = PointTransformerV3(**ckpt["config"])
|
| 862 |
+
model.load_state_dict(ckpt["state_dict"])
|
| 863 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 864 |
+
print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}")
|
| 865 |
+
return model
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def load_by_config(config_path: str):
|
| 869 |
+
with open(config_path, "r") as f:
|
| 870 |
+
config = json.load(f)
|
| 871 |
+
model = PointTransformerV3(**config)
|
| 872 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 873 |
+
print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}")
|
| 874 |
+
return model
|
XPart/partgen/models/sonata/module.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Point Modules
|
| 3 |
+
Pointcept detached version
|
| 4 |
+
|
| 5 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
|
| 6 |
+
Please cite our work if the code is helpful to you.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
import sys
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import spconv.pytorch as spconv
|
| 27 |
+
from collections import OrderedDict
|
| 28 |
+
|
| 29 |
+
from .structure import Point
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class PointModule(nn.Module):
|
| 33 |
+
r"""PointModule
|
| 34 |
+
placeholder, all module subclass from this will take Point in PointSequential.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, *args, **kwargs):
|
| 38 |
+
super().__init__(*args, **kwargs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class PointSequential(PointModule):
|
| 42 |
+
r"""A sequential container.
|
| 43 |
+
Modules will be added to it in the order they are passed in the constructor.
|
| 44 |
+
Alternatively, an ordered dict of modules can also be passed in.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, *args, **kwargs):
|
| 48 |
+
super().__init__()
|
| 49 |
+
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
| 50 |
+
for key, module in args[0].items():
|
| 51 |
+
self.add_module(key, module)
|
| 52 |
+
else:
|
| 53 |
+
for idx, module in enumerate(args):
|
| 54 |
+
self.add_module(str(idx), module)
|
| 55 |
+
for name, module in kwargs.items():
|
| 56 |
+
if sys.version_info < (3, 6):
|
| 57 |
+
raise ValueError("kwargs only supported in py36+")
|
| 58 |
+
if name in self._modules:
|
| 59 |
+
raise ValueError("name exists.")
|
| 60 |
+
self.add_module(name, module)
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, idx):
|
| 63 |
+
if not (-len(self) <= idx < len(self)):
|
| 64 |
+
raise IndexError("index {} is out of range".format(idx))
|
| 65 |
+
if idx < 0:
|
| 66 |
+
idx += len(self)
|
| 67 |
+
it = iter(self._modules.values())
|
| 68 |
+
for i in range(idx):
|
| 69 |
+
next(it)
|
| 70 |
+
return next(it)
|
| 71 |
+
|
| 72 |
+
def __len__(self):
|
| 73 |
+
return len(self._modules)
|
| 74 |
+
|
| 75 |
+
def add(self, module, name=None):
|
| 76 |
+
if name is None:
|
| 77 |
+
name = str(len(self._modules))
|
| 78 |
+
if name in self._modules:
|
| 79 |
+
raise KeyError("name exists")
|
| 80 |
+
self.add_module(name, module)
|
| 81 |
+
|
| 82 |
+
def forward(self, input):
|
| 83 |
+
for k, module in self._modules.items():
|
| 84 |
+
# Point module
|
| 85 |
+
if isinstance(module, PointModule):
|
| 86 |
+
input = module(input)
|
| 87 |
+
# Spconv module
|
| 88 |
+
elif spconv.modules.is_spconv_module(module):
|
| 89 |
+
if isinstance(input, Point):
|
| 90 |
+
input.sparse_conv_feat = module(input.sparse_conv_feat)
|
| 91 |
+
input.feat = input.sparse_conv_feat.features
|
| 92 |
+
else:
|
| 93 |
+
input = module(input)
|
| 94 |
+
# PyTorch module
|
| 95 |
+
else:
|
| 96 |
+
if isinstance(input, Point):
|
| 97 |
+
input.feat = module(input.feat)
|
| 98 |
+
if "sparse_conv_feat" in input.keys():
|
| 99 |
+
input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(
|
| 100 |
+
input.feat
|
| 101 |
+
)
|
| 102 |
+
elif isinstance(input, spconv.SparseConvTensor):
|
| 103 |
+
if input.indices.shape[0] != 0:
|
| 104 |
+
input = input.replace_feature(module(input.features))
|
| 105 |
+
else:
|
| 106 |
+
input = module(input)
|
| 107 |
+
return input
|
XPart/partgen/models/sonata/registry.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @lint-ignore-every LICENSELINT
|
| 2 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 3 |
+
import inspect
|
| 4 |
+
import warnings
|
| 5 |
+
from functools import partial
|
| 6 |
+
from collections import abc
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def is_seq_of(seq, expected_type, seq_type=None):
|
| 10 |
+
"""Check whether it is a sequence of some type.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
seq (Sequence): The sequence to be checked.
|
| 14 |
+
expected_type (type): Expected type of sequence items.
|
| 15 |
+
seq_type (type, optional): Expected sequence type.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
bool: Whether the sequence is valid.
|
| 19 |
+
"""
|
| 20 |
+
if seq_type is None:
|
| 21 |
+
exp_seq_type = abc.Sequence
|
| 22 |
+
else:
|
| 23 |
+
assert isinstance(seq_type, type)
|
| 24 |
+
exp_seq_type = seq_type
|
| 25 |
+
if not isinstance(seq, exp_seq_type):
|
| 26 |
+
return False
|
| 27 |
+
for item in seq:
|
| 28 |
+
if not isinstance(item, expected_type):
|
| 29 |
+
return False
|
| 30 |
+
return True
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def build_from_cfg(cfg, registry, default_args=None):
|
| 34 |
+
"""Build a module from configs dict.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
cfg (dict): Config dict. It should at least contain the key "type".
|
| 38 |
+
registry (:obj:`Registry`): The registry to search the type from.
|
| 39 |
+
default_args (dict, optional): Default initialization arguments.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
object: The constructed object.
|
| 43 |
+
"""
|
| 44 |
+
if not isinstance(cfg, dict):
|
| 45 |
+
raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
|
| 46 |
+
if "type" not in cfg:
|
| 47 |
+
if default_args is None or "type" not in default_args:
|
| 48 |
+
raise KeyError(
|
| 49 |
+
'`cfg` or `default_args` must contain the key "type", '
|
| 50 |
+
f"but got {cfg}\n{default_args}"
|
| 51 |
+
)
|
| 52 |
+
if not isinstance(registry, Registry):
|
| 53 |
+
raise TypeError(
|
| 54 |
+
"registry must be an mmcv.Registry object, " f"but got {type(registry)}"
|
| 55 |
+
)
|
| 56 |
+
if not (isinstance(default_args, dict) or default_args is None):
|
| 57 |
+
raise TypeError(
|
| 58 |
+
"default_args must be a dict or None, " f"but got {type(default_args)}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
args = cfg.copy()
|
| 62 |
+
|
| 63 |
+
if default_args is not None:
|
| 64 |
+
for name, value in default_args.items():
|
| 65 |
+
args.setdefault(name, value)
|
| 66 |
+
|
| 67 |
+
obj_type = args.pop("type")
|
| 68 |
+
if isinstance(obj_type, str):
|
| 69 |
+
obj_cls = registry.get(obj_type)
|
| 70 |
+
if obj_cls is None:
|
| 71 |
+
raise KeyError(f"{obj_type} is not in the {registry.name} registry")
|
| 72 |
+
elif inspect.isclass(obj_type):
|
| 73 |
+
obj_cls = obj_type
|
| 74 |
+
else:
|
| 75 |
+
raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
|
| 76 |
+
try:
|
| 77 |
+
return obj_cls(**args)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
# Normal TypeError does not print class name.
|
| 80 |
+
raise type(e)(f"{obj_cls.__name__}: {e}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Registry:
|
| 84 |
+
"""A registry to map strings to classes.
|
| 85 |
+
|
| 86 |
+
Registered object could be built from registry.
|
| 87 |
+
Example:
|
| 88 |
+
>>> MODELS = Registry('models')
|
| 89 |
+
>>> @MODELS.register_module()
|
| 90 |
+
>>> class ResNet:
|
| 91 |
+
>>> pass
|
| 92 |
+
>>> resnet = MODELS.build(dict(type='ResNet'))
|
| 93 |
+
|
| 94 |
+
Please refer to
|
| 95 |
+
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
| 96 |
+
advanced usage.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
name (str): Registry name.
|
| 100 |
+
build_func(func, optional): Build function to construct instance from
|
| 101 |
+
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
| 102 |
+
``build_func`` is specified. If ``parent`` is specified and
|
| 103 |
+
``build_func`` is not given, ``build_func`` will be inherited
|
| 104 |
+
from ``parent``. Default: None.
|
| 105 |
+
parent (Registry, optional): Parent registry. The class registered in
|
| 106 |
+
children registry could be built from parent. Default: None.
|
| 107 |
+
scope (str, optional): The scope of registry. It is the key to search
|
| 108 |
+
for children registry. If not specified, scope will be the name of
|
| 109 |
+
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
| 110 |
+
Default: None.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, name, build_func=None, parent=None, scope=None):
|
| 114 |
+
self._name = name
|
| 115 |
+
self._module_dict = dict()
|
| 116 |
+
self._children = dict()
|
| 117 |
+
self._scope = self.infer_scope() if scope is None else scope
|
| 118 |
+
|
| 119 |
+
# self.build_func will be set with the following priority:
|
| 120 |
+
# 1. build_func
|
| 121 |
+
# 2. parent.build_func
|
| 122 |
+
# 3. build_from_cfg
|
| 123 |
+
if build_func is None:
|
| 124 |
+
if parent is not None:
|
| 125 |
+
self.build_func = parent.build_func
|
| 126 |
+
else:
|
| 127 |
+
self.build_func = build_from_cfg
|
| 128 |
+
else:
|
| 129 |
+
self.build_func = build_func
|
| 130 |
+
if parent is not None:
|
| 131 |
+
assert isinstance(parent, Registry)
|
| 132 |
+
parent._add_children(self)
|
| 133 |
+
self.parent = parent
|
| 134 |
+
else:
|
| 135 |
+
self.parent = None
|
| 136 |
+
|
| 137 |
+
def __len__(self):
|
| 138 |
+
return len(self._module_dict)
|
| 139 |
+
|
| 140 |
+
def __contains__(self, key):
|
| 141 |
+
return self.get(key) is not None
|
| 142 |
+
|
| 143 |
+
def __repr__(self):
|
| 144 |
+
format_str = (
|
| 145 |
+
self.__class__.__name__ + f"(name={self._name}, "
|
| 146 |
+
f"items={self._module_dict})"
|
| 147 |
+
)
|
| 148 |
+
return format_str
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
def infer_scope():
|
| 152 |
+
"""Infer the scope of registry.
|
| 153 |
+
|
| 154 |
+
The name of the package where registry is defined will be returned.
|
| 155 |
+
|
| 156 |
+
Example:
|
| 157 |
+
# in mmdet/models/backbone/resnet.py
|
| 158 |
+
>>> MODELS = Registry('models')
|
| 159 |
+
>>> @MODELS.register_module()
|
| 160 |
+
>>> class ResNet:
|
| 161 |
+
>>> pass
|
| 162 |
+
The scope of ``ResNet`` will be ``mmdet``.
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
scope (str): The inferred scope name.
|
| 167 |
+
"""
|
| 168 |
+
# inspect.stack() trace where this function is called, the index-2
|
| 169 |
+
# indicates the frame where `infer_scope()` is called
|
| 170 |
+
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
|
| 171 |
+
split_filename = filename.split(".")
|
| 172 |
+
return split_filename[0]
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def split_scope_key(key):
|
| 176 |
+
"""Split scope and key.
|
| 177 |
+
|
| 178 |
+
The first scope will be split from key.
|
| 179 |
+
|
| 180 |
+
Examples:
|
| 181 |
+
>>> Registry.split_scope_key('mmdet.ResNet')
|
| 182 |
+
'mmdet', 'ResNet'
|
| 183 |
+
>>> Registry.split_scope_key('ResNet')
|
| 184 |
+
None, 'ResNet'
|
| 185 |
+
|
| 186 |
+
Return:
|
| 187 |
+
scope (str, None): The first scope.
|
| 188 |
+
key (str): The remaining key.
|
| 189 |
+
"""
|
| 190 |
+
split_index = key.find(".")
|
| 191 |
+
if split_index != -1:
|
| 192 |
+
return key[:split_index], key[split_index + 1 :]
|
| 193 |
+
else:
|
| 194 |
+
return None, key
|
| 195 |
+
|
| 196 |
+
@property
|
| 197 |
+
def name(self):
|
| 198 |
+
return self._name
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def scope(self):
|
| 202 |
+
return self._scope
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def module_dict(self):
|
| 206 |
+
return self._module_dict
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def children(self):
|
| 210 |
+
return self._children
|
| 211 |
+
|
| 212 |
+
def get(self, key):
|
| 213 |
+
"""Get the registry record.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
key (str): The class name in string format.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
class: The corresponding class.
|
| 220 |
+
"""
|
| 221 |
+
scope, real_key = self.split_scope_key(key)
|
| 222 |
+
if scope is None or scope == self._scope:
|
| 223 |
+
# get from self
|
| 224 |
+
if real_key in self._module_dict:
|
| 225 |
+
return self._module_dict[real_key]
|
| 226 |
+
else:
|
| 227 |
+
# get from self._children
|
| 228 |
+
if scope in self._children:
|
| 229 |
+
return self._children[scope].get(real_key)
|
| 230 |
+
else:
|
| 231 |
+
# goto root
|
| 232 |
+
parent = self.parent
|
| 233 |
+
while parent.parent is not None:
|
| 234 |
+
parent = parent.parent
|
| 235 |
+
return parent.get(key)
|
| 236 |
+
|
| 237 |
+
def build(self, *args, **kwargs):
|
| 238 |
+
return self.build_func(*args, **kwargs, registry=self)
|
| 239 |
+
|
| 240 |
+
def _add_children(self, registry):
|
| 241 |
+
"""Add children for a registry.
|
| 242 |
+
|
| 243 |
+
The ``registry`` will be added as children based on its scope.
|
| 244 |
+
The parent registry could build objects from children registry.
|
| 245 |
+
|
| 246 |
+
Example:
|
| 247 |
+
>>> models = Registry('models')
|
| 248 |
+
>>> mmdet_models = Registry('models', parent=models)
|
| 249 |
+
>>> @mmdet_models.register_module()
|
| 250 |
+
>>> class ResNet:
|
| 251 |
+
>>> pass
|
| 252 |
+
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
assert isinstance(registry, Registry)
|
| 256 |
+
assert registry.scope is not None
|
| 257 |
+
assert (
|
| 258 |
+
registry.scope not in self.children
|
| 259 |
+
), f"scope {registry.scope} exists in {self.name} registry"
|
| 260 |
+
self.children[registry.scope] = registry
|
| 261 |
+
|
| 262 |
+
def _register_module(self, module_class, module_name=None, force=False):
|
| 263 |
+
if not inspect.isclass(module_class):
|
| 264 |
+
raise TypeError("module must be a class, " f"but got {type(module_class)}")
|
| 265 |
+
|
| 266 |
+
if module_name is None:
|
| 267 |
+
module_name = module_class.__name__
|
| 268 |
+
if isinstance(module_name, str):
|
| 269 |
+
module_name = [module_name]
|
| 270 |
+
for name in module_name:
|
| 271 |
+
if not force and name in self._module_dict:
|
| 272 |
+
raise KeyError(f"{name} is already registered " f"in {self.name}")
|
| 273 |
+
self._module_dict[name] = module_class
|
| 274 |
+
|
| 275 |
+
def deprecated_register_module(self, cls=None, force=False):
|
| 276 |
+
warnings.warn(
|
| 277 |
+
"The old API of register_module(module, force=False) "
|
| 278 |
+
"is deprecated and will be removed, please use the new API "
|
| 279 |
+
"register_module(name=None, force=False, module=None) instead."
|
| 280 |
+
)
|
| 281 |
+
if cls is None:
|
| 282 |
+
return partial(self.deprecated_register_module, force=force)
|
| 283 |
+
self._register_module(cls, force=force)
|
| 284 |
+
return cls
|
| 285 |
+
|
| 286 |
+
def register_module(self, name=None, force=False, module=None):
|
| 287 |
+
"""Register a module.
|
| 288 |
+
|
| 289 |
+
A record will be added to `self._module_dict`, whose key is the class
|
| 290 |
+
name or the specified name, and value is the class itself.
|
| 291 |
+
It can be used as a decorator or a normal function.
|
| 292 |
+
|
| 293 |
+
Example:
|
| 294 |
+
>>> backbones = Registry('backbone')
|
| 295 |
+
>>> @backbones.register_module()
|
| 296 |
+
>>> class ResNet:
|
| 297 |
+
>>> pass
|
| 298 |
+
|
| 299 |
+
>>> backbones = Registry('backbone')
|
| 300 |
+
>>> @backbones.register_module(name='mnet')
|
| 301 |
+
>>> class MobileNet:
|
| 302 |
+
>>> pass
|
| 303 |
+
|
| 304 |
+
>>> backbones = Registry('backbone')
|
| 305 |
+
>>> class ResNet:
|
| 306 |
+
>>> pass
|
| 307 |
+
>>> backbones.register_module(ResNet)
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
name (str | None): The module name to be registered. If not
|
| 311 |
+
specified, the class name will be used.
|
| 312 |
+
force (bool, optional): Whether to override an existing class with
|
| 313 |
+
the same name. Default: False.
|
| 314 |
+
module (type): Module class to be registered.
|
| 315 |
+
"""
|
| 316 |
+
if not isinstance(force, bool):
|
| 317 |
+
raise TypeError(f"force must be a boolean, but got {type(force)}")
|
| 318 |
+
# NOTE: This is a walkaround to be compatible with the old api,
|
| 319 |
+
# while it may introduce unexpected bugs.
|
| 320 |
+
if isinstance(name, type):
|
| 321 |
+
return self.deprecated_register_module(name, force=force)
|
| 322 |
+
|
| 323 |
+
# raise the error ahead of time
|
| 324 |
+
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
| 325 |
+
raise TypeError(
|
| 326 |
+
"name must be either of None, an instance of str or a sequence"
|
| 327 |
+
f" of str, but got {type(name)}"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# use it as a normal method: x.register_module(module=SomeClass)
|
| 331 |
+
if module is not None:
|
| 332 |
+
self._register_module(module_class=module, module_name=name, force=force)
|
| 333 |
+
return module
|
| 334 |
+
|
| 335 |
+
# use it as a decorator: @x.register_module()
|
| 336 |
+
def _register(cls):
|
| 337 |
+
self._register_module(module_class=cls, module_name=name, force=force)
|
| 338 |
+
return cls
|
| 339 |
+
|
| 340 |
+
return _register
|
XPart/partgen/models/sonata/serialization/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#init.py
|
| 2 |
+
from .default import (
|
| 3 |
+
encode,
|
| 4 |
+
decode,
|
| 5 |
+
z_order_encode,
|
| 6 |
+
z_order_decode,
|
| 7 |
+
hilbert_encode,
|
| 8 |
+
hilbert_decode,
|
| 9 |
+
)
|
XPart/partgen/models/sonata/serialization/default.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Serialization Encoding
|
| 3 |
+
Pointcept detached version
|
| 4 |
+
|
| 5 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
|
| 6 |
+
Please cite our work if the code is helpful to you.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from .z_order import xyz2key as z_order_encode_
|
| 26 |
+
from .z_order import key2xyz as z_order_decode_
|
| 27 |
+
from .hilbert import encode as hilbert_encode_
|
| 28 |
+
from .hilbert import decode as hilbert_decode_
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@torch.inference_mode()
|
| 32 |
+
def encode(grid_coord, batch=None, depth=16, order="z"):
|
| 33 |
+
assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
|
| 34 |
+
if order == "z":
|
| 35 |
+
code = z_order_encode(grid_coord, depth=depth)
|
| 36 |
+
elif order == "z-trans":
|
| 37 |
+
code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
| 38 |
+
elif order == "hilbert":
|
| 39 |
+
code = hilbert_encode(grid_coord, depth=depth)
|
| 40 |
+
elif order == "hilbert-trans":
|
| 41 |
+
code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
| 42 |
+
else:
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
if batch is not None:
|
| 45 |
+
batch = batch.long()
|
| 46 |
+
code = batch << depth * 3 | code
|
| 47 |
+
return code
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@torch.inference_mode()
|
| 51 |
+
def decode(code, depth=16, order="z"):
|
| 52 |
+
assert order in {"z", "hilbert"}
|
| 53 |
+
batch = code >> depth * 3
|
| 54 |
+
code = code & ((1 << depth * 3) - 1)
|
| 55 |
+
if order == "z":
|
| 56 |
+
grid_coord = z_order_decode(code, depth=depth)
|
| 57 |
+
elif order == "hilbert":
|
| 58 |
+
grid_coord = hilbert_decode(code, depth=depth)
|
| 59 |
+
else:
|
| 60 |
+
raise NotImplementedError
|
| 61 |
+
return grid_coord, batch
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
|
| 65 |
+
x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
|
| 66 |
+
# we block the support to batch, maintain batched code in Point class
|
| 67 |
+
code = z_order_encode_(x, y, z, b=None, depth=depth)
|
| 68 |
+
return code
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def z_order_decode(code: torch.Tensor, depth):
|
| 72 |
+
x, y, z = z_order_decode_(code, depth=depth)
|
| 73 |
+
grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
|
| 74 |
+
return grid_coord
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
|
| 78 |
+
return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def hilbert_decode(code: torch.Tensor, depth: int = 16):
|
| 82 |
+
return hilbert_decode_(code, num_dims=3, num_bits=depth)
|
XPart/partgen/models/sonata/serialization/hilbert.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hilbert Order
|
| 3 |
+
Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
|
| 4 |
+
|
| 5 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu
|
| 6 |
+
Please cite our work if the code is helpful to you.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def right_shift(binary, k=1, axis=-1):
|
| 28 |
+
"""Right shift an array of binary values.
|
| 29 |
+
|
| 30 |
+
Parameters:
|
| 31 |
+
-----------
|
| 32 |
+
binary: An ndarray of binary values.
|
| 33 |
+
|
| 34 |
+
k: The number of bits to shift. Default 1.
|
| 35 |
+
|
| 36 |
+
axis: The axis along which to shift. Default -1.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
--------
|
| 40 |
+
Returns an ndarray with zero prepended and the ends truncated, along
|
| 41 |
+
whatever axis was specified."""
|
| 42 |
+
|
| 43 |
+
# If we're shifting the whole thing, just return zeros.
|
| 44 |
+
if binary.shape[axis] <= k:
|
| 45 |
+
return torch.zeros_like(binary)
|
| 46 |
+
|
| 47 |
+
# Determine the padding pattern.
|
| 48 |
+
# padding = [(0,0)] * len(binary.shape)
|
| 49 |
+
# padding[axis] = (k,0)
|
| 50 |
+
|
| 51 |
+
# Determine the slicing pattern to eliminate just the last one.
|
| 52 |
+
slicing = [slice(None)] * len(binary.shape)
|
| 53 |
+
slicing[axis] = slice(None, -k)
|
| 54 |
+
shifted = torch.nn.functional.pad(
|
| 55 |
+
binary[tuple(slicing)], (k, 0), mode="constant", value=0
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return shifted
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def binary2gray(binary, axis=-1):
|
| 62 |
+
"""Convert an array of binary values into Gray codes.
|
| 63 |
+
|
| 64 |
+
This uses the classic X ^ (X >> 1) trick to compute the Gray code.
|
| 65 |
+
|
| 66 |
+
Parameters:
|
| 67 |
+
-----------
|
| 68 |
+
binary: An ndarray of binary values.
|
| 69 |
+
|
| 70 |
+
axis: The axis along which to compute the gray code. Default=-1.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
--------
|
| 74 |
+
Returns an ndarray of Gray codes.
|
| 75 |
+
"""
|
| 76 |
+
shifted = right_shift(binary, axis=axis)
|
| 77 |
+
|
| 78 |
+
# Do the X ^ (X >> 1) trick.
|
| 79 |
+
gray = torch.logical_xor(binary, shifted)
|
| 80 |
+
|
| 81 |
+
return gray
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def gray2binary(gray, axis=-1):
|
| 85 |
+
"""Convert an array of Gray codes back into binary values.
|
| 86 |
+
|
| 87 |
+
Parameters:
|
| 88 |
+
-----------
|
| 89 |
+
gray: An ndarray of gray codes.
|
| 90 |
+
|
| 91 |
+
axis: The axis along which to perform Gray decoding. Default=-1.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
--------
|
| 95 |
+
Returns an ndarray of binary values.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
# Loop the log2(bits) number of times necessary, with shift and xor.
|
| 99 |
+
shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
|
| 100 |
+
while shift > 0:
|
| 101 |
+
gray = torch.logical_xor(gray, right_shift(gray, shift))
|
| 102 |
+
shift = torch.div(shift, 2, rounding_mode="floor")
|
| 103 |
+
return gray
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def encode(locs, num_dims, num_bits):
|
| 107 |
+
"""Decode an array of locations in a hypercube into a Hilbert integer.
|
| 108 |
+
|
| 109 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
| 110 |
+
Skilling as described in:
|
| 111 |
+
|
| 112 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
| 113 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
| 114 |
+
|
| 115 |
+
Params:
|
| 116 |
+
-------
|
| 117 |
+
locs - An ndarray of locations in a hypercube of num_dims dimensions, in
|
| 118 |
+
which each dimension runs from 0 to 2**num_bits-1. The shape can
|
| 119 |
+
be arbitrary, as long as the last dimension of the same has size
|
| 120 |
+
num_dims.
|
| 121 |
+
|
| 122 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
| 123 |
+
|
| 124 |
+
num_bits - The number of bits for each dimension. Integer.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
--------
|
| 128 |
+
The output is an ndarray of uint64 integers with the same shape as the
|
| 129 |
+
input, excluding the last dimension, which needs to be num_dims.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
# Keep around the original shape for later.
|
| 133 |
+
orig_shape = locs.shape
|
| 134 |
+
bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
|
| 135 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
| 136 |
+
|
| 137 |
+
if orig_shape[-1] != num_dims:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
"""
|
| 140 |
+
The shape of locs was surprising in that the last dimension was of size
|
| 141 |
+
%d, but num_dims=%d. These need to be equal.
|
| 142 |
+
"""
|
| 143 |
+
% (orig_shape[-1], num_dims)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if num_dims * num_bits > 63:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
"""
|
| 149 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
| 150 |
+
into a int64. Are you sure you need that many points on your Hilbert
|
| 151 |
+
curve?
|
| 152 |
+
"""
|
| 153 |
+
% (num_dims, num_bits, num_dims * num_bits)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Treat the location integers as 64-bit unsigned and then split them up into
|
| 157 |
+
# a sequence of uint8s. Preserve the association by dimension.
|
| 158 |
+
locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
|
| 159 |
+
|
| 160 |
+
# Now turn these into bits and truncate to num_bits.
|
| 161 |
+
gray = (
|
| 162 |
+
locs_uint8.unsqueeze(-1)
|
| 163 |
+
.bitwise_and(bitpack_mask_rev)
|
| 164 |
+
.ne(0)
|
| 165 |
+
.byte()
|
| 166 |
+
.flatten(-2, -1)[..., -num_bits:]
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Run the decoding process the other way.
|
| 170 |
+
# Iterate forwards through the bits.
|
| 171 |
+
for bit in range(0, num_bits):
|
| 172 |
+
# Iterate forwards through the dimensions.
|
| 173 |
+
for dim in range(0, num_dims):
|
| 174 |
+
# Identify which ones have this bit active.
|
| 175 |
+
mask = gray[:, dim, bit]
|
| 176 |
+
|
| 177 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
| 178 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
| 179 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
| 183 |
+
to_flip = torch.logical_and(
|
| 184 |
+
torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
|
| 185 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
| 186 |
+
)
|
| 187 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
| 188 |
+
gray[:, dim, bit + 1 :], to_flip
|
| 189 |
+
)
|
| 190 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
| 191 |
+
|
| 192 |
+
# Now flatten out.
|
| 193 |
+
gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
|
| 194 |
+
|
| 195 |
+
# Convert Gray back to binary.
|
| 196 |
+
hh_bin = gray2binary(gray)
|
| 197 |
+
|
| 198 |
+
# Pad back out to 64 bits.
|
| 199 |
+
extra_dims = 64 - num_bits * num_dims
|
| 200 |
+
padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
|
| 201 |
+
|
| 202 |
+
# Convert binary values into uint8s.
|
| 203 |
+
hh_uint8 = (
|
| 204 |
+
(padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
|
| 205 |
+
.sum(2)
|
| 206 |
+
.squeeze()
|
| 207 |
+
.type(torch.uint8)
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Convert uint8s into uint64s.
|
| 211 |
+
hh_uint64 = hh_uint8.view(torch.int64).squeeze()
|
| 212 |
+
|
| 213 |
+
return hh_uint64
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def decode(hilberts, num_dims, num_bits):
|
| 217 |
+
"""Decode an array of Hilbert integers into locations in a hypercube.
|
| 218 |
+
|
| 219 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
| 220 |
+
Skilling as described in:
|
| 221 |
+
|
| 222 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
| 223 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
| 224 |
+
|
| 225 |
+
Params:
|
| 226 |
+
-------
|
| 227 |
+
hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
|
| 228 |
+
cannot have fewer bits than num_dims * num_bits.
|
| 229 |
+
|
| 230 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
| 231 |
+
|
| 232 |
+
num_bits - The number of bits for each dimension. Integer.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
--------
|
| 236 |
+
The output is an ndarray of unsigned integers with the same shape as hilberts
|
| 237 |
+
but with an additional dimension of size num_dims.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
if num_dims * num_bits > 64:
|
| 241 |
+
raise ValueError(
|
| 242 |
+
"""
|
| 243 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
| 244 |
+
into a uint64. Are you sure you need that many points on your Hilbert
|
| 245 |
+
curve?
|
| 246 |
+
"""
|
| 247 |
+
% (num_dims, num_bits)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Handle the case where we got handed a naked integer.
|
| 251 |
+
hilberts = torch.atleast_1d(hilberts)
|
| 252 |
+
|
| 253 |
+
# Keep around the shape for later.
|
| 254 |
+
orig_shape = hilberts.shape
|
| 255 |
+
bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
|
| 256 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
| 257 |
+
|
| 258 |
+
# Treat each of the hilberts as a s equence of eight uint8.
|
| 259 |
+
# This treats all of the inputs as uint64 and makes things uniform.
|
| 260 |
+
hh_uint8 = (
|
| 261 |
+
hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Turn these lists of uints into lists of bits and then truncate to the size
|
| 265 |
+
# we actually need for using Skilling's procedure.
|
| 266 |
+
hh_bits = (
|
| 267 |
+
hh_uint8.unsqueeze(-1)
|
| 268 |
+
.bitwise_and(bitpack_mask_rev)
|
| 269 |
+
.ne(0)
|
| 270 |
+
.byte()
|
| 271 |
+
.flatten(-2, -1)[:, -num_dims * num_bits :]
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Take the sequence of bits and Gray-code it.
|
| 275 |
+
gray = binary2gray(hh_bits)
|
| 276 |
+
|
| 277 |
+
# There has got to be a better way to do this.
|
| 278 |
+
# I could index them differently, but the eventual packbits likes it this way.
|
| 279 |
+
gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
|
| 280 |
+
|
| 281 |
+
# Iterate backwards through the bits.
|
| 282 |
+
for bit in range(num_bits - 1, -1, -1):
|
| 283 |
+
# Iterate backwards through the dimensions.
|
| 284 |
+
for dim in range(num_dims - 1, -1, -1):
|
| 285 |
+
# Identify which ones have this bit active.
|
| 286 |
+
mask = gray[:, dim, bit]
|
| 287 |
+
|
| 288 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
| 289 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
| 290 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
| 294 |
+
to_flip = torch.logical_and(
|
| 295 |
+
torch.logical_not(mask[:, None]),
|
| 296 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
| 297 |
+
)
|
| 298 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
| 299 |
+
gray[:, dim, bit + 1 :], to_flip
|
| 300 |
+
)
|
| 301 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
| 302 |
+
|
| 303 |
+
# Pad back out to 64 bits.
|
| 304 |
+
extra_dims = 64 - num_bits
|
| 305 |
+
padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
|
| 306 |
+
|
| 307 |
+
# Now chop these up into blocks of 8.
|
| 308 |
+
locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
|
| 309 |
+
|
| 310 |
+
# Take those blocks and turn them unto uint8s.
|
| 311 |
+
# from IPython import embed; embed()
|
| 312 |
+
locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
|
| 313 |
+
|
| 314 |
+
# Finally, treat these as uint64s.
|
| 315 |
+
flat_locs = locs_uint8.view(torch.int64)
|
| 316 |
+
|
| 317 |
+
# Return them in the expected shape.
|
| 318 |
+
return flat_locs.reshape((*orig_shape, num_dims))
|
XPart/partgen/models/sonata/serialization/z_order.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @lint-ignore-every LICENSELINT
|
| 2 |
+
# --------------------------------------------------------
|
| 3 |
+
# Octree-based Sparse Convolutional Neural Networks
|
| 4 |
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Written by Peng-Shuai Wang
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
# Octree-based Sparse Convolutional Neural Networks
|
| 10 |
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
| 11 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 12 |
+
# Written by Peng-Shuai Wang
|
| 13 |
+
# --------------------------------------------------------
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# Octree-based Sparse Convolutional Neural Networks
|
| 16 |
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
| 17 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 18 |
+
# Written by Peng-Shuai Wang
|
| 19 |
+
# --------------------------------------------------------
|
| 20 |
+
# --------------------------------------------------------
|
| 21 |
+
# Octree-based Sparse Convolutional Neural Networks
|
| 22 |
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
| 23 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 24 |
+
# Written by Peng-Shuai Wang
|
| 25 |
+
# --------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from typing import Optional, Union
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class KeyLUT:
|
| 32 |
+
def __init__(self):
|
| 33 |
+
r256 = torch.arange(256, dtype=torch.int64)
|
| 34 |
+
r512 = torch.arange(512, dtype=torch.int64)
|
| 35 |
+
zero = torch.zeros(256, dtype=torch.int64)
|
| 36 |
+
device = torch.device("cpu")
|
| 37 |
+
|
| 38 |
+
self._encode = {
|
| 39 |
+
device: (
|
| 40 |
+
self.xyz2key(r256, zero, zero, 8),
|
| 41 |
+
self.xyz2key(zero, r256, zero, 8),
|
| 42 |
+
self.xyz2key(zero, zero, r256, 8),
|
| 43 |
+
)
|
| 44 |
+
}
|
| 45 |
+
self._decode = {device: self.key2xyz(r512, 9)}
|
| 46 |
+
|
| 47 |
+
def encode_lut(self, device=torch.device("cpu")):
|
| 48 |
+
if device not in self._encode:
|
| 49 |
+
cpu = torch.device("cpu")
|
| 50 |
+
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
|
| 51 |
+
return self._encode[device]
|
| 52 |
+
|
| 53 |
+
def decode_lut(self, device=torch.device("cpu")):
|
| 54 |
+
if device not in self._decode:
|
| 55 |
+
cpu = torch.device("cpu")
|
| 56 |
+
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
|
| 57 |
+
return self._decode[device]
|
| 58 |
+
|
| 59 |
+
def xyz2key(self, x, y, z, depth):
|
| 60 |
+
key = torch.zeros_like(x)
|
| 61 |
+
for i in range(depth):
|
| 62 |
+
mask = 1 << i
|
| 63 |
+
key = (
|
| 64 |
+
key
|
| 65 |
+
| ((x & mask) << (2 * i + 2))
|
| 66 |
+
| ((y & mask) << (2 * i + 1))
|
| 67 |
+
| ((z & mask) << (2 * i + 0))
|
| 68 |
+
)
|
| 69 |
+
return key
|
| 70 |
+
|
| 71 |
+
def key2xyz(self, key, depth):
|
| 72 |
+
x = torch.zeros_like(key)
|
| 73 |
+
y = torch.zeros_like(key)
|
| 74 |
+
z = torch.zeros_like(key)
|
| 75 |
+
for i in range(depth):
|
| 76 |
+
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
|
| 77 |
+
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
|
| 78 |
+
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
|
| 79 |
+
return x, y, z
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
_key_lut = KeyLUT()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def xyz2key(
|
| 86 |
+
x: torch.Tensor,
|
| 87 |
+
y: torch.Tensor,
|
| 88 |
+
z: torch.Tensor,
|
| 89 |
+
b: Optional[Union[torch.Tensor, int]] = None,
|
| 90 |
+
depth: int = 16,
|
| 91 |
+
):
|
| 92 |
+
"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
|
| 93 |
+
based on pre-computed look up tables. The speed of this function is much
|
| 94 |
+
faster than the method based on for-loop.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
x (torch.Tensor): The x coordinate.
|
| 98 |
+
y (torch.Tensor): The y coordinate.
|
| 99 |
+
z (torch.Tensor): The z coordinate.
|
| 100 |
+
b (torch.Tensor or int): The batch index of the coordinates, and should be
|
| 101 |
+
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
|
| 102 |
+
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
|
| 103 |
+
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
EX, EY, EZ = _key_lut.encode_lut(x.device)
|
| 107 |
+
x, y, z = x.long(), y.long(), z.long()
|
| 108 |
+
|
| 109 |
+
mask = 255 if depth > 8 else (1 << depth) - 1
|
| 110 |
+
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
|
| 111 |
+
if depth > 8:
|
| 112 |
+
mask = (1 << (depth - 8)) - 1
|
| 113 |
+
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
|
| 114 |
+
key = key16 << 24 | key
|
| 115 |
+
|
| 116 |
+
if b is not None:
|
| 117 |
+
b = b.long()
|
| 118 |
+
key = b << 48 | key
|
| 119 |
+
|
| 120 |
+
return key
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def key2xyz(key: torch.Tensor, depth: int = 16):
|
| 124 |
+
r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
|
| 125 |
+
and the batch index based on pre-computed look up tables.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
key (torch.Tensor): The shuffled key.
|
| 129 |
+
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
DX, DY, DZ = _key_lut.decode_lut(key.device)
|
| 133 |
+
x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
|
| 134 |
+
|
| 135 |
+
b = key >> 48
|
| 136 |
+
key = key & ((1 << 48) - 1)
|
| 137 |
+
|
| 138 |
+
n = (depth + 2) // 3
|
| 139 |
+
for i in range(n):
|
| 140 |
+
k = key >> (i * 9) & 511
|
| 141 |
+
x = x | (DX[k] << (i * 3))
|
| 142 |
+
y = y | (DY[k] << (i * 3))
|
| 143 |
+
z = z | (DZ[k] << (i * 3))
|
| 144 |
+
|
| 145 |
+
return x, y, z, b
|
XPart/partgen/models/sonata/structure.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data structure for 3D point cloud
|
| 3 |
+
|
| 4 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
|
| 5 |
+
Please cite our work if the code is helpful to you.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import spconv.pytorch as spconv
|
| 25 |
+
from addict import Dict
|
| 26 |
+
|
| 27 |
+
from .serialization import encode
|
| 28 |
+
from .utils import offset2batch, batch2offset
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Point(Dict):
|
| 32 |
+
"""
|
| 33 |
+
Point Structure of Pointcept
|
| 34 |
+
|
| 35 |
+
A Point (point cloud) in Pointcept is a dictionary that contains various properties of
|
| 36 |
+
a batched point cloud. The property with the following names have a specific definition
|
| 37 |
+
as follows:
|
| 38 |
+
|
| 39 |
+
- "coord": original coordinate of point cloud;
|
| 40 |
+
- "grid_coord": grid coordinate for specific grid size (related to GridSampling);
|
| 41 |
+
Point also support the following optional attributes:
|
| 42 |
+
- "offset": if not exist, initialized as batch size is 1;
|
| 43 |
+
- "batch": if not exist, initialized as batch size is 1;
|
| 44 |
+
- "feat": feature of point cloud, default input of model;
|
| 45 |
+
- "grid_size": Grid size of point cloud (related to GridSampling);
|
| 46 |
+
(related to Serialization)
|
| 47 |
+
- "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
|
| 48 |
+
- "serialized_code": a list of serialization codes;
|
| 49 |
+
- "serialized_order": a list of serialization order determined by code;
|
| 50 |
+
- "serialized_inverse": a list of inverse mapping determined by code;
|
| 51 |
+
(related to Sparsify: SpConv)
|
| 52 |
+
- "sparse_shape": Sparse shape for Sparse Conv Tensor;
|
| 53 |
+
- "sparse_conv_feat": SparseConvTensor init with information provide by Point;
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, *args, **kwargs):
|
| 57 |
+
super().__init__(*args, **kwargs)
|
| 58 |
+
# If one of "offset" or "batch" do not exist, generate by the existing one
|
| 59 |
+
if "batch" not in self.keys() and "offset" in self.keys():
|
| 60 |
+
self["batch"] = offset2batch(self.offset)
|
| 61 |
+
elif "offset" not in self.keys() and "batch" in self.keys():
|
| 62 |
+
self["offset"] = batch2offset(self.batch)
|
| 63 |
+
|
| 64 |
+
def serialization(self, order="z", depth=None, shuffle_orders=False):
|
| 65 |
+
"""
|
| 66 |
+
Point Cloud Serialization
|
| 67 |
+
|
| 68 |
+
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
|
| 69 |
+
"""
|
| 70 |
+
self["order"] = order
|
| 71 |
+
assert "batch" in self.keys()
|
| 72 |
+
if "grid_coord" not in self.keys():
|
| 73 |
+
# if you don't want to operate GridSampling in data augmentation,
|
| 74 |
+
# please add the following augmentation into your pipeline:
|
| 75 |
+
# dict(type="Copy", keys_dict={"grid_size": 0.01}),
|
| 76 |
+
# (adjust `grid_size` to what your want)
|
| 77 |
+
assert {"grid_size", "coord"}.issubset(self.keys())
|
| 78 |
+
|
| 79 |
+
self["grid_coord"] = torch.div(
|
| 80 |
+
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
|
| 81 |
+
).int()
|
| 82 |
+
|
| 83 |
+
if depth is None:
|
| 84 |
+
# Adaptive measure the depth of serialization cube (length = 2 ^ depth)
|
| 85 |
+
depth = int(self.grid_coord.max() + 1).bit_length()
|
| 86 |
+
self["serialized_depth"] = depth
|
| 87 |
+
# Maximum bit length for serialization code is 63 (int64)
|
| 88 |
+
assert depth * 3 + len(self.offset).bit_length() <= 63
|
| 89 |
+
# Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
|
| 90 |
+
# Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
|
| 91 |
+
# cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
|
| 92 |
+
# We can unlock the limitation by optimizing the z-order encoding function if necessary.
|
| 93 |
+
assert depth <= 16
|
| 94 |
+
|
| 95 |
+
# The serialization codes are arranged as following structures:
|
| 96 |
+
# [Order1 ([n]),
|
| 97 |
+
# Order2 ([n]),
|
| 98 |
+
# ...
|
| 99 |
+
# OrderN ([n])] (k, n)
|
| 100 |
+
code = [
|
| 101 |
+
encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order
|
| 102 |
+
]
|
| 103 |
+
code = torch.stack(code)
|
| 104 |
+
order = torch.argsort(code)
|
| 105 |
+
inverse = torch.zeros_like(order).scatter_(
|
| 106 |
+
dim=1,
|
| 107 |
+
index=order,
|
| 108 |
+
src=torch.arange(0, code.shape[1], device=order.device).repeat(
|
| 109 |
+
code.shape[0], 1
|
| 110 |
+
),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if shuffle_orders:
|
| 114 |
+
perm = torch.randperm(code.shape[0])
|
| 115 |
+
code = code[perm]
|
| 116 |
+
order = order[perm]
|
| 117 |
+
inverse = inverse[perm]
|
| 118 |
+
|
| 119 |
+
self["serialized_code"] = code
|
| 120 |
+
self["serialized_order"] = order
|
| 121 |
+
self["serialized_inverse"] = inverse
|
| 122 |
+
|
| 123 |
+
def sparsify(self, pad=96):
|
| 124 |
+
"""
|
| 125 |
+
Point Cloud Serialization
|
| 126 |
+
|
| 127 |
+
Point cloud is sparse, here we use "sparsify" to specifically refer to
|
| 128 |
+
preparing "spconv.SparseConvTensor" for SpConv.
|
| 129 |
+
|
| 130 |
+
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
|
| 131 |
+
|
| 132 |
+
pad: padding sparse for sparse shape.
|
| 133 |
+
"""
|
| 134 |
+
assert {"feat", "batch"}.issubset(self.keys())
|
| 135 |
+
if "grid_coord" not in self.keys():
|
| 136 |
+
# if you don't want to operate GridSampling in data augmentation,
|
| 137 |
+
# please add the following augmentation into your pipeline:
|
| 138 |
+
# dict(type="Copy", keys_dict={"grid_size": 0.01}),
|
| 139 |
+
# (adjust `grid_size` to what your want)
|
| 140 |
+
assert {"grid_size", "coord"}.issubset(self.keys())
|
| 141 |
+
self["grid_coord"] = torch.div(
|
| 142 |
+
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
|
| 143 |
+
).int()
|
| 144 |
+
if "sparse_shape" in self.keys():
|
| 145 |
+
sparse_shape = self.sparse_shape
|
| 146 |
+
else:
|
| 147 |
+
sparse_shape = torch.add(
|
| 148 |
+
torch.max(self.grid_coord, dim=0).values, pad
|
| 149 |
+
).tolist()
|
| 150 |
+
sparse_conv_feat = spconv.SparseConvTensor(
|
| 151 |
+
features=self.feat,
|
| 152 |
+
indices=torch.cat(
|
| 153 |
+
[self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
|
| 154 |
+
).contiguous(),
|
| 155 |
+
spatial_shape=sparse_shape,
|
| 156 |
+
batch_size=self.batch[-1].tolist() + 1,
|
| 157 |
+
)
|
| 158 |
+
self["sparse_shape"] = sparse_shape
|
| 159 |
+
self["sparse_conv_feat"] = sparse_conv_feat
|
XPart/partgen/models/sonata/transform.py
ADDED
|
@@ -0,0 +1,1330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
3D point cloud augmentation
|
| 3 |
+
|
| 4 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
|
| 5 |
+
Please cite our work if the code is helpful to you.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 22 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 23 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 24 |
+
# See the License for the specific language governing permissions and
|
| 25 |
+
# limitations under the License.
|
| 26 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 27 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 28 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 29 |
+
# See the License for the specific language governing permissions and
|
| 30 |
+
# limitations under the License.
|
| 31 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 32 |
+
#
|
| 33 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 34 |
+
# you may not use this file except in compliance with the License.
|
| 35 |
+
# You may obtain a copy of the License at
|
| 36 |
+
#
|
| 37 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 38 |
+
#
|
| 39 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 40 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 41 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 42 |
+
# See the License for the specific language governing permissions and
|
| 43 |
+
# limitations under the License.
|
| 44 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 45 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 46 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 47 |
+
# See the License for the specific language governing permissions and
|
| 48 |
+
# limitations under the License.
|
| 49 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 50 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 51 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 52 |
+
# See the License for the specific language governing permissions and
|
| 53 |
+
# limitations under the License.
|
| 54 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 55 |
+
#
|
| 56 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 57 |
+
# you may not use this file except in compliance with the License.
|
| 58 |
+
# You may obtain a copy of the License at
|
| 59 |
+
#
|
| 60 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 61 |
+
#
|
| 62 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 63 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 64 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 65 |
+
# See the License for the specific language governing permissions and
|
| 66 |
+
# limitations under the License.
|
| 67 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 68 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 69 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 70 |
+
# See the License for the specific language governing permissions and
|
| 71 |
+
# limitations under the License.
|
| 72 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 73 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 74 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 75 |
+
# See the License for the specific language governing permissions and
|
| 76 |
+
# limitations under the License.
|
| 77 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 78 |
+
#
|
| 79 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 80 |
+
# you may not use this file except in compliance with the License.
|
| 81 |
+
# You may obtain a copy of the License at
|
| 82 |
+
#
|
| 83 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 84 |
+
#
|
| 85 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 86 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 87 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 88 |
+
# See the License for the specific language governing permissions and
|
| 89 |
+
# limitations under the License.
|
| 90 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 91 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 92 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 93 |
+
# See the License for the specific language governing permissions and
|
| 94 |
+
# limitations under the License.
|
| 95 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 96 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 97 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 98 |
+
# See the License for the specific language governing permissions and
|
| 99 |
+
# limitations under the License.
|
| 100 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 101 |
+
#
|
| 102 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 103 |
+
# you may not use this file except in compliance with the License.
|
| 104 |
+
# You may obtain a copy of the License at
|
| 105 |
+
#
|
| 106 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 107 |
+
#
|
| 108 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 109 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 110 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 111 |
+
# See the License for the specific language governing permissions and
|
| 112 |
+
# limitations under the License.
|
| 113 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 114 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 115 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 116 |
+
# See the License for the specific language governing permissions and
|
| 117 |
+
# limitations under the License.
|
| 118 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 119 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 120 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 121 |
+
# See the License for the specific language governing permissions and
|
| 122 |
+
# limitations under the License.
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
import random
|
| 127 |
+
import numbers
|
| 128 |
+
import scipy
|
| 129 |
+
import scipy.ndimage
|
| 130 |
+
import scipy.interpolate
|
| 131 |
+
import scipy.stats
|
| 132 |
+
import numpy as np
|
| 133 |
+
import torch
|
| 134 |
+
import copy
|
| 135 |
+
from collections.abc import Sequence, Mapping
|
| 136 |
+
|
| 137 |
+
from .registry import Registry
|
| 138 |
+
|
| 139 |
+
TRANSFORMS = Registry("transforms")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def index_operator(data_dict, index, duplicate=False):
|
| 143 |
+
# index selection operator for keys in "index_valid_keys"
|
| 144 |
+
# custom these keys by "Update" transform in config
|
| 145 |
+
if "index_valid_keys" not in data_dict:
|
| 146 |
+
data_dict["index_valid_keys"] = [
|
| 147 |
+
"coord",
|
| 148 |
+
"color",
|
| 149 |
+
"normal",
|
| 150 |
+
"strength",
|
| 151 |
+
"segment",
|
| 152 |
+
"instance",
|
| 153 |
+
]
|
| 154 |
+
if not duplicate:
|
| 155 |
+
for key in data_dict["index_valid_keys"]:
|
| 156 |
+
if key in data_dict:
|
| 157 |
+
data_dict[key] = data_dict[key][index]
|
| 158 |
+
return data_dict
|
| 159 |
+
else:
|
| 160 |
+
data_dict_ = dict()
|
| 161 |
+
for key in data_dict.keys():
|
| 162 |
+
if key in data_dict["index_valid_keys"]:
|
| 163 |
+
data_dict_[key] = data_dict[key][index]
|
| 164 |
+
else:
|
| 165 |
+
data_dict_[key] = data_dict[key]
|
| 166 |
+
return data_dict_
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@TRANSFORMS.register_module()
|
| 170 |
+
class Collect(object):
|
| 171 |
+
def __init__(self, keys, offset_keys_dict=None, **kwargs):
|
| 172 |
+
"""
|
| 173 |
+
e.g. Collect(keys=[coord], feat_keys=[coord, color])
|
| 174 |
+
"""
|
| 175 |
+
if offset_keys_dict is None:
|
| 176 |
+
offset_keys_dict = dict(offset="coord")
|
| 177 |
+
self.keys = keys
|
| 178 |
+
self.offset_keys = offset_keys_dict
|
| 179 |
+
self.kwargs = kwargs
|
| 180 |
+
|
| 181 |
+
def __call__(self, data_dict):
|
| 182 |
+
data = dict()
|
| 183 |
+
if isinstance(self.keys, str):
|
| 184 |
+
self.keys = [self.keys]
|
| 185 |
+
for key in self.keys:
|
| 186 |
+
data[key] = data_dict[key]
|
| 187 |
+
for key, value in self.offset_keys.items():
|
| 188 |
+
data[key] = torch.tensor([data_dict[value].shape[0]])
|
| 189 |
+
for name, keys in self.kwargs.items():
|
| 190 |
+
name = name.replace("_keys", "")
|
| 191 |
+
assert isinstance(keys, Sequence)
|
| 192 |
+
data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1)
|
| 193 |
+
return data
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@TRANSFORMS.register_module()
|
| 197 |
+
class Copy(object):
|
| 198 |
+
def __init__(self, keys_dict=None):
|
| 199 |
+
if keys_dict is None:
|
| 200 |
+
keys_dict = dict(coord="origin_coord", segment="origin_segment")
|
| 201 |
+
self.keys_dict = keys_dict
|
| 202 |
+
|
| 203 |
+
def __call__(self, data_dict):
|
| 204 |
+
for key, value in self.keys_dict.items():
|
| 205 |
+
if isinstance(data_dict[key], np.ndarray):
|
| 206 |
+
data_dict[value] = data_dict[key].copy()
|
| 207 |
+
elif isinstance(data_dict[key], torch.Tensor):
|
| 208 |
+
data_dict[value] = data_dict[key].clone().detach()
|
| 209 |
+
else:
|
| 210 |
+
data_dict[value] = copy.deepcopy(data_dict[key])
|
| 211 |
+
return data_dict
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
@TRANSFORMS.register_module()
|
| 215 |
+
class Update(object):
|
| 216 |
+
def __init__(self, keys_dict=None):
|
| 217 |
+
if keys_dict is None:
|
| 218 |
+
keys_dict = dict()
|
| 219 |
+
self.keys_dict = keys_dict
|
| 220 |
+
|
| 221 |
+
def __call__(self, data_dict):
|
| 222 |
+
for key, value in self.keys_dict.items():
|
| 223 |
+
data_dict[key] = value
|
| 224 |
+
return data_dict
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@TRANSFORMS.register_module()
|
| 228 |
+
class ToTensor(object):
|
| 229 |
+
def __call__(self, data):
|
| 230 |
+
if isinstance(data, torch.Tensor):
|
| 231 |
+
return data
|
| 232 |
+
elif isinstance(data, str):
|
| 233 |
+
# note that str is also a kind of sequence, judgement should before sequence
|
| 234 |
+
return data
|
| 235 |
+
elif isinstance(data, int):
|
| 236 |
+
return torch.LongTensor([data])
|
| 237 |
+
elif isinstance(data, float):
|
| 238 |
+
return torch.FloatTensor([data])
|
| 239 |
+
elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool):
|
| 240 |
+
return torch.from_numpy(data)
|
| 241 |
+
elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer):
|
| 242 |
+
return torch.from_numpy(data).long()
|
| 243 |
+
elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating):
|
| 244 |
+
return torch.from_numpy(data).float()
|
| 245 |
+
elif isinstance(data, Mapping):
|
| 246 |
+
result = {sub_key: self(item) for sub_key, item in data.items()}
|
| 247 |
+
return result
|
| 248 |
+
elif isinstance(data, Sequence):
|
| 249 |
+
result = [self(item) for item in data]
|
| 250 |
+
return result
|
| 251 |
+
else:
|
| 252 |
+
raise TypeError(f"type {type(data)} cannot be converted to tensor.")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
@TRANSFORMS.register_module()
|
| 256 |
+
class NormalizeColor(object):
|
| 257 |
+
def __call__(self, data_dict):
|
| 258 |
+
if "color" in data_dict.keys():
|
| 259 |
+
data_dict["color"] = data_dict["color"] / 255
|
| 260 |
+
return data_dict
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@TRANSFORMS.register_module()
|
| 264 |
+
class NormalizeCoord(object):
|
| 265 |
+
def __call__(self, data_dict):
|
| 266 |
+
if "coord" in data_dict.keys():
|
| 267 |
+
# modified from pointnet2
|
| 268 |
+
centroid = np.mean(data_dict["coord"], axis=0)
|
| 269 |
+
data_dict["coord"] -= centroid
|
| 270 |
+
m = np.max(np.sqrt(np.sum(data_dict["coord"] ** 2, axis=1)))
|
| 271 |
+
data_dict["coord"] = data_dict["coord"] / m
|
| 272 |
+
return data_dict
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@TRANSFORMS.register_module()
|
| 276 |
+
class PositiveShift(object):
|
| 277 |
+
def __call__(self, data_dict):
|
| 278 |
+
if "coord" in data_dict.keys():
|
| 279 |
+
coord_min = np.min(data_dict["coord"], 0)
|
| 280 |
+
data_dict["coord"] -= coord_min
|
| 281 |
+
return data_dict
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@TRANSFORMS.register_module()
|
| 285 |
+
class CenterShift(object):
|
| 286 |
+
def __init__(self, apply_z=True):
|
| 287 |
+
self.apply_z = apply_z
|
| 288 |
+
|
| 289 |
+
def __call__(self, data_dict):
|
| 290 |
+
if "coord" in data_dict.keys():
|
| 291 |
+
x_min, y_min, z_min = data_dict["coord"].min(axis=0)
|
| 292 |
+
x_max, y_max, _ = data_dict["coord"].max(axis=0)
|
| 293 |
+
if self.apply_z:
|
| 294 |
+
shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, z_min]
|
| 295 |
+
else:
|
| 296 |
+
shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, 0]
|
| 297 |
+
data_dict["coord"] -= shift
|
| 298 |
+
return data_dict
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@TRANSFORMS.register_module()
|
| 302 |
+
class RandomShift(object):
|
| 303 |
+
def __init__(self, shift=((-0.2, 0.2), (-0.2, 0.2), (0, 0))):
|
| 304 |
+
self.shift = shift
|
| 305 |
+
|
| 306 |
+
def __call__(self, data_dict):
|
| 307 |
+
if "coord" in data_dict.keys():
|
| 308 |
+
shift_x = np.random.uniform(self.shift[0][0], self.shift[0][1])
|
| 309 |
+
shift_y = np.random.uniform(self.shift[1][0], self.shift[1][1])
|
| 310 |
+
shift_z = np.random.uniform(self.shift[2][0], self.shift[2][1])
|
| 311 |
+
data_dict["coord"] += [shift_x, shift_y, shift_z]
|
| 312 |
+
return data_dict
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@TRANSFORMS.register_module()
|
| 316 |
+
class PointClip(object):
|
| 317 |
+
def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)):
|
| 318 |
+
self.point_cloud_range = point_cloud_range
|
| 319 |
+
|
| 320 |
+
def __call__(self, data_dict):
|
| 321 |
+
if "coord" in data_dict.keys():
|
| 322 |
+
data_dict["coord"] = np.clip(
|
| 323 |
+
data_dict["coord"],
|
| 324 |
+
a_min=self.point_cloud_range[:3],
|
| 325 |
+
a_max=self.point_cloud_range[3:],
|
| 326 |
+
)
|
| 327 |
+
return data_dict
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
@TRANSFORMS.register_module()
|
| 331 |
+
class RandomDropout(object):
|
| 332 |
+
def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5):
|
| 333 |
+
"""
|
| 334 |
+
upright_axis: axis index among x,y,z, i.e. 2 for z
|
| 335 |
+
"""
|
| 336 |
+
self.dropout_ratio = dropout_ratio
|
| 337 |
+
self.dropout_application_ratio = dropout_application_ratio
|
| 338 |
+
|
| 339 |
+
def __call__(self, data_dict):
|
| 340 |
+
if random.random() < self.dropout_application_ratio:
|
| 341 |
+
n = len(data_dict["coord"])
|
| 342 |
+
idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False)
|
| 343 |
+
if "sampled_index" in data_dict:
|
| 344 |
+
# for ScanNet data efficient, we need to make sure labeled point is sampled.
|
| 345 |
+
idx = np.unique(np.append(idx, data_dict["sampled_index"]))
|
| 346 |
+
mask = np.zeros_like(data_dict["segment"]).astype(bool)
|
| 347 |
+
mask[data_dict["sampled_index"]] = True
|
| 348 |
+
data_dict["sampled_index"] = np.where(mask[idx])[0]
|
| 349 |
+
data_dict = index_operator(data_dict, idx)
|
| 350 |
+
return data_dict
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
@TRANSFORMS.register_module()
|
| 354 |
+
class RandomRotate(object):
|
| 355 |
+
def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5):
|
| 356 |
+
self.angle = [-1, 1] if angle is None else angle
|
| 357 |
+
self.axis = axis
|
| 358 |
+
self.always_apply = always_apply
|
| 359 |
+
self.p = p if not self.always_apply else 1
|
| 360 |
+
self.center = center
|
| 361 |
+
|
| 362 |
+
def __call__(self, data_dict):
|
| 363 |
+
if random.random() > self.p:
|
| 364 |
+
return data_dict
|
| 365 |
+
angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi
|
| 366 |
+
rot_cos, rot_sin = np.cos(angle), np.sin(angle)
|
| 367 |
+
if self.axis == "x":
|
| 368 |
+
rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
|
| 369 |
+
elif self.axis == "y":
|
| 370 |
+
rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
|
| 371 |
+
elif self.axis == "z":
|
| 372 |
+
rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
|
| 373 |
+
else:
|
| 374 |
+
raise NotImplementedError
|
| 375 |
+
if "coord" in data_dict.keys():
|
| 376 |
+
if self.center is None:
|
| 377 |
+
x_min, y_min, z_min = data_dict["coord"].min(axis=0)
|
| 378 |
+
x_max, y_max, z_max = data_dict["coord"].max(axis=0)
|
| 379 |
+
center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
|
| 380 |
+
else:
|
| 381 |
+
center = self.center
|
| 382 |
+
data_dict["coord"] -= center
|
| 383 |
+
data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
|
| 384 |
+
data_dict["coord"] += center
|
| 385 |
+
if "normal" in data_dict.keys():
|
| 386 |
+
data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
|
| 387 |
+
return data_dict
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
@TRANSFORMS.register_module()
|
| 391 |
+
class RandomRotateTargetAngle(object):
|
| 392 |
+
def __init__(
|
| 393 |
+
self, angle=(1 / 2, 1, 3 / 2), center=None, axis="z", always_apply=False, p=0.75
|
| 394 |
+
):
|
| 395 |
+
self.angle = angle
|
| 396 |
+
self.axis = axis
|
| 397 |
+
self.always_apply = always_apply
|
| 398 |
+
self.p = p if not self.always_apply else 1
|
| 399 |
+
self.center = center
|
| 400 |
+
|
| 401 |
+
def __call__(self, data_dict):
|
| 402 |
+
if random.random() > self.p:
|
| 403 |
+
return data_dict
|
| 404 |
+
angle = np.random.choice(self.angle) * np.pi
|
| 405 |
+
rot_cos, rot_sin = np.cos(angle), np.sin(angle)
|
| 406 |
+
if self.axis == "x":
|
| 407 |
+
rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
|
| 408 |
+
elif self.axis == "y":
|
| 409 |
+
rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
|
| 410 |
+
elif self.axis == "z":
|
| 411 |
+
rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
|
| 412 |
+
else:
|
| 413 |
+
raise NotImplementedError
|
| 414 |
+
if "coord" in data_dict.keys():
|
| 415 |
+
if self.center is None:
|
| 416 |
+
x_min, y_min, z_min = data_dict["coord"].min(axis=0)
|
| 417 |
+
x_max, y_max, z_max = data_dict["coord"].max(axis=0)
|
| 418 |
+
center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
|
| 419 |
+
else:
|
| 420 |
+
center = self.center
|
| 421 |
+
data_dict["coord"] -= center
|
| 422 |
+
data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
|
| 423 |
+
data_dict["coord"] += center
|
| 424 |
+
if "normal" in data_dict.keys():
|
| 425 |
+
data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
|
| 426 |
+
return data_dict
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
@TRANSFORMS.register_module()
|
| 430 |
+
class RandomScale(object):
|
| 431 |
+
def __init__(self, scale=None, anisotropic=False):
|
| 432 |
+
self.scale = scale if scale is not None else [0.95, 1.05]
|
| 433 |
+
self.anisotropic = anisotropic
|
| 434 |
+
|
| 435 |
+
def __call__(self, data_dict):
|
| 436 |
+
if "coord" in data_dict.keys():
|
| 437 |
+
scale = np.random.uniform(
|
| 438 |
+
self.scale[0], self.scale[1], 3 if self.anisotropic else 1
|
| 439 |
+
)
|
| 440 |
+
data_dict["coord"] *= scale
|
| 441 |
+
return data_dict
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@TRANSFORMS.register_module()
|
| 445 |
+
class RandomFlip(object):
|
| 446 |
+
def __init__(self, p=0.5):
|
| 447 |
+
self.p = p
|
| 448 |
+
|
| 449 |
+
def __call__(self, data_dict):
|
| 450 |
+
if np.random.rand() < self.p:
|
| 451 |
+
if "coord" in data_dict.keys():
|
| 452 |
+
data_dict["coord"][:, 0] = -data_dict["coord"][:, 0]
|
| 453 |
+
if "normal" in data_dict.keys():
|
| 454 |
+
data_dict["normal"][:, 0] = -data_dict["normal"][:, 0]
|
| 455 |
+
if np.random.rand() < self.p:
|
| 456 |
+
if "coord" in data_dict.keys():
|
| 457 |
+
data_dict["coord"][:, 1] = -data_dict["coord"][:, 1]
|
| 458 |
+
if "normal" in data_dict.keys():
|
| 459 |
+
data_dict["normal"][:, 1] = -data_dict["normal"][:, 1]
|
| 460 |
+
return data_dict
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
@TRANSFORMS.register_module()
|
| 464 |
+
class RandomJitter(object):
|
| 465 |
+
def __init__(self, sigma=0.01, clip=0.05):
|
| 466 |
+
assert clip > 0
|
| 467 |
+
self.sigma = sigma
|
| 468 |
+
self.clip = clip
|
| 469 |
+
|
| 470 |
+
def __call__(self, data_dict):
|
| 471 |
+
if "coord" in data_dict.keys():
|
| 472 |
+
jitter = np.clip(
|
| 473 |
+
self.sigma * np.random.randn(data_dict["coord"].shape[0], 3),
|
| 474 |
+
-self.clip,
|
| 475 |
+
self.clip,
|
| 476 |
+
)
|
| 477 |
+
data_dict["coord"] += jitter
|
| 478 |
+
return data_dict
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
@TRANSFORMS.register_module()
|
| 482 |
+
class ClipGaussianJitter(object):
|
| 483 |
+
def __init__(self, scalar=0.02, store_jitter=False):
|
| 484 |
+
self.scalar = scalar
|
| 485 |
+
self.mean = np.mean(3)
|
| 486 |
+
self.cov = np.identity(3)
|
| 487 |
+
self.quantile = 1.96
|
| 488 |
+
self.store_jitter = store_jitter
|
| 489 |
+
|
| 490 |
+
def __call__(self, data_dict):
|
| 491 |
+
if "coord" in data_dict.keys():
|
| 492 |
+
jitter = np.random.multivariate_normal(
|
| 493 |
+
self.mean, self.cov, data_dict["coord"].shape[0]
|
| 494 |
+
)
|
| 495 |
+
jitter = self.scalar * np.clip(jitter / 1.96, -1, 1)
|
| 496 |
+
data_dict["coord"] += jitter
|
| 497 |
+
if self.store_jitter:
|
| 498 |
+
data_dict["jitter"] = jitter
|
| 499 |
+
return data_dict
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
@TRANSFORMS.register_module()
|
| 503 |
+
class ChromaticAutoContrast(object):
|
| 504 |
+
def __init__(self, p=0.2, blend_factor=None):
|
| 505 |
+
self.p = p
|
| 506 |
+
self.blend_factor = blend_factor
|
| 507 |
+
|
| 508 |
+
def __call__(self, data_dict):
|
| 509 |
+
if "color" in data_dict.keys() and np.random.rand() < self.p:
|
| 510 |
+
lo = np.min(data_dict["color"], 0, keepdims=True)
|
| 511 |
+
hi = np.max(data_dict["color"], 0, keepdims=True)
|
| 512 |
+
scale = 255 / (hi - lo)
|
| 513 |
+
contrast_feat = (data_dict["color"][:, :3] - lo) * scale
|
| 514 |
+
blend_factor = (
|
| 515 |
+
np.random.rand() if self.blend_factor is None else self.blend_factor
|
| 516 |
+
)
|
| 517 |
+
data_dict["color"][:, :3] = (1 - blend_factor) * data_dict["color"][
|
| 518 |
+
:, :3
|
| 519 |
+
] + blend_factor * contrast_feat
|
| 520 |
+
return data_dict
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
@TRANSFORMS.register_module()
|
| 524 |
+
class ChromaticTranslation(object):
|
| 525 |
+
def __init__(self, p=0.95, ratio=0.05):
|
| 526 |
+
self.p = p
|
| 527 |
+
self.ratio = ratio
|
| 528 |
+
|
| 529 |
+
def __call__(self, data_dict):
|
| 530 |
+
if "color" in data_dict.keys() and np.random.rand() < self.p:
|
| 531 |
+
tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio
|
| 532 |
+
data_dict["color"][:, :3] = np.clip(tr + data_dict["color"][:, :3], 0, 255)
|
| 533 |
+
return data_dict
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
@TRANSFORMS.register_module()
|
| 537 |
+
class ChromaticJitter(object):
|
| 538 |
+
def __init__(self, p=0.95, std=0.005):
|
| 539 |
+
self.p = p
|
| 540 |
+
self.std = std
|
| 541 |
+
|
| 542 |
+
def __call__(self, data_dict):
|
| 543 |
+
if "color" in data_dict.keys() and np.random.rand() < self.p:
|
| 544 |
+
noise = np.random.randn(data_dict["color"].shape[0], 3)
|
| 545 |
+
noise *= self.std * 255
|
| 546 |
+
data_dict["color"][:, :3] = np.clip(
|
| 547 |
+
noise + data_dict["color"][:, :3], 0, 255
|
| 548 |
+
)
|
| 549 |
+
return data_dict
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
@TRANSFORMS.register_module()
|
| 553 |
+
class RandomColorGrayScale(object):
|
| 554 |
+
def __init__(self, p):
|
| 555 |
+
self.p = p
|
| 556 |
+
|
| 557 |
+
@staticmethod
|
| 558 |
+
def rgb_to_grayscale(color, num_output_channels=1):
|
| 559 |
+
if color.shape[-1] < 3:
|
| 560 |
+
raise TypeError(
|
| 561 |
+
"Input color should have at least 3 dimensions, but found {}".format(
|
| 562 |
+
color.shape[-1]
|
| 563 |
+
)
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
if num_output_channels not in (1, 3):
|
| 567 |
+
raise ValueError("num_output_channels should be either 1 or 3")
|
| 568 |
+
|
| 569 |
+
r, g, b = color[..., 0], color[..., 1], color[..., 2]
|
| 570 |
+
gray = (0.2989 * r + 0.587 * g + 0.114 * b).astype(color.dtype)
|
| 571 |
+
gray = np.expand_dims(gray, axis=-1)
|
| 572 |
+
|
| 573 |
+
if num_output_channels == 3:
|
| 574 |
+
gray = np.broadcast_to(gray, color.shape)
|
| 575 |
+
|
| 576 |
+
return gray
|
| 577 |
+
|
| 578 |
+
def __call__(self, data_dict):
|
| 579 |
+
if np.random.rand() < self.p:
|
| 580 |
+
data_dict["color"] = self.rgb_to_grayscale(data_dict["color"], 3)
|
| 581 |
+
return data_dict
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
@TRANSFORMS.register_module()
|
| 585 |
+
class RandomColorJitter(object):
|
| 586 |
+
"""
|
| 587 |
+
Random Color Jitter for 3D point cloud (refer torchvision)
|
| 588 |
+
"""
|
| 589 |
+
|
| 590 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.95):
|
| 591 |
+
self.brightness = self._check_input(brightness, "brightness")
|
| 592 |
+
self.contrast = self._check_input(contrast, "contrast")
|
| 593 |
+
self.saturation = self._check_input(saturation, "saturation")
|
| 594 |
+
self.hue = self._check_input(
|
| 595 |
+
hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False
|
| 596 |
+
)
|
| 597 |
+
self.p = p
|
| 598 |
+
|
| 599 |
+
@staticmethod
|
| 600 |
+
def _check_input(
|
| 601 |
+
value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
|
| 602 |
+
):
|
| 603 |
+
if isinstance(value, numbers.Number):
|
| 604 |
+
if value < 0:
|
| 605 |
+
raise ValueError(
|
| 606 |
+
"If {} is a single number, it must be non negative.".format(name)
|
| 607 |
+
)
|
| 608 |
+
value = [center - float(value), center + float(value)]
|
| 609 |
+
if clip_first_on_zero:
|
| 610 |
+
value[0] = max(value[0], 0.0)
|
| 611 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
| 612 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
| 613 |
+
raise ValueError("{} values should be between {}".format(name, bound))
|
| 614 |
+
else:
|
| 615 |
+
raise TypeError(
|
| 616 |
+
"{} should be a single number or a list/tuple with length 2.".format(
|
| 617 |
+
name
|
| 618 |
+
)
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
| 622 |
+
# or (0., 0.) for hue, do nothing
|
| 623 |
+
if value[0] == value[1] == center:
|
| 624 |
+
value = None
|
| 625 |
+
return value
|
| 626 |
+
|
| 627 |
+
@staticmethod
|
| 628 |
+
def blend(color1, color2, ratio):
|
| 629 |
+
ratio = float(ratio)
|
| 630 |
+
bound = 255.0
|
| 631 |
+
return (
|
| 632 |
+
(ratio * color1 + (1.0 - ratio) * color2)
|
| 633 |
+
.clip(0, bound)
|
| 634 |
+
.astype(color1.dtype)
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
@staticmethod
|
| 638 |
+
def rgb2hsv(rgb):
|
| 639 |
+
r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
|
| 640 |
+
maxc = np.max(rgb, axis=-1)
|
| 641 |
+
minc = np.min(rgb, axis=-1)
|
| 642 |
+
eqc = maxc == minc
|
| 643 |
+
cr = maxc - minc
|
| 644 |
+
s = cr / (np.ones_like(maxc) * eqc + maxc * (1 - eqc))
|
| 645 |
+
cr_divisor = np.ones_like(maxc) * eqc + cr * (1 - eqc)
|
| 646 |
+
rc = (maxc - r) / cr_divisor
|
| 647 |
+
gc = (maxc - g) / cr_divisor
|
| 648 |
+
bc = (maxc - b) / cr_divisor
|
| 649 |
+
|
| 650 |
+
hr = (maxc == r) * (bc - gc)
|
| 651 |
+
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
|
| 652 |
+
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
|
| 653 |
+
h = hr + hg + hb
|
| 654 |
+
h = (h / 6.0 + 1.0) % 1.0
|
| 655 |
+
return np.stack((h, s, maxc), axis=-1)
|
| 656 |
+
|
| 657 |
+
@staticmethod
|
| 658 |
+
def hsv2rgb(hsv):
|
| 659 |
+
h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
|
| 660 |
+
i = np.floor(h * 6.0)
|
| 661 |
+
f = (h * 6.0) - i
|
| 662 |
+
i = i.astype(np.int32)
|
| 663 |
+
|
| 664 |
+
p = np.clip((v * (1.0 - s)), 0.0, 1.0)
|
| 665 |
+
q = np.clip((v * (1.0 - s * f)), 0.0, 1.0)
|
| 666 |
+
t = np.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
|
| 667 |
+
i = i % 6
|
| 668 |
+
mask = np.expand_dims(i, axis=-1) == np.arange(6)
|
| 669 |
+
|
| 670 |
+
a1 = np.stack((v, q, p, p, t, v), axis=-1)
|
| 671 |
+
a2 = np.stack((t, v, v, q, p, p), axis=-1)
|
| 672 |
+
a3 = np.stack((p, p, t, v, v, q), axis=-1)
|
| 673 |
+
a4 = np.stack((a1, a2, a3), axis=-1)
|
| 674 |
+
|
| 675 |
+
return np.einsum("...na, ...nab -> ...nb", mask.astype(hsv.dtype), a4)
|
| 676 |
+
|
| 677 |
+
def adjust_brightness(self, color, brightness_factor):
|
| 678 |
+
if brightness_factor < 0:
|
| 679 |
+
raise ValueError(
|
| 680 |
+
"brightness_factor ({}) is not non-negative.".format(brightness_factor)
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
return self.blend(color, np.zeros_like(color), brightness_factor)
|
| 684 |
+
|
| 685 |
+
def adjust_contrast(self, color, contrast_factor):
|
| 686 |
+
if contrast_factor < 0:
|
| 687 |
+
raise ValueError(
|
| 688 |
+
"contrast_factor ({}) is not non-negative.".format(contrast_factor)
|
| 689 |
+
)
|
| 690 |
+
mean = np.mean(RandomColorGrayScale.rgb_to_grayscale(color))
|
| 691 |
+
return self.blend(color, mean, contrast_factor)
|
| 692 |
+
|
| 693 |
+
def adjust_saturation(self, color, saturation_factor):
|
| 694 |
+
if saturation_factor < 0:
|
| 695 |
+
raise ValueError(
|
| 696 |
+
"saturation_factor ({}) is not non-negative.".format(saturation_factor)
|
| 697 |
+
)
|
| 698 |
+
gray = RandomColorGrayScale.rgb_to_grayscale(color)
|
| 699 |
+
return self.blend(color, gray, saturation_factor)
|
| 700 |
+
|
| 701 |
+
def adjust_hue(self, color, hue_factor):
|
| 702 |
+
if not (-0.5 <= hue_factor <= 0.5):
|
| 703 |
+
raise ValueError(
|
| 704 |
+
"hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)
|
| 705 |
+
)
|
| 706 |
+
orig_dtype = color.dtype
|
| 707 |
+
hsv = self.rgb2hsv(color / 255.0)
|
| 708 |
+
h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
|
| 709 |
+
h = (h + hue_factor) % 1.0
|
| 710 |
+
hsv = np.stack((h, s, v), axis=-1)
|
| 711 |
+
color_hue_adj = (self.hsv2rgb(hsv) * 255.0).astype(orig_dtype)
|
| 712 |
+
return color_hue_adj
|
| 713 |
+
|
| 714 |
+
@staticmethod
|
| 715 |
+
def get_params(brightness, contrast, saturation, hue):
|
| 716 |
+
fn_idx = torch.randperm(4)
|
| 717 |
+
b = (
|
| 718 |
+
None
|
| 719 |
+
if brightness is None
|
| 720 |
+
else np.random.uniform(brightness[0], brightness[1])
|
| 721 |
+
)
|
| 722 |
+
c = None if contrast is None else np.random.uniform(contrast[0], contrast[1])
|
| 723 |
+
s = (
|
| 724 |
+
None
|
| 725 |
+
if saturation is None
|
| 726 |
+
else np.random.uniform(saturation[0], saturation[1])
|
| 727 |
+
)
|
| 728 |
+
h = None if hue is None else np.random.uniform(hue[0], hue[1])
|
| 729 |
+
return fn_idx, b, c, s, h
|
| 730 |
+
|
| 731 |
+
def __call__(self, data_dict):
|
| 732 |
+
(
|
| 733 |
+
fn_idx,
|
| 734 |
+
brightness_factor,
|
| 735 |
+
contrast_factor,
|
| 736 |
+
saturation_factor,
|
| 737 |
+
hue_factor,
|
| 738 |
+
) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
|
| 739 |
+
|
| 740 |
+
for fn_id in fn_idx:
|
| 741 |
+
if (
|
| 742 |
+
fn_id == 0
|
| 743 |
+
and brightness_factor is not None
|
| 744 |
+
and np.random.rand() < self.p
|
| 745 |
+
):
|
| 746 |
+
data_dict["color"] = self.adjust_brightness(
|
| 747 |
+
data_dict["color"], brightness_factor
|
| 748 |
+
)
|
| 749 |
+
elif (
|
| 750 |
+
fn_id == 1 and contrast_factor is not None and np.random.rand() < self.p
|
| 751 |
+
):
|
| 752 |
+
data_dict["color"] = self.adjust_contrast(
|
| 753 |
+
data_dict["color"], contrast_factor
|
| 754 |
+
)
|
| 755 |
+
elif (
|
| 756 |
+
fn_id == 2
|
| 757 |
+
and saturation_factor is not None
|
| 758 |
+
and np.random.rand() < self.p
|
| 759 |
+
):
|
| 760 |
+
data_dict["color"] = self.adjust_saturation(
|
| 761 |
+
data_dict["color"], saturation_factor
|
| 762 |
+
)
|
| 763 |
+
elif fn_id == 3 and hue_factor is not None and np.random.rand() < self.p:
|
| 764 |
+
data_dict["color"] = self.adjust_hue(data_dict["color"], hue_factor)
|
| 765 |
+
return data_dict
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
@TRANSFORMS.register_module()
|
| 769 |
+
class HueSaturationTranslation(object):
|
| 770 |
+
@staticmethod
|
| 771 |
+
def rgb_to_hsv(rgb):
|
| 772 |
+
# Translated from source of colorsys.rgb_to_hsv
|
| 773 |
+
# r,g,b should be a numpy arrays with values between 0 and 255
|
| 774 |
+
# rgb_to_hsv returns an array of floats between 0.0 and 1.0.
|
| 775 |
+
rgb = rgb.astype("float")
|
| 776 |
+
hsv = np.zeros_like(rgb)
|
| 777 |
+
# in case an RGBA array was passed, just copy the A channel
|
| 778 |
+
hsv[..., 3:] = rgb[..., 3:]
|
| 779 |
+
r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
|
| 780 |
+
maxc = np.max(rgb[..., :3], axis=-1)
|
| 781 |
+
minc = np.min(rgb[..., :3], axis=-1)
|
| 782 |
+
hsv[..., 2] = maxc
|
| 783 |
+
mask = maxc != minc
|
| 784 |
+
hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
|
| 785 |
+
rc = np.zeros_like(r)
|
| 786 |
+
gc = np.zeros_like(g)
|
| 787 |
+
bc = np.zeros_like(b)
|
| 788 |
+
rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
|
| 789 |
+
gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
|
| 790 |
+
bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
|
| 791 |
+
hsv[..., 0] = np.select(
|
| 792 |
+
[r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc
|
| 793 |
+
)
|
| 794 |
+
hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
|
| 795 |
+
return hsv
|
| 796 |
+
|
| 797 |
+
@staticmethod
|
| 798 |
+
def hsv_to_rgb(hsv):
|
| 799 |
+
# Translated from source of colorsys.hsv_to_rgb
|
| 800 |
+
# h,s should be a numpy arrays with values between 0.0 and 1.0
|
| 801 |
+
# v should be a numpy array with values between 0.0 and 255.0
|
| 802 |
+
# hsv_to_rgb returns an array of uints between 0 and 255.
|
| 803 |
+
rgb = np.empty_like(hsv)
|
| 804 |
+
rgb[..., 3:] = hsv[..., 3:]
|
| 805 |
+
h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
|
| 806 |
+
i = (h * 6.0).astype("uint8")
|
| 807 |
+
f = (h * 6.0) - i
|
| 808 |
+
p = v * (1.0 - s)
|
| 809 |
+
q = v * (1.0 - s * f)
|
| 810 |
+
t = v * (1.0 - s * (1.0 - f))
|
| 811 |
+
i = i % 6
|
| 812 |
+
conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
|
| 813 |
+
rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
|
| 814 |
+
rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
|
| 815 |
+
rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
|
| 816 |
+
return rgb.astype("uint8")
|
| 817 |
+
|
| 818 |
+
def __init__(self, hue_max=0.5, saturation_max=0.2):
|
| 819 |
+
self.hue_max = hue_max
|
| 820 |
+
self.saturation_max = saturation_max
|
| 821 |
+
|
| 822 |
+
def __call__(self, data_dict):
|
| 823 |
+
if "color" in data_dict.keys():
|
| 824 |
+
# Assume color[:, :3] is rgb
|
| 825 |
+
hsv = HueSaturationTranslation.rgb_to_hsv(data_dict["color"][:, :3])
|
| 826 |
+
hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max
|
| 827 |
+
sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max
|
| 828 |
+
hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1)
|
| 829 |
+
hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1)
|
| 830 |
+
data_dict["color"][:, :3] = np.clip(
|
| 831 |
+
HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255
|
| 832 |
+
)
|
| 833 |
+
return data_dict
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
@TRANSFORMS.register_module()
|
| 837 |
+
class RandomColorDrop(object):
|
| 838 |
+
def __init__(self, p=0.2, color_augment=0.0):
|
| 839 |
+
self.p = p
|
| 840 |
+
self.color_augment = color_augment
|
| 841 |
+
|
| 842 |
+
def __call__(self, data_dict):
|
| 843 |
+
if "color" in data_dict.keys() and np.random.rand() < self.p:
|
| 844 |
+
data_dict["color"] *= self.color_augment
|
| 845 |
+
return data_dict
|
| 846 |
+
|
| 847 |
+
def __repr__(self):
|
| 848 |
+
return "RandomColorDrop(color_augment: {}, p: {})".format(
|
| 849 |
+
self.color_augment, self.p
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
@TRANSFORMS.register_module()
|
| 854 |
+
class ElasticDistortion(object):
|
| 855 |
+
def __init__(self, distortion_params=None):
|
| 856 |
+
self.distortion_params = (
|
| 857 |
+
[[0.2, 0.4], [0.8, 1.6]] if distortion_params is None else distortion_params
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
@staticmethod
|
| 861 |
+
def elastic_distortion(coords, granularity, magnitude):
|
| 862 |
+
"""
|
| 863 |
+
Apply elastic distortion on sparse coordinate space.
|
| 864 |
+
pointcloud: numpy array of (number of points, at least 3 spatial dims)
|
| 865 |
+
granularity: size of the noise grid (in same scale[m/cm] as the voxel grid)
|
| 866 |
+
magnitude: noise multiplier
|
| 867 |
+
"""
|
| 868 |
+
blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3
|
| 869 |
+
blury = np.ones((1, 3, 1, 1)).astype("float32") / 3
|
| 870 |
+
blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3
|
| 871 |
+
coords_min = coords.min(0)
|
| 872 |
+
|
| 873 |
+
# Create Gaussian noise tensor of the size given by granularity.
|
| 874 |
+
noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
|
| 875 |
+
noise = np.random.randn(*noise_dim, 3).astype(np.float32)
|
| 876 |
+
|
| 877 |
+
# Smoothing.
|
| 878 |
+
for _ in range(2):
|
| 879 |
+
noise = scipy.ndimage.filters.convolve(
|
| 880 |
+
noise, blurx, mode="constant", cval=0
|
| 881 |
+
)
|
| 882 |
+
noise = scipy.ndimage.filters.convolve(
|
| 883 |
+
noise, blury, mode="constant", cval=0
|
| 884 |
+
)
|
| 885 |
+
noise = scipy.ndimage.filters.convolve(
|
| 886 |
+
noise, blurz, mode="constant", cval=0
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
# Trilinear interpolate noise filters for each spatial dimensions.
|
| 890 |
+
ax = [
|
| 891 |
+
np.linspace(d_min, d_max, d)
|
| 892 |
+
for d_min, d_max, d in zip(
|
| 893 |
+
coords_min - granularity,
|
| 894 |
+
coords_min + granularity * (noise_dim - 2),
|
| 895 |
+
noise_dim,
|
| 896 |
+
)
|
| 897 |
+
]
|
| 898 |
+
interp = scipy.interpolate.RegularGridInterpolator(
|
| 899 |
+
ax, noise, bounds_error=False, fill_value=0
|
| 900 |
+
)
|
| 901 |
+
coords += interp(coords) * magnitude
|
| 902 |
+
return coords
|
| 903 |
+
|
| 904 |
+
def __call__(self, data_dict):
|
| 905 |
+
if "coord" in data_dict.keys() and self.distortion_params is not None:
|
| 906 |
+
if random.random() < 0.95:
|
| 907 |
+
for granularity, magnitude in self.distortion_params:
|
| 908 |
+
data_dict["coord"] = self.elastic_distortion(
|
| 909 |
+
data_dict["coord"], granularity, magnitude
|
| 910 |
+
)
|
| 911 |
+
return data_dict
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
@TRANSFORMS.register_module()
|
| 915 |
+
class GridSample(object):
|
| 916 |
+
def __init__(
|
| 917 |
+
self,
|
| 918 |
+
grid_size=0.05,
|
| 919 |
+
hash_type="fnv",
|
| 920 |
+
mode="train",
|
| 921 |
+
return_inverse=False,
|
| 922 |
+
return_grid_coord=False,
|
| 923 |
+
return_min_coord=False,
|
| 924 |
+
return_displacement=False,
|
| 925 |
+
project_displacement=False,
|
| 926 |
+
):
|
| 927 |
+
self.grid_size = grid_size
|
| 928 |
+
self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec
|
| 929 |
+
assert mode in ["train", "test"]
|
| 930 |
+
self.mode = mode
|
| 931 |
+
self.return_inverse = return_inverse
|
| 932 |
+
self.return_grid_coord = return_grid_coord
|
| 933 |
+
self.return_min_coord = return_min_coord
|
| 934 |
+
self.return_displacement = return_displacement
|
| 935 |
+
self.project_displacement = project_displacement
|
| 936 |
+
|
| 937 |
+
def __call__(self, data_dict):
|
| 938 |
+
assert "coord" in data_dict.keys()
|
| 939 |
+
scaled_coord = data_dict["coord"] / np.array(self.grid_size)
|
| 940 |
+
grid_coord = np.floor(scaled_coord).astype(int)
|
| 941 |
+
min_coord = grid_coord.min(0)
|
| 942 |
+
grid_coord -= min_coord
|
| 943 |
+
scaled_coord -= min_coord
|
| 944 |
+
min_coord = min_coord * np.array(self.grid_size)
|
| 945 |
+
key = self.hash(grid_coord)
|
| 946 |
+
idx_sort = np.argsort(key)
|
| 947 |
+
key_sort = key[idx_sort]
|
| 948 |
+
_, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
|
| 949 |
+
if self.mode == "train": # train mode
|
| 950 |
+
idx_select = (
|
| 951 |
+
np.cumsum(np.insert(count, 0, 0)[0:-1])
|
| 952 |
+
+ np.random.randint(0, count.max(), count.size) % count
|
| 953 |
+
)
|
| 954 |
+
idx_unique = idx_sort[idx_select]
|
| 955 |
+
if "sampled_index" in data_dict:
|
| 956 |
+
# for ScanNet data efficient, we need to make sure labeled point is sampled.
|
| 957 |
+
idx_unique = np.unique(
|
| 958 |
+
np.append(idx_unique, data_dict["sampled_index"])
|
| 959 |
+
)
|
| 960 |
+
mask = np.zeros_like(data_dict["segment"]).astype(bool)
|
| 961 |
+
mask[data_dict["sampled_index"]] = True
|
| 962 |
+
data_dict["sampled_index"] = np.where(mask[idx_unique])[0]
|
| 963 |
+
data_dict = index_operator(data_dict, idx_unique)
|
| 964 |
+
if self.return_inverse:
|
| 965 |
+
data_dict["inverse"] = np.zeros_like(inverse)
|
| 966 |
+
data_dict["inverse"][idx_sort] = inverse
|
| 967 |
+
if self.return_grid_coord:
|
| 968 |
+
data_dict["grid_coord"] = grid_coord[idx_unique]
|
| 969 |
+
data_dict["index_valid_keys"].append("grid_coord")
|
| 970 |
+
if self.return_min_coord:
|
| 971 |
+
data_dict["min_coord"] = min_coord.reshape([1, 3])
|
| 972 |
+
if self.return_displacement:
|
| 973 |
+
displacement = (
|
| 974 |
+
scaled_coord - grid_coord - 0.5
|
| 975 |
+
) # [0, 1] -> [-0.5, 0.5] displacement to center
|
| 976 |
+
if self.project_displacement:
|
| 977 |
+
displacement = np.sum(
|
| 978 |
+
displacement * data_dict["normal"], axis=-1, keepdims=True
|
| 979 |
+
)
|
| 980 |
+
data_dict["displacement"] = displacement[idx_unique]
|
| 981 |
+
data_dict["index_valid_keys"].append("displacement")
|
| 982 |
+
return data_dict
|
| 983 |
+
|
| 984 |
+
elif self.mode == "test": # test mode
|
| 985 |
+
data_part_list = []
|
| 986 |
+
for i in range(count.max()):
|
| 987 |
+
idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count
|
| 988 |
+
idx_part = idx_sort[idx_select]
|
| 989 |
+
data_part = index_operator(data_dict, idx_part, duplicate=True)
|
| 990 |
+
data_part["index"] = idx_part
|
| 991 |
+
if self.return_inverse:
|
| 992 |
+
data_part["inverse"] = np.zeros_like(inverse)
|
| 993 |
+
data_part["inverse"][idx_sort] = inverse
|
| 994 |
+
if self.return_grid_coord:
|
| 995 |
+
data_part["grid_coord"] = grid_coord[idx_part]
|
| 996 |
+
data_dict["index_valid_keys"].append("grid_coord")
|
| 997 |
+
if self.return_min_coord:
|
| 998 |
+
data_part["min_coord"] = min_coord.reshape([1, 3])
|
| 999 |
+
if self.return_displacement:
|
| 1000 |
+
displacement = (
|
| 1001 |
+
scaled_coord - grid_coord - 0.5
|
| 1002 |
+
) # [0, 1] -> [-0.5, 0.5] displacement to center
|
| 1003 |
+
if self.project_displacement:
|
| 1004 |
+
displacement = np.sum(
|
| 1005 |
+
displacement * data_dict["normal"], axis=-1, keepdims=True
|
| 1006 |
+
)
|
| 1007 |
+
data_dict["displacement"] = displacement[idx_part]
|
| 1008 |
+
data_dict["index_valid_keys"].append("displacement")
|
| 1009 |
+
data_part_list.append(data_part)
|
| 1010 |
+
return data_part_list
|
| 1011 |
+
else:
|
| 1012 |
+
raise NotImplementedError
|
| 1013 |
+
|
| 1014 |
+
@staticmethod
|
| 1015 |
+
def ravel_hash_vec(arr):
|
| 1016 |
+
"""
|
| 1017 |
+
Ravel the coordinates after subtracting the min coordinates.
|
| 1018 |
+
"""
|
| 1019 |
+
assert arr.ndim == 2
|
| 1020 |
+
arr = arr.copy()
|
| 1021 |
+
arr -= arr.min(0)
|
| 1022 |
+
arr = arr.astype(np.uint64, copy=False)
|
| 1023 |
+
arr_max = arr.max(0).astype(np.uint64) + 1
|
| 1024 |
+
|
| 1025 |
+
keys = np.zeros(arr.shape[0], dtype=np.uint64)
|
| 1026 |
+
# Fortran style indexing
|
| 1027 |
+
for j in range(arr.shape[1] - 1):
|
| 1028 |
+
keys += arr[:, j]
|
| 1029 |
+
keys *= arr_max[j + 1]
|
| 1030 |
+
keys += arr[:, -1]
|
| 1031 |
+
return keys
|
| 1032 |
+
|
| 1033 |
+
@staticmethod
|
| 1034 |
+
def fnv_hash_vec(arr):
|
| 1035 |
+
"""
|
| 1036 |
+
FNV64-1A
|
| 1037 |
+
"""
|
| 1038 |
+
assert arr.ndim == 2
|
| 1039 |
+
# Floor first for negative coordinates
|
| 1040 |
+
arr = arr.copy()
|
| 1041 |
+
arr = arr.astype(np.uint64, copy=False)
|
| 1042 |
+
hashed_arr = np.uint64(14695981039346656037) * np.ones(
|
| 1043 |
+
arr.shape[0], dtype=np.uint64
|
| 1044 |
+
)
|
| 1045 |
+
for j in range(arr.shape[1]):
|
| 1046 |
+
hashed_arr *= np.uint64(1099511628211)
|
| 1047 |
+
hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
|
| 1048 |
+
return hashed_arr
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
@TRANSFORMS.register_module()
|
| 1052 |
+
class SphereCrop(object):
|
| 1053 |
+
def __init__(self, point_max=80000, sample_rate=None, mode="random"):
|
| 1054 |
+
self.point_max = point_max
|
| 1055 |
+
self.sample_rate = sample_rate
|
| 1056 |
+
assert mode in ["random", "center", "all"]
|
| 1057 |
+
self.mode = mode
|
| 1058 |
+
|
| 1059 |
+
def __call__(self, data_dict):
|
| 1060 |
+
point_max = (
|
| 1061 |
+
int(self.sample_rate * data_dict["coord"].shape[0])
|
| 1062 |
+
if self.sample_rate is not None
|
| 1063 |
+
else self.point_max
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
assert "coord" in data_dict.keys()
|
| 1067 |
+
if data_dict["coord"].shape[0] > point_max:
|
| 1068 |
+
if self.mode == "random":
|
| 1069 |
+
center = data_dict["coord"][
|
| 1070 |
+
np.random.randint(data_dict["coord"].shape[0])
|
| 1071 |
+
]
|
| 1072 |
+
elif self.mode == "center":
|
| 1073 |
+
center = data_dict["coord"][data_dict["coord"].shape[0] // 2]
|
| 1074 |
+
else:
|
| 1075 |
+
raise NotImplementedError
|
| 1076 |
+
idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[
|
| 1077 |
+
:point_max
|
| 1078 |
+
]
|
| 1079 |
+
data_dict = index_operator(data_dict, idx_crop)
|
| 1080 |
+
return data_dict
|
| 1081 |
+
|
| 1082 |
+
|
| 1083 |
+
@TRANSFORMS.register_module()
|
| 1084 |
+
class ShufflePoint(object):
|
| 1085 |
+
def __call__(self, data_dict):
|
| 1086 |
+
assert "coord" in data_dict.keys()
|
| 1087 |
+
shuffle_index = np.arange(data_dict["coord"].shape[0])
|
| 1088 |
+
np.random.shuffle(shuffle_index)
|
| 1089 |
+
data_dict = index_operator(data_dict, shuffle_index)
|
| 1090 |
+
return data_dict
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
@TRANSFORMS.register_module()
|
| 1094 |
+
class CropBoundary(object):
|
| 1095 |
+
def __call__(self, data_dict):
|
| 1096 |
+
assert "segment" in data_dict
|
| 1097 |
+
segment = data_dict["segment"].flatten()
|
| 1098 |
+
mask = (segment != 0) * (segment != 1)
|
| 1099 |
+
data_dict = index_operator(data_dict, mask)
|
| 1100 |
+
return data_dict
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
@TRANSFORMS.register_module()
|
| 1104 |
+
class ContrastiveViewsGenerator(object):
|
| 1105 |
+
def __init__(
|
| 1106 |
+
self,
|
| 1107 |
+
view_keys=("coord", "color", "normal", "origin_coord"),
|
| 1108 |
+
view_trans_cfg=None,
|
| 1109 |
+
):
|
| 1110 |
+
self.view_keys = view_keys
|
| 1111 |
+
self.view_trans = Compose(view_trans_cfg)
|
| 1112 |
+
|
| 1113 |
+
def __call__(self, data_dict):
|
| 1114 |
+
view1_dict = dict()
|
| 1115 |
+
view2_dict = dict()
|
| 1116 |
+
for key in self.view_keys:
|
| 1117 |
+
view1_dict[key] = data_dict[key].copy()
|
| 1118 |
+
view2_dict[key] = data_dict[key].copy()
|
| 1119 |
+
view1_dict = self.view_trans(view1_dict)
|
| 1120 |
+
view2_dict = self.view_trans(view2_dict)
|
| 1121 |
+
for key, value in view1_dict.items():
|
| 1122 |
+
data_dict["view1_" + key] = value
|
| 1123 |
+
for key, value in view2_dict.items():
|
| 1124 |
+
data_dict["view2_" + key] = value
|
| 1125 |
+
return data_dict
|
| 1126 |
+
|
| 1127 |
+
|
| 1128 |
+
@TRANSFORMS.register_module()
|
| 1129 |
+
class MultiViewGenerator(object):
|
| 1130 |
+
def __init__(
|
| 1131 |
+
self,
|
| 1132 |
+
global_view_num=2,
|
| 1133 |
+
global_view_scale=(0.4, 1.0),
|
| 1134 |
+
local_view_num=4,
|
| 1135 |
+
local_view_scale=(0.1, 0.4),
|
| 1136 |
+
global_shared_transform=None,
|
| 1137 |
+
global_transform=None,
|
| 1138 |
+
local_transform=None,
|
| 1139 |
+
max_size=65536,
|
| 1140 |
+
center_height_scale=(0, 1),
|
| 1141 |
+
shared_global_view=False,
|
| 1142 |
+
view_keys=("coord", "origin_coord", "color", "normal"),
|
| 1143 |
+
):
|
| 1144 |
+
self.global_view_num = global_view_num
|
| 1145 |
+
self.global_view_scale = global_view_scale
|
| 1146 |
+
self.local_view_num = local_view_num
|
| 1147 |
+
self.local_view_scale = local_view_scale
|
| 1148 |
+
self.global_shared_transform = Compose(global_shared_transform)
|
| 1149 |
+
self.global_transform = Compose(global_transform)
|
| 1150 |
+
self.local_transform = Compose(local_transform)
|
| 1151 |
+
self.max_size = max_size
|
| 1152 |
+
self.center_height_scale = center_height_scale
|
| 1153 |
+
self.shared_global_view = shared_global_view
|
| 1154 |
+
self.view_keys = view_keys
|
| 1155 |
+
assert "coord" in view_keys
|
| 1156 |
+
|
| 1157 |
+
def get_view(self, point, center, scale):
|
| 1158 |
+
coord = point["coord"]
|
| 1159 |
+
max_size = min(self.max_size, coord.shape[0])
|
| 1160 |
+
size = int(np.random.uniform(*scale) * max_size)
|
| 1161 |
+
index = np.argsort(np.sum(np.square(coord - center), axis=-1))[:size]
|
| 1162 |
+
view = dict(index=index)
|
| 1163 |
+
for key in point.keys():
|
| 1164 |
+
if key in self.view_keys:
|
| 1165 |
+
view[key] = point[key][index]
|
| 1166 |
+
|
| 1167 |
+
if "index_valid_keys" in point.keys():
|
| 1168 |
+
# inherit index_valid_keys from point
|
| 1169 |
+
view["index_valid_keys"] = point["index_valid_keys"]
|
| 1170 |
+
return view
|
| 1171 |
+
|
| 1172 |
+
def __call__(self, data_dict):
|
| 1173 |
+
coord = data_dict["coord"]
|
| 1174 |
+
point = self.global_shared_transform(copy.deepcopy(data_dict))
|
| 1175 |
+
z_min = coord[:, 2].min()
|
| 1176 |
+
z_max = coord[:, 2].max()
|
| 1177 |
+
z_min_ = z_min + (z_max - z_min) * self.center_height_scale[0]
|
| 1178 |
+
z_max_ = z_min + (z_max - z_min) * self.center_height_scale[1]
|
| 1179 |
+
center_mask = np.logical_and(coord[:, 2] >= z_min_, coord[:, 2] <= z_max_)
|
| 1180 |
+
# get major global view
|
| 1181 |
+
major_center = coord[np.random.choice(np.where(center_mask)[0])]
|
| 1182 |
+
major_view = self.get_view(point, major_center, self.global_view_scale)
|
| 1183 |
+
major_coord = major_view["coord"]
|
| 1184 |
+
# get global views: restrict the center of left global view within the major global view
|
| 1185 |
+
if not self.shared_global_view:
|
| 1186 |
+
global_views = [
|
| 1187 |
+
self.get_view(
|
| 1188 |
+
point=point,
|
| 1189 |
+
center=major_coord[np.random.randint(major_coord.shape[0])],
|
| 1190 |
+
scale=self.global_view_scale,
|
| 1191 |
+
)
|
| 1192 |
+
for _ in range(self.global_view_num - 1)
|
| 1193 |
+
]
|
| 1194 |
+
else:
|
| 1195 |
+
global_views = [
|
| 1196 |
+
{key: value.copy() for key, value in major_view.items()}
|
| 1197 |
+
for _ in range(self.global_view_num - 1)
|
| 1198 |
+
]
|
| 1199 |
+
|
| 1200 |
+
global_views = [major_view] + global_views
|
| 1201 |
+
|
| 1202 |
+
# get local views: restrict the center of local view within the major global view
|
| 1203 |
+
cover_mask = np.zeros_like(major_view["index"], dtype=bool)
|
| 1204 |
+
local_views = []
|
| 1205 |
+
for i in range(self.local_view_num):
|
| 1206 |
+
if sum(~cover_mask) == 0:
|
| 1207 |
+
# reset cover mask if all points are sampled
|
| 1208 |
+
cover_mask[:] = False
|
| 1209 |
+
local_view = self.get_view(
|
| 1210 |
+
point=data_dict,
|
| 1211 |
+
center=major_coord[np.random.choice(np.where(~cover_mask)[0])],
|
| 1212 |
+
scale=self.local_view_scale,
|
| 1213 |
+
)
|
| 1214 |
+
local_views.append(local_view)
|
| 1215 |
+
cover_mask[np.isin(major_view["index"], local_view["index"])] = True
|
| 1216 |
+
|
| 1217 |
+
# augmentation and concat
|
| 1218 |
+
view_dict = {}
|
| 1219 |
+
for global_view in global_views:
|
| 1220 |
+
global_view.pop("index")
|
| 1221 |
+
global_view = self.global_transform(global_view)
|
| 1222 |
+
for key in self.view_keys:
|
| 1223 |
+
if f"global_{key}" in view_dict.keys():
|
| 1224 |
+
view_dict[f"global_{key}"].append(global_view[key])
|
| 1225 |
+
else:
|
| 1226 |
+
view_dict[f"global_{key}"] = [global_view[key]]
|
| 1227 |
+
view_dict["global_offset"] = np.cumsum(
|
| 1228 |
+
[data.shape[0] for data in view_dict["global_coord"]]
|
| 1229 |
+
)
|
| 1230 |
+
for local_view in local_views:
|
| 1231 |
+
local_view.pop("index")
|
| 1232 |
+
local_view = self.local_transform(local_view)
|
| 1233 |
+
for key in self.view_keys:
|
| 1234 |
+
if f"local_{key}" in view_dict.keys():
|
| 1235 |
+
view_dict[f"local_{key}"].append(local_view[key])
|
| 1236 |
+
else:
|
| 1237 |
+
view_dict[f"local_{key}"] = [local_view[key]]
|
| 1238 |
+
view_dict["local_offset"] = np.cumsum(
|
| 1239 |
+
[data.shape[0] for data in view_dict["local_coord"]]
|
| 1240 |
+
)
|
| 1241 |
+
for key in view_dict.keys():
|
| 1242 |
+
if "offset" not in key:
|
| 1243 |
+
view_dict[key] = np.concatenate(view_dict[key], axis=0)
|
| 1244 |
+
data_dict.update(view_dict)
|
| 1245 |
+
return data_dict
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
@TRANSFORMS.register_module()
|
| 1249 |
+
class InstanceParser(object):
|
| 1250 |
+
def __init__(self, segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1):
|
| 1251 |
+
self.segment_ignore_index = segment_ignore_index
|
| 1252 |
+
self.instance_ignore_index = instance_ignore_index
|
| 1253 |
+
|
| 1254 |
+
def __call__(self, data_dict):
|
| 1255 |
+
coord = data_dict["coord"]
|
| 1256 |
+
segment = data_dict["segment"]
|
| 1257 |
+
instance = data_dict["instance"]
|
| 1258 |
+
mask = ~np.in1d(segment, self.segment_ignore_index)
|
| 1259 |
+
# mapping ignored instance to ignore index
|
| 1260 |
+
instance[~mask] = self.instance_ignore_index
|
| 1261 |
+
# reorder left instance
|
| 1262 |
+
unique, inverse = np.unique(instance[mask], return_inverse=True)
|
| 1263 |
+
instance_num = len(unique)
|
| 1264 |
+
instance[mask] = inverse
|
| 1265 |
+
# init instance information
|
| 1266 |
+
centroid = np.ones((coord.shape[0], 3)) * self.instance_ignore_index
|
| 1267 |
+
bbox = np.ones((instance_num, 8)) * self.instance_ignore_index
|
| 1268 |
+
vacancy = [
|
| 1269 |
+
index for index in self.segment_ignore_index if index >= 0
|
| 1270 |
+
] # vacate class index
|
| 1271 |
+
|
| 1272 |
+
for instance_id in range(instance_num):
|
| 1273 |
+
mask_ = instance == instance_id
|
| 1274 |
+
coord_ = coord[mask_]
|
| 1275 |
+
bbox_min = coord_.min(0)
|
| 1276 |
+
bbox_max = coord_.max(0)
|
| 1277 |
+
bbox_centroid = coord_.mean(0)
|
| 1278 |
+
bbox_center = (bbox_max + bbox_min) / 2
|
| 1279 |
+
bbox_size = bbox_max - bbox_min
|
| 1280 |
+
bbox_theta = np.zeros(1, dtype=coord_.dtype)
|
| 1281 |
+
bbox_class = np.array([segment[mask_][0]], dtype=coord_.dtype)
|
| 1282 |
+
# shift class index to fill vacate class index caused by segment ignore index
|
| 1283 |
+
bbox_class -= np.greater(bbox_class, vacancy).sum()
|
| 1284 |
+
|
| 1285 |
+
centroid[mask_] = bbox_centroid
|
| 1286 |
+
bbox[instance_id] = np.concatenate(
|
| 1287 |
+
[bbox_center, bbox_size, bbox_theta, bbox_class]
|
| 1288 |
+
) # 3 + 3 + 1 + 1 = 8
|
| 1289 |
+
data_dict["instance"] = instance
|
| 1290 |
+
data_dict["instance_centroid"] = centroid
|
| 1291 |
+
data_dict["bbox"] = bbox
|
| 1292 |
+
return data_dict
|
| 1293 |
+
|
| 1294 |
+
|
| 1295 |
+
class Compose(object):
|
| 1296 |
+
def __init__(self, cfg=None):
|
| 1297 |
+
self.cfg = cfg if cfg is not None else []
|
| 1298 |
+
self.transforms = []
|
| 1299 |
+
for t_cfg in self.cfg:
|
| 1300 |
+
self.transforms.append(TRANSFORMS.build(t_cfg))
|
| 1301 |
+
|
| 1302 |
+
def __call__(self, data_dict):
|
| 1303 |
+
for t in self.transforms:
|
| 1304 |
+
data_dict = t(data_dict)
|
| 1305 |
+
return data_dict
|
| 1306 |
+
|
| 1307 |
+
|
| 1308 |
+
def default():
|
| 1309 |
+
config = [
|
| 1310 |
+
dict(type="CenterShift", apply_z=True),
|
| 1311 |
+
dict(
|
| 1312 |
+
type="GridSample",
|
| 1313 |
+
# grid_size=0.02,
|
| 1314 |
+
# grid_size=0.01,
|
| 1315 |
+
grid_size=0.005,
|
| 1316 |
+
# grid_size=0.0025,
|
| 1317 |
+
hash_type="fnv",
|
| 1318 |
+
mode="train",
|
| 1319 |
+
return_grid_coord=True,
|
| 1320 |
+
return_inverse=True,
|
| 1321 |
+
),
|
| 1322 |
+
dict(type="NormalizeColor"),
|
| 1323 |
+
dict(type="ToTensor"),
|
| 1324 |
+
dict(
|
| 1325 |
+
type="Collect",
|
| 1326 |
+
keys=("coord", "grid_coord", "color", "inverse"),
|
| 1327 |
+
feat_keys=("coord", "color", "normal"),
|
| 1328 |
+
),
|
| 1329 |
+
]
|
| 1330 |
+
return Compose(config)
|
XPart/partgen/models/sonata/utils.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
General utils
|
| 3 |
+
|
| 4 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
|
| 5 |
+
Please cite our work if the code is helpful to you.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import os
|
| 24 |
+
import random
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.backends.cudnn as cudnn
|
| 28 |
+
from datetime import datetime
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def offset2bincount(offset):
|
| 33 |
+
return torch.diff(
|
| 34 |
+
offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def bincount2offset(bincount):
|
| 40 |
+
return torch.cumsum(bincount, dim=0)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def offset2batch(offset):
|
| 45 |
+
bincount = offset2bincount(offset)
|
| 46 |
+
return torch.arange(
|
| 47 |
+
len(bincount), device=offset.device, dtype=torch.long
|
| 48 |
+
).repeat_interleave(bincount)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def batch2offset(batch):
|
| 53 |
+
return torch.cumsum(batch.bincount(), dim=0).long()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_random_seed():
|
| 57 |
+
seed = (
|
| 58 |
+
os.getpid()
|
| 59 |
+
+ int(datetime.now().strftime("%S%f"))
|
| 60 |
+
+ int.from_bytes(os.urandom(2), "big")
|
| 61 |
+
)
|
| 62 |
+
return seed
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def set_seed(seed=None):
|
| 66 |
+
if seed is None:
|
| 67 |
+
seed = get_random_seed()
|
| 68 |
+
random.seed(seed)
|
| 69 |
+
np.random.seed(seed)
|
| 70 |
+
torch.manual_seed(seed)
|
| 71 |
+
torch.cuda.manual_seed(seed)
|
| 72 |
+
torch.cuda.manual_seed_all(seed)
|
| 73 |
+
cudnn.benchmark = False
|
| 74 |
+
cudnn.deterministic = True
|
| 75 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|