root commited on
Commit
7b75adb
·
1 Parent(s): be73458

add our app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +215 -0
  3. P3-SAM/demo/assets/1.glb +3 -0
  4. P3-SAM/demo/assets/2.glb +3 -0
  5. P3-SAM/demo/assets/3.glb +3 -0
  6. P3-SAM/demo/assets/4.glb +3 -0
  7. P3-SAM/demo/auto_mask.py +1405 -0
  8. P3-SAM/demo/auto_mask_no_postprocess.py +943 -0
  9. P3-SAM/model.py +156 -0
  10. P3-SAM/utils/chamfer3D/chamfer3D.cu +196 -0
  11. P3-SAM/utils/chamfer3D/chamfer_cuda.cpp +29 -0
  12. P3-SAM/utils/chamfer3D/dist_chamfer_3D.py +81 -0
  13. P3-SAM/utils/chamfer3D/setup.py +14 -0
  14. XPart/data/000.glb +3 -0
  15. XPart/data/001.glb +3 -0
  16. XPart/data/002.glb +3 -0
  17. XPart/data/003.glb +3 -0
  18. XPart/data/004.glb +3 -0
  19. XPart/partgen/bbox_estimator/auto_mask_api.py +1417 -0
  20. XPart/partgen/config/infer.yaml +122 -0
  21. XPart/partgen/config/sonata.json +58 -0
  22. XPart/partgen/models/autoencoders/__init__.py +29 -0
  23. XPart/partgen/models/autoencoders/attention_blocks.py +770 -0
  24. XPart/partgen/models/autoencoders/attention_processors.py +32 -0
  25. XPart/partgen/models/autoencoders/model.py +452 -0
  26. XPart/partgen/models/autoencoders/surface_extractors.py +164 -0
  27. XPart/partgen/models/autoencoders/volume_decoders.py +107 -0
  28. XPart/partgen/models/conditioner/condioner_release.py +170 -0
  29. XPart/partgen/models/conditioner/part_encoders.py +89 -0
  30. XPart/partgen/models/conditioner/sonata_extractor.py +315 -0
  31. XPart/partgen/models/diffusion/schedulers.py +329 -0
  32. XPart/partgen/models/diffusion/transport/__init__.py +97 -0
  33. XPart/partgen/models/diffusion/transport/integrators.py +142 -0
  34. XPart/partgen/models/diffusion/transport/path.py +220 -0
  35. XPart/partgen/models/diffusion/transport/transport.py +506 -0
  36. XPart/partgen/models/diffusion/transport/utils.py +54 -0
  37. XPart/partgen/models/moe_layers.py +209 -0
  38. XPart/partgen/models/partformer_dit.py +756 -0
  39. XPart/partgen/models/sonata/__init__.py +35 -0
  40. XPart/partgen/models/sonata/data.py +84 -0
  41. XPart/partgen/models/sonata/model.py +874 -0
  42. XPart/partgen/models/sonata/module.py +107 -0
  43. XPart/partgen/models/sonata/registry.py +340 -0
  44. XPart/partgen/models/sonata/serialization/__init__.py +9 -0
  45. XPart/partgen/models/sonata/serialization/default.py +82 -0
  46. XPart/partgen/models/sonata/serialization/hilbert.py +318 -0
  47. XPart/partgen/models/sonata/serialization/z_order.py +145 -0
  48. XPart/partgen/models/sonata/structure.py +159 -0
  49. XPart/partgen/models/sonata/transform.py +1330 -0
  50. XPart/partgen/models/sonata/utils.py +75 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.glb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # Streamlit
210
+ .streamlit/secrets.toml
211
+
212
+ results/
213
+ weights/
214
+ .gradio/
215
+ demo/segment_result.glb
P3-SAM/demo/assets/1.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b57626cf269cb1e5b949586b6b6b87efa552c66d0a45371fed0c9f47db4c3314
3
+ size 29140044
P3-SAM/demo/assets/2.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cedbb7365e8506d42809cc547fc0d69d9118af1d419c8b8748bae01b545634c
3
+ size 8529600
P3-SAM/demo/assets/3.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b179b1141dc0dd847e0be289c3f9b0dfd9446995b38b5050d41df7cbabbc516d
3
+ size 30475016
P3-SAM/demo/assets/4.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be9fea2e4b63233ab30f929ca9f735ccbd154b1f9652e9bbaacd87839829f02b
3
+ size 29621968
P3-SAM/demo/auto_mask.py ADDED
@@ -0,0 +1,1405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import argparse
7
+ import trimesh
8
+ from sklearn.decomposition import PCA
9
+ import fpsample
10
+ from tqdm import tqdm
11
+ import threading
12
+ import random
13
+
14
+ # from tqdm.notebook import tqdm
15
+ import time
16
+ import copy
17
+ import shutil
18
+ from pathlib import Path
19
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
20
+ from collections import defaultdict
21
+
22
+ import numba
23
+ from numba import njit
24
+
25
+ sys.path.append('..')
26
+ from model import build_P3SAM, load_state_dict
27
+
28
+ class P3SAM(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ build_P3SAM(self)
32
+
33
+ def load_state_dict(self,
34
+ ckpt_path=None,
35
+ state_dict=None,
36
+ strict=True,
37
+ assign=False,
38
+ ignore_seg_mlp=False,
39
+ ignore_seg_s2_mlp=False,
40
+ ignore_iou_mlp=False):
41
+ load_state_dict(self,
42
+ ckpt_path=ckpt_path,
43
+ state_dict=state_dict,
44
+ strict=strict,
45
+ assign=assign,
46
+ ignore_seg_mlp=ignore_seg_mlp,
47
+ ignore_seg_s2_mlp=ignore_seg_s2_mlp,
48
+ ignore_iou_mlp=ignore_iou_mlp)
49
+
50
+ def forward(self, feats, points, point_prompt, iter=1):
51
+ '''
52
+ feats: [K, N, 512]
53
+ points: [K, N, 3]
54
+ point_prompt: [K, N, 3]
55
+ '''
56
+ # print(feats.shape, points.shape, point_prompt.shape)
57
+ point_num = points.shape[1]
58
+ feats = feats.transpose(0, 1) # [N, K, 512]
59
+ points = points.transpose(0, 1) # [N, K, 3]
60
+ point_prompt = point_prompt.transpose(0, 1) # [N, K, 3]
61
+ feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3]
62
+
63
+ # 预测mask stage-1
64
+ pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K]
65
+ pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K]
66
+ pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K]
67
+ pred_mask = torch.stack(
68
+ [pred_mask_1, pred_mask_2, pred_mask_3], dim=-1
69
+ ) # [N, K, 3]
70
+
71
+ for _ in range(iter):
72
+ # 预测mask stage-2
73
+ feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3]
74
+ feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512]
75
+ feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512]
76
+ feats_seg_global = feats_seg_global.unsqueeze(0).repeat(
77
+ point_num, 1, 1
78
+ ) # [N, K, 512]
79
+ feats_seg_3 = torch.cat(
80
+ [feats_seg_global, feats_seg_2], dim=-1
81
+ ) # [N, K, 512+3+3+3+512]
82
+ pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K]
83
+ pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K]
84
+ pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K]
85
+ pred_mask_s2 = torch.stack(
86
+ [pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1
87
+ ) # [N,, K 3]
88
+ pred_mask = pred_mask_s2
89
+
90
+ mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K]
91
+ mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K]
92
+ mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K]
93
+
94
+ feats_iou = torch.cat(
95
+ [feats_seg_global, feats_seg, pred_mask_s2], dim=-1
96
+ ) # [N, K, 512+3+3+3+512]
97
+ feats_iou = self.iou_mlp(feats_iou) # [N, K, 512]
98
+ feats_iou = torch.max(feats_iou, dim=0).values # [K, 512]
99
+ pred_iou = self.iou_mlp_out(feats_iou) # [K, 3]
100
+ pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3]
101
+
102
+ mask_1 = mask_1.transpose(0, 1) # [K, N]
103
+ mask_2 = mask_2.transpose(0, 1) # [K, N]
104
+ mask_3 = mask_3.transpose(0, 1) # [K, N]
105
+
106
+ return mask_1, mask_2, mask_3, pred_iou
107
+
108
+
109
+ def normalize_pc(pc):
110
+ """
111
+ pc: (N, 3)
112
+ """
113
+ max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0)
114
+ center = (max_ + min_) / 2
115
+ scale = (max_ - min_) / 2
116
+ scale = np.max(np.abs(scale))
117
+ pc = (pc - center) / (scale + 1e-10)
118
+ return pc
119
+
120
+
121
+ @torch.no_grad()
122
+ def get_feat(model, points, normals):
123
+ data_dict = {
124
+ "coord": points,
125
+ "normal": normals,
126
+ "color": np.ones_like(points),
127
+ "batch": np.zeros(points.shape[0], dtype=np.int64),
128
+ }
129
+ data_dict = model.transform(data_dict)
130
+ for k in data_dict:
131
+ if isinstance(data_dict[k], torch.Tensor):
132
+ data_dict[k] = data_dict[k].cuda()
133
+ point = model.sonata(data_dict)
134
+ while "pooling_parent" in point.keys():
135
+ assert "pooling_inverse" in point.keys()
136
+ parent = point.pop("pooling_parent")
137
+ inverse = point.pop("pooling_inverse")
138
+ parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
139
+ point = parent
140
+ feat = point.feat # [M, 1232]
141
+ feat = model.mlp(feat) # [M, 512]
142
+ feat = feat[point.inverse] # [N, 512]
143
+ feats = feat
144
+ return feats
145
+
146
+
147
+ @torch.no_grad()
148
+ def get_mask(model, feats, points, point_prompt, iter=1):
149
+ """
150
+ feats: [N, 512]
151
+ points: [N, 3]
152
+ point_prompt: [K, 3]
153
+ """
154
+ point_num = points.shape[0]
155
+ prompt_num = point_prompt.shape[0]
156
+ feats = feats.unsqueeze(1) # [N, 1, 512]
157
+ feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512]
158
+ points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3]
159
+ points = points.repeat(1, prompt_num, 1) # [N, K, 3]
160
+ prompt_coord = (
161
+ torch.from_numpy(point_prompt).float().cuda().unsqueeze(0)
162
+ ) # [1, K, 3]
163
+ prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3]
164
+
165
+ feats = feats.transpose(0, 1) # [K, N, 512]
166
+ points = points.transpose(0, 1) # [K, N, 3]
167
+ prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3]
168
+
169
+ mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter)
170
+
171
+ mask_1 = mask_1.transpose(0, 1) # [N, K]
172
+ mask_2 = mask_2.transpose(0, 1) # [N, K]
173
+ mask_3 = mask_3.transpose(0, 1) # [N, K]
174
+
175
+ mask_1 = mask_1.detach().cpu().numpy() > 0.5
176
+ mask_2 = mask_2.detach().cpu().numpy() > 0.5
177
+ mask_3 = mask_3.detach().cpu().numpy() > 0.5
178
+
179
+ org_iou = pred_iou.detach().cpu().numpy() # [K, 3]
180
+
181
+ return mask_1, mask_2, mask_3, org_iou
182
+
183
+
184
+ def cal_iou(m1, m2):
185
+ return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2))
186
+
187
+
188
+ def cal_single_iou(m1, m2):
189
+ return np.sum(np.logical_and(m1, m2)) / np.sum(m1)
190
+
191
+
192
+ def iou_3d(box1, box2, signle=None):
193
+ """
194
+ 计算两个三维边界框的交并比 (IoU)
195
+
196
+ 参数:
197
+ box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max]
198
+ box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max]
199
+
200
+ 返回:
201
+ float: 交并比 (IoU) 值
202
+ """
203
+ # 计算交集的坐标
204
+ intersection_xmin = max(box1[0], box2[0])
205
+ intersection_ymin = max(box1[1], box2[1])
206
+ intersection_zmin = max(box1[2], box2[2])
207
+ intersection_xmax = min(box1[3], box2[3])
208
+ intersection_ymax = min(box1[4], box2[4])
209
+ intersection_zmax = min(box1[5], box2[5])
210
+
211
+ # 判断是否有交集
212
+ if (
213
+ intersection_xmin >= intersection_xmax
214
+ or intersection_ymin >= intersection_ymax
215
+ or intersection_zmin >= intersection_zmax
216
+ ):
217
+ return 0.0 # 无交集
218
+
219
+ # 计算交集的体积
220
+ intersection_volume = (
221
+ (intersection_xmax - intersection_xmin)
222
+ * (intersection_ymax - intersection_ymin)
223
+ * (intersection_zmax - intersection_zmin)
224
+ )
225
+
226
+ # 计算两个盒子的体积
227
+ box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2])
228
+ box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2])
229
+
230
+ if signle is None:
231
+ # 计算并集的体积
232
+ union_volume = box1_volume + box2_volume - intersection_volume
233
+ elif signle == "1":
234
+ union_volume = box1_volume
235
+ elif signle == "2":
236
+ union_volume = box2_volume
237
+ else:
238
+ raise ValueError("signle must be None or 1 or 2")
239
+
240
+ # 计算 IoU
241
+ iou = intersection_volume / union_volume if union_volume > 0 else 0.0
242
+ return iou
243
+
244
+
245
+ def cal_point_bbox_iou(p1, p2, signle=None):
246
+ min_p1 = np.min(p1, axis=0)
247
+ max_p1 = np.max(p1, axis=0)
248
+ min_p2 = np.min(p2, axis=0)
249
+ max_p2 = np.max(p2, axis=0)
250
+ box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]]
251
+ box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]]
252
+ return iou_3d(box1, box2, signle)
253
+
254
+
255
+ def cal_bbox_iou(points, m1, m2):
256
+ p1 = points[m1]
257
+ p2 = points[m2]
258
+ return cal_point_bbox_iou(p1, p2)
259
+
260
+
261
+ def clean_mesh(mesh):
262
+ """
263
+ mesh: trimesh.Trimesh
264
+ """
265
+ # 1. 合并接近的顶点
266
+ mesh.merge_vertices()
267
+
268
+ # 2. 删除重复的顶点
269
+ # 3. 删除重复的面片
270
+ mesh.process(True)
271
+ return mesh
272
+
273
+
274
+ def remove_outliers_iqr(data, factor=1.5):
275
+ """
276
+ 基于 IQR 去除离群值
277
+ :param data: 输入的列表或 NumPy 数组
278
+ :param factor: IQR 的倍数(默认 1.5)
279
+ :return: 去除离群值后的列表
280
+ """
281
+ data = np.array(data, dtype=np.float32)
282
+ q1 = np.percentile(data, 25) # 第一四分位数
283
+ q3 = np.percentile(data, 75) # 第三四分位数
284
+ iqr = q3 - q1 # 四分位距
285
+ lower_bound = q1 - factor * iqr
286
+ upper_bound = q3 + factor * iqr
287
+ return data[(data >= lower_bound) & (data <= upper_bound)].tolist()
288
+
289
+ def better_aabb(points):
290
+ x = points[:, 0]
291
+ y = points[:, 1]
292
+ z = points[:, 2]
293
+ x = remove_outliers_iqr(x)
294
+ y = remove_outliers_iqr(y)
295
+ z = remove_outliers_iqr(z)
296
+ min_xyz = np.array([np.min(x), np.min(y), np.min(z)])
297
+ max_xyz = np.array([np.max(x), np.max(y), np.max(z)])
298
+ return [min_xyz, max_xyz]
299
+
300
+ def fix_label(face_ids, adjacent_faces, use_aabb=False, mesh=None, show_info=False):
301
+ if use_aabb:
302
+ def _cal_aabb(face_ids, i, _points_org):
303
+ _part_mask = face_ids == i
304
+ _faces = mesh.faces[_part_mask]
305
+ _faces = np.reshape(_faces, (-1))
306
+ _points = mesh.vertices[_faces]
307
+ min_xyz, max_xyz = better_aabb(_points)
308
+ _part_mask = (
309
+ (_points_org[:, 0] >= min_xyz[0])
310
+ & (_points_org[:, 0] <= max_xyz[0])
311
+ & (_points_org[:, 1] >= min_xyz[1])
312
+ & (_points_org[:, 1] <= max_xyz[1])
313
+ & (_points_org[:, 2] >= min_xyz[2])
314
+ & (_points_org[:, 2] <= max_xyz[2])
315
+ )
316
+ _part_mask = np.reshape(_part_mask, (-1, 3))
317
+ _part_mask = np.all(_part_mask, axis=1)
318
+ return i, [min_xyz, max_xyz], _part_mask
319
+ with Timer("计算aabb"):
320
+ aabb = {}
321
+ unique_ids = np.unique(face_ids)
322
+ # print(max(unique_ids))
323
+ aabb_face_mask = {}
324
+ _faces = mesh.faces
325
+ _vertices = mesh.vertices
326
+ _faces = np.reshape(_faces, (-1))
327
+ _points = _vertices[_faces]
328
+ with ThreadPoolExecutor(max_workers=20) as executor:
329
+ futures = []
330
+ for i in unique_ids:
331
+ if i < 0:
332
+ continue
333
+ futures.append(executor.submit(_cal_aabb, face_ids, i, _points))
334
+ for future in futures:
335
+ res = future.result()
336
+ aabb[res[0]] = res[1]
337
+ aabb_face_mask[res[0]] = res[2]
338
+
339
+ # _faces = mesh.faces
340
+ # _vertices = mesh.vertices
341
+ # _faces = np.reshape(_faces, (-1))
342
+ # _points = _vertices[_faces]
343
+ # aabb_face_mask = cal_aabb_mask(_points, face_ids)
344
+
345
+ with Timer("合并mesh"):
346
+ loop_cnt = 1
347
+ changed = True
348
+ progress = tqdm(disable=not show_info)
349
+ no_mask_ids = np.where(face_ids < 0)[0].tolist()
350
+ faces_max = adjacent_faces.shape[0]
351
+ while changed and loop_cnt <= 50:
352
+ changed = False
353
+ # 获取无色面片
354
+ new_no_mask_ids = []
355
+ for i in no_mask_ids:
356
+ # if face_ids[i] < 0:
357
+ # 找邻居
358
+ if not (0 <= i < faces_max):
359
+ continue
360
+ _adj_faces = adjacent_faces[i]
361
+ _adj_ids = []
362
+ for j in _adj_faces:
363
+ if j == -1:
364
+ break
365
+ if face_ids[j] >= 0:
366
+ _tar_id = face_ids[j]
367
+ if use_aabb:
368
+ _mask = aabb_face_mask[_tar_id]
369
+ if _mask[i]:
370
+ _adj_ids.append(_tar_id)
371
+ else:
372
+ _adj_ids.append(_tar_id)
373
+ if len(_adj_ids) == 0:
374
+ new_no_mask_ids.append(i)
375
+ continue
376
+ _max_id = np.argmax(np.bincount(_adj_ids))
377
+ face_ids[i] = _max_id
378
+ changed = True
379
+ no_mask_ids = new_no_mask_ids
380
+ # print(loop_cnt)
381
+ progress.update(1)
382
+ # progress.set_description(f"合并mesh循环:{loop_cnt} {np.sum(face_ids < 0)}")
383
+ loop_cnt += 1
384
+ return face_ids
385
+
386
+
387
+ def save_mesh(save_path, mesh, face_ids, color_map):
388
+ face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8)
389
+ for i in tqdm(range(len(mesh.faces)), disable=True):
390
+ _max_id = face_ids[i]
391
+ if _max_id == -2:
392
+ continue
393
+ face_colors[i, :3] = color_map[_max_id]
394
+
395
+ mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
396
+ mesh_save.visual.face_colors = face_colors
397
+ mesh_save.export(save_path)
398
+ mesh_save.export(save_path.replace(".glb", ".ply"))
399
+ # print('保存mesh完成')
400
+
401
+ scene_mesh = trimesh.Scene()
402
+ scene_mesh.add_geometry(mesh_save)
403
+ unique_ids = np.unique(face_ids)
404
+ aabb = []
405
+ for i in unique_ids:
406
+ if i == -1 or i == -2:
407
+ continue
408
+ _part_mask = face_ids == i
409
+ _faces = mesh.faces[_part_mask]
410
+ _faces = np.reshape(_faces, (-1))
411
+ _points = mesh.vertices[_faces]
412
+ min_xyz, max_xyz = better_aabb(_points)
413
+ center = (min_xyz + max_xyz) / 2
414
+ size = max_xyz - min_xyz
415
+ box = trimesh.path.creation.box_outline()
416
+ box.vertices *= size
417
+ box.vertices += center
418
+ box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]])
419
+ box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8)
420
+ box.colors = box_color
421
+ scene_mesh.add_geometry(box)
422
+ min_xyz = np.min(_points, axis=0)
423
+ max_xyz = np.max(_points, axis=0)
424
+ aabb.append([min_xyz, max_xyz])
425
+ scene_mesh.export(save_path.replace(".glb", "_aabb.glb"))
426
+ aabb = np.array(aabb)
427
+ np.save(save_path.replace(".glb", "_aabb.npy"), aabb)
428
+ np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids)
429
+
430
+
431
+ def get_aabb_from_face_ids(mesh, face_ids):
432
+ unique_ids = np.unique(face_ids)
433
+ aabb = []
434
+ for i in unique_ids:
435
+ if i == -1 or i == -2:
436
+ continue
437
+ _part_mask = face_ids == i
438
+ _faces = mesh.faces[_part_mask]
439
+ _faces = np.reshape(_faces, (-1))
440
+ _points = mesh.vertices[_faces]
441
+ min_xyz = np.min(_points, axis=0)
442
+ max_xyz = np.max(_points, axis=0)
443
+ aabb.append([min_xyz, max_xyz])
444
+ return np.array(aabb)
445
+
446
+
447
+ def calculate_face_areas(mesh):
448
+ """
449
+ 计算每个三角形面片的面积
450
+ :param mesh: trimesh.Trimesh 对象
451
+ :return: 面片面积数组 (n_faces,)
452
+ """
453
+ return mesh.area_faces
454
+ # # 提取顶点和面片索引
455
+ # vertices = mesh.vertices
456
+ # faces = mesh.faces
457
+
458
+ # # 获取所有三个顶点的坐标
459
+ # v0 = vertices[faces[:, 0]]
460
+ # v1 = vertices[faces[:, 1]]
461
+ # v2 = vertices[faces[:, 2]]
462
+
463
+ # # 计算两个边向量
464
+ # edge1 = v1 - v0
465
+ # edge2 = v2 - v0
466
+
467
+ # # 计算叉积的模长(向量面积的两倍)
468
+ # cross_product = np.cross(edge1, edge2)
469
+ # areas = 0.5 * np.linalg.norm(cross_product, axis=1)
470
+
471
+ # return areas
472
+
473
+
474
+ def get_connected_region(face_ids, adjacent_faces, return_face_part_ids=False):
475
+ vis = [False] * len(face_ids)
476
+ parts = []
477
+ face_part_ids = np.ones_like(face_ids) * -1
478
+ for i in range(len(face_ids)):
479
+ if vis[i]:
480
+ continue
481
+ _part = []
482
+ _queue = [i]
483
+ while len(_queue) > 0:
484
+ _cur_face = _queue.pop(0)
485
+ if vis[_cur_face]:
486
+ continue
487
+ vis[_cur_face] = True
488
+ _part.append(_cur_face)
489
+ face_part_ids[_cur_face] = len(parts)
490
+ if not (0 <= _cur_face < adjacent_faces.shape[0]):
491
+ continue
492
+ _cur_face_id = face_ids[_cur_face]
493
+ _adj_faces = adjacent_faces[_cur_face]
494
+ for j in _adj_faces:
495
+ if j == -1:
496
+ break
497
+ if not vis[j] and face_ids[j] == _cur_face_id:
498
+ _queue.append(j)
499
+ parts.append(_part)
500
+ if return_face_part_ids:
501
+ return parts, face_part_ids
502
+ else:
503
+ return parts
504
+
505
+ def aabb_distance(box1, box2):
506
+ """
507
+ 计算两个轴对齐包围盒(AABB)之间的最近距离。
508
+ :param box1: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
509
+ :param box2: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
510
+ :return: 最近距离(浮点数)
511
+ """
512
+ # 解包坐标
513
+ min1, max1 = box1
514
+ min2, max2 = box2
515
+
516
+ # 计算各轴上的分离距离
517
+ dx = max(0, max2[0] - min1[0], max1[0] - min2[0]) # x轴分离距离
518
+ dy = max(0, max2[1] - min1[1], max1[1] - min2[1]) # y轴分离距离
519
+ dz = max(0, max2[2] - min1[2], max1[2] - min2[2]) # z轴分离距离
520
+
521
+ # 如果所有轴都重叠,则距离为0
522
+ if dx == 0 and dy == 0 and dz == 0:
523
+ return 0.0
524
+
525
+ # 计算欧几里得距离
526
+ return np.sqrt(dx**2 + dy**2 + dz**2)
527
+
528
+ def aabb_volume(aabb):
529
+ """
530
+ 计算轴对齐包围盒(AABB)的体积。
531
+ :param aabb: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
532
+ :return: 体积(浮点数)
533
+ """
534
+ # 解包坐标
535
+ min_xyz, max_xyz = aabb
536
+
537
+ # 计算体积
538
+ dx = max_xyz[0] - min_xyz[0]
539
+ dy = max_xyz[1] - min_xyz[1]
540
+ dz = max_xyz[2] - min_xyz[2]
541
+ return dx * dy * dz
542
+
543
+ def find_neighbor_part(parts, adjacent_faces, parts_aabb=None, parts_ids=None):
544
+ face2part = {}
545
+ for i, part in enumerate(parts):
546
+ for face in part:
547
+ face2part[face] = i
548
+ neighbor_parts = []
549
+ for i, part in enumerate(parts):
550
+ neighbor_part = set()
551
+ for face in part:
552
+ if not (0 <= face < adjacent_faces.shape[0]):
553
+ continue
554
+ for adj_face in adjacent_faces[face]:
555
+ if adj_face == -1:
556
+ break
557
+ if adj_face not in face2part:
558
+ continue
559
+ if face2part[adj_face] == i:
560
+ continue
561
+ if parts_ids is not None and parts_ids[face2part[adj_face]] in [-1, -2]:
562
+ continue
563
+ neighbor_part.add(face2part[adj_face])
564
+ neighbor_part = list(neighbor_part)
565
+ if parts_aabb is not None and parts_ids is not None and (parts_ids[i] == -1 or parts_ids[i] == -2) and len(neighbor_part) == 0:
566
+ min_dis = np.inf
567
+ min_idx = -1
568
+ for j, _part in tqdm(enumerate(parts)):
569
+ if j == i:
570
+ continue
571
+ if parts_ids[j] == -1 or parts_ids[j] == -2:
572
+ continue
573
+ aabb_1 = parts_aabb[i]
574
+ aabb_2 = parts_aabb[j]
575
+ dis = aabb_distance(aabb_1, aabb_2)
576
+ if dis < min_dis:
577
+ min_dis = dis
578
+ min_idx = j
579
+ elif dis == min_dis:
580
+ if aabb_volume(parts_aabb[j]) < aabb_volume(parts_aabb[min_idx]):
581
+ min_idx = j
582
+ neighbor_part = [min_idx]
583
+ neighbor_parts.append(neighbor_part)
584
+ return neighbor_parts
585
+
586
+
587
+ def do_post_process(face_areas, parts, adjacent_faces, face_ids, threshold=0.95, show_info=False):
588
+ # # 获取邻接面片
589
+ # mesh_save = mesh.copy()
590
+ # face_adjacency = mesh.face_adjacency
591
+ # adjacent_faces = {}
592
+ # for face1, face2 in face_adjacency:
593
+ # if face1 not in adjacent_faces:
594
+ # adjacent_faces[face1] = []
595
+ # if face2 not in adjacent_faces:
596
+ # adjacent_faces[face2] = []
597
+ # adjacent_faces[face1].append(face2)
598
+ # adjacent_faces[face2].append(face1)
599
+
600
+ # parts = get_connected_region(face_ids, adjacent_faces)
601
+
602
+
603
+ unique_ids = np.unique(face_ids)
604
+ if show_info:
605
+ print(f"连通区域数量:{len(parts)}")
606
+ print(f"ID数量:{len(unique_ids)}")
607
+
608
+ # face_areas = calculate_face_areas(mesh)
609
+ total_area = np.sum(face_areas)
610
+ if show_info:
611
+ print(f"总面积:{total_area}")
612
+ part_areas = []
613
+ for i, part in enumerate(parts):
614
+ part_area = np.sum(face_areas[part])
615
+ part_areas.append(float(part_area / total_area))
616
+
617
+ sorted_parts = sorted(zip(part_areas, parts), key=lambda x: x[0], reverse=True)
618
+ parts = [x[1] for x in sorted_parts]
619
+ part_areas = [x[0] for x in sorted_parts]
620
+ integral_part_areas = np.cumsum(part_areas)
621
+
622
+ neighbor_parts = find_neighbor_part(parts, adjacent_faces)
623
+
624
+ new_face_ids = face_ids.copy()
625
+
626
+ for i, part in enumerate(parts):
627
+ if integral_part_areas[i] > threshold and part_areas[i] < 0.01:
628
+ if len(neighbor_parts[i]) > 0:
629
+ max_area = 0
630
+ max_part = -1
631
+ for j in neighbor_parts[i]:
632
+ if integral_part_areas[j] > threshold:
633
+ continue
634
+ if part_areas[j] > max_area:
635
+ max_area = part_areas[j]
636
+ max_part = j
637
+ if max_part != -1:
638
+ if show_info:
639
+ print(f"合并mesh:{i} {max_part}")
640
+ parts[max_part].extend(part)
641
+ parts[i] = []
642
+ target_face_id = face_ids[parts[max_part][0]]
643
+ for face in part:
644
+ new_face_ids[face] = target_face_id
645
+
646
+ return new_face_ids
647
+
648
+
649
+ def do_no_mask_process(parts, face_ids):
650
+ # # 获取邻接面片
651
+ # mesh_save = mesh.copy()
652
+ # face_adjacency = mesh.face_adjacency
653
+ # adjacent_faces = {}
654
+ # for face1, face2 in face_adjacency:
655
+ # if face1 not in adjacent_faces:
656
+ # adjacent_faces[face1] = []
657
+ # if face2 not in adjacent_faces:
658
+ # adjacent_faces[face2] = []
659
+ # adjacent_faces[face1].append(face2)
660
+ # adjacent_faces[face2].append(face1)
661
+ # parts = get_connected_region(face_ids, adjacent_faces)
662
+
663
+ unique_ids = np.unique(face_ids)
664
+ max_id = np.max(unique_ids)
665
+ if -1 or -2 in unique_ids:
666
+ new_face_ids = face_ids.copy()
667
+ for i, part in enumerate(parts):
668
+ if face_ids[part[0]] == -1 or face_ids[part[0]] == -2:
669
+ for face in part:
670
+ new_face_ids[face] = max_id + 1
671
+ max_id += 1
672
+ return new_face_ids
673
+ else:
674
+ return face_ids
675
+
676
+
677
+ def union_aabb(aabb1, aabb2):
678
+ min_xyz1 = aabb1[0]
679
+ max_xyz1 = aabb1[1]
680
+ min_xyz2 = aabb2[0]
681
+ max_xyz2 = aabb2[1]
682
+ min_xyz = np.minimum(min_xyz1, min_xyz2)
683
+ max_xyz = np.maximum(max_xyz1, max_xyz2)
684
+ return [min_xyz, max_xyz]
685
+
686
+
687
+ def aabb_increase(aabb1, aabb2):
688
+ min_xyz_before = aabb1[0]
689
+ max_xyz_before = aabb1[1]
690
+ min_xyz_after, max_xyz_after = union_aabb(aabb1, aabb2)
691
+ min_xyz_increase = np.abs(min_xyz_after - min_xyz_before) / np.abs(min_xyz_before)
692
+ max_xyz_increase = np.abs(max_xyz_after - max_xyz_before) / np.abs(max_xyz_before)
693
+ return min_xyz_increase, max_xyz_increase
694
+
695
+ def sort_multi_list(multi_list, key=lambda x: x[0], reverse=False):
696
+ '''
697
+ multi_list: [list1, list2, list3, list4, ...], len(list1)=N, len(list2)=N, len(list3)=N, ...
698
+ key: 排序函数,默认按第一个元素排序
699
+ reverse: 排序顺序,默认降序
700
+ return:
701
+ [list1, list2, list3, list4, ...]: 按同一个顺序排序后的多个list
702
+ '''
703
+ sorted_list = sorted(zip(*multi_list), key=key, reverse=reverse)
704
+ return zip(*sorted_list)
705
+
706
+
707
+ class Timer:
708
+ STATE = True
709
+ def __init__(self, name):
710
+ self.name = name
711
+
712
+ def __enter__(self):
713
+ if not Timer.STATE:
714
+ return
715
+ self.start_time = time.time()
716
+ return self # 可以返回 self 以便在 with 块内访问
717
+
718
+ def __exit__(self, exc_type, exc_val, exc_tb):
719
+ if not Timer.STATE:
720
+ return
721
+ self.end_time = time.time()
722
+ self.elapsed_time = self.end_time - self.start_time
723
+ print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒")
724
+
725
+ ###################### NUMBA 加速 ######################
726
+ @njit
727
+ def build_adjacent_faces_numba(face_adjacency):
728
+ """
729
+ 使用 Numba 加速构建邻接面片数组。
730
+ :param face_adjacency: (N, 2) numpy 数组,包含邻接面片对。
731
+ :return:
732
+ - adj_list: 一维数组,存储所有邻接面片。
733
+ - offsets: 一维数组,记录每个面片的邻接起始位置。
734
+ """
735
+ n_faces = np.max(face_adjacency) + 1 # 总面片数
736
+ n_edges = face_adjacency.shape[0] # 总邻接边数
737
+
738
+ # 第一步:统计每个面片的邻接数量(度数)
739
+ degrees = np.zeros(n_faces, dtype=np.int32)
740
+ for i in range(n_edges):
741
+ f1, f2 = face_adjacency[i]
742
+ degrees[f1] += 1
743
+ degrees[f2] += 1
744
+ max_degree = np.max(degrees) # 最大度数
745
+
746
+ adjacent_faces = np.ones((n_faces, max_degree), dtype=np.int32) * -1 # 邻接面片数组
747
+ adjacent_faces_count = np.zeros(n_faces, dtype=np.int32) # 邻接面片计数器
748
+ for i in range(n_edges):
749
+ f1, f2 = face_adjacency[i]
750
+ adjacent_faces[f1, adjacent_faces_count[f1]] = f2
751
+ adjacent_faces_count[f1] += 1
752
+ adjacent_faces[f2, adjacent_faces_count[f2]] = f1
753
+ adjacent_faces_count[f2] += 1
754
+ return adjacent_faces
755
+ ###################### NUMBA 加速 ######################
756
+
757
+ def mesh_sam(
758
+ model,
759
+ mesh,
760
+ save_path,
761
+ point_num=100000,
762
+ prompt_num=400,
763
+ save_mid_res=False,
764
+ show_info=False,
765
+ post_process=False,
766
+ threshold=0.95,
767
+ clean_mesh_flag=True,
768
+ seed=42,
769
+ prompt_bs=32,
770
+ ):
771
+ with Timer("加载mesh"):
772
+ model, model_parallel = model
773
+ if clean_mesh_flag:
774
+ mesh = clean_mesh(mesh)
775
+ mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
776
+ if show_info:
777
+ print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}")
778
+
779
+ point_num = 100000
780
+ prompt_num = 400
781
+ with Timer("获取邻接面片"):
782
+ face_adjacency = mesh.face_adjacency
783
+ with Timer("处理邻接面片"):
784
+ adjacent_faces = build_adjacent_faces_numba(face_adjacency)
785
+
786
+
787
+ with Timer("采样点云"):
788
+ _points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed)
789
+ _points_org = _points.copy()
790
+ _points = normalize_pc(_points)
791
+ normals = mesh.face_normals[face_idx]
792
+ if show_info:
793
+ print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}")
794
+
795
+ with Timer("获取特征"):
796
+ _feats = get_feat(model, _points, normals)
797
+ if show_info:
798
+ print("预处理特征")
799
+
800
+ if save_mid_res:
801
+ feat_save = _feats.float().detach().cpu().numpy()
802
+ data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
803
+ pca = PCA(n_components=3)
804
+ data_reduced = pca.fit_transform(data_scaled)
805
+ data_reduced = (data_reduced - data_reduced.min()) / (
806
+ data_reduced.max() - data_reduced.min()
807
+ )
808
+ _colors_pca = (data_reduced * 255).astype(np.uint8)
809
+ pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca)
810
+ pc_save.export(os.path.join(save_path, "point_pca.glb"))
811
+ pc_save.export(os.path.join(save_path, "point_pca.ply"))
812
+ if show_info:
813
+ print("PCA获取特征颜色")
814
+
815
+ with Timer("FPS采样提示点"):
816
+ fps_idx = fpsample.fps_sampling(_points, prompt_num)
817
+ _point_prompts = _points[fps_idx]
818
+ if save_mid_res:
819
+ trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
820
+ os.path.join(save_path, "point_prompts_pca.glb")
821
+ )
822
+ trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
823
+ os.path.join(save_path, "point_prompts_pca.ply")
824
+ )
825
+ if show_info:
826
+ print("采样完成")
827
+
828
+ with Timer("推理"):
829
+ bs = prompt_bs
830
+ step_num = prompt_num // bs + 1
831
+ mask_res = []
832
+ iou_res = []
833
+ for i in tqdm(range(step_num), disable=not show_info):
834
+ cur_propmt = _point_prompts[bs * i : bs * (i + 1)]
835
+ pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask(
836
+ model_parallel, _feats, _points, cur_propmt
837
+ )
838
+ pred_mask = np.stack(
839
+ [pred_mask_1, pred_mask_2, pred_mask_3], axis=-1
840
+ ) # [N, K, 3]
841
+ max_idx = np.argmax(pred_iou, axis=-1) # [K]
842
+ for j in range(max_idx.shape[0]):
843
+ mask_res.append(pred_mask[:, j, max_idx[j]])
844
+ iou_res.append(pred_iou[j, max_idx[j]])
845
+ mask_res = np.stack(mask_res, axis=-1) # [N, K]
846
+ if show_info:
847
+ print("prmopt 推理完成")
848
+
849
+ with Timer("根据IOU排序"):
850
+ iou_res = np.array(iou_res).tolist()
851
+ mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)]
852
+ mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True)
853
+ mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)]
854
+ iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)]
855
+
856
+ with Timer("NMS"):
857
+ clusters = defaultdict(list)
858
+ with ThreadPoolExecutor(max_workers=20) as executor:
859
+ for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
860
+ _mask = mask_sorted[i]
861
+ futures = []
862
+ for j in clusters.keys():
863
+ futures.append(executor.submit(cal_iou, _mask, mask_sorted[j]))
864
+
865
+ for j, future in zip(clusters.keys(), futures):
866
+ if future.result() > 0.9:
867
+ clusters[j].append(i)
868
+ break
869
+ else:
870
+ clusters[i].append(i)
871
+
872
+ if show_info:
873
+ print(f"NMS完成,mask数量:{len(clusters)}")
874
+
875
+ if save_mid_res:
876
+ part_mask_save_path = os.path.join(save_path, "part_mask")
877
+ if os.path.exists(part_mask_save_path):
878
+ shutil.rmtree(part_mask_save_path)
879
+ os.makedirs(part_mask_save_path, exist_ok=True)
880
+ for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info):
881
+ cluster_num = len(clusters[i])
882
+ cluster_iou = iou_sorted[i]
883
+ cluster_area = np.sum(mask_sorted[i])
884
+ if cluster_num <= 2:
885
+ continue
886
+ mask_save = mask_sorted[i]
887
+ mask_save = np.expand_dims(mask_save, axis=-1)
888
+ mask_save = np.repeat(mask_save, 3, axis=-1)
889
+ mask_save = (mask_save * 255).astype(np.uint8)
890
+ point_save = trimesh.points.PointCloud(_points, colors=mask_save)
891
+ point_save.export(
892
+ os.path.join(
893
+ part_mask_save_path,
894
+ f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
895
+ )
896
+ )
897
+
898
+ # 过滤只有一个mask的cluster
899
+ with Timer("过滤只有一个mask的cluster"):
900
+ filtered_clusters = []
901
+ other_clusters = []
902
+ for i in clusters.keys():
903
+ if len(clusters[i]) > 2:
904
+ filtered_clusters.append(i)
905
+ else:
906
+ other_clusters.append(i)
907
+ if show_info:
908
+ print(
909
+ f"过滤前:{len(clusters)} 个cluster,"
910
+ f"过滤后:{len(filtered_clusters)} 个cluster"
911
+ )
912
+
913
+ # 再次合并
914
+ with Timer("再次合并"):
915
+ filtered_clusters_num = len(filtered_clusters)
916
+ cluster2 = {}
917
+ is_union = [False] * filtered_clusters_num
918
+ for i in range(filtered_clusters_num):
919
+ if is_union[i]:
920
+ continue
921
+ cur_cluster = filtered_clusters[i]
922
+ cluster2[cur_cluster] = [cur_cluster]
923
+ for j in range(i + 1, filtered_clusters_num):
924
+ if is_union[j]:
925
+ continue
926
+ tar_cluster = filtered_clusters[j]
927
+ if (
928
+ cal_bbox_iou(
929
+ _points, mask_sorted[tar_cluster], mask_sorted[cur_cluster]
930
+ )
931
+ > 0.5
932
+ ):
933
+ cluster2[cur_cluster].append(tar_cluster)
934
+ is_union[j] = True
935
+ if show_info:
936
+ print(f"再次合并,合并数量:{len(cluster2.keys())}")
937
+
938
+ with Timer("计算没有mask的点"):
939
+ no_mask = np.ones(point_num)
940
+ for i in cluster2:
941
+ part_mask = mask_sorted[i]
942
+ no_mask[part_mask] = 0
943
+ if show_info:
944
+ print(
945
+ f"{np.sum(no_mask == 1)} 个点没有mask,"
946
+ f" 占比:{np.sum(no_mask == 1) / point_num:.4f}"
947
+ )
948
+
949
+ with Timer("修补遗漏mask"):
950
+ # 查询漏掉的mask
951
+ for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info):
952
+ if i in cluster2:
953
+ continue
954
+ part_mask = mask_sorted[i]
955
+ _iou = cal_single_iou(part_mask, no_mask)
956
+ if _iou > 0.7:
957
+ cluster2[i] = [i]
958
+ no_mask[part_mask] = 0
959
+ if save_mid_res:
960
+ mask_save = mask_sorted[i]
961
+ mask_save = np.expand_dims(mask_save, axis=-1)
962
+ mask_save = np.repeat(mask_save, 3, axis=-1)
963
+ mask_save = (mask_save * 255).astype(np.uint8)
964
+ point_save = trimesh.points.PointCloud(_points, colors=mask_save)
965
+ cluster_iou = iou_sorted[i]
966
+ cluster_area = int(np.sum(mask_sorted[i]))
967
+ cluster_num = 1
968
+ point_save.export(
969
+ os.path.join(
970
+ part_mask_save_path,
971
+ f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
972
+ )
973
+ )
974
+ if show_info:
975
+ print(f"修补遗漏mask:{len(cluster2.keys())}")
976
+
977
+ with Timer("计算点云最终mask"):
978
+ final_mask = list(cluster2.keys())
979
+ final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask]
980
+ final_mask_area = [
981
+ [final_mask[i], final_mask_area[i]] for i in range(len(final_mask))
982
+ ]
983
+ final_mask_area_sorted = sorted(final_mask_area, key=lambda x: x[1], reverse=True)
984
+ final_mask_sorted = [
985
+ final_mask_area_sorted[i][0] for i in range(len(final_mask_area))
986
+ ]
987
+ final_mask_area_sorted = [
988
+ final_mask_area_sorted[i][1] for i in range(len(final_mask_area))
989
+ ]
990
+ if show_info:
991
+ print(f"最终mask数量:{len(final_mask_sorted)}")
992
+
993
+ with Timer("点云上色"):
994
+ # 生成color map
995
+ color_map = {}
996
+ for i in final_mask_sorted:
997
+ part_color = np.random.rand(3) * 255
998
+ color_map[i] = part_color
999
+ # print(color_map)
1000
+
1001
+ result_mask = -np.ones(point_num, dtype=np.int64)
1002
+ for i in final_mask_sorted:
1003
+ part_mask = mask_sorted[i]
1004
+ result_mask[part_mask] = i
1005
+ if save_mid_res:
1006
+ # 保存点云结果
1007
+ result_colors = np.zeros_like(_colors_pca)
1008
+ for i in final_mask_sorted:
1009
+ part_color = color_map[i]
1010
+ part_mask = mask_sorted[i]
1011
+ result_colors[part_mask, :3] = part_color
1012
+ trimesh.points.PointCloud(_points, colors=result_colors).export(
1013
+ os.path.join(save_path, "auto_mask_cluster.glb")
1014
+ )
1015
+ trimesh.points.PointCloud(_points, colors=result_colors).export(
1016
+ os.path.join(save_path, "auto_mask_cluster.ply")
1017
+ )
1018
+ if show_info:
1019
+ print("保存点云完成")
1020
+
1021
+
1022
+ with Timer("投影Mesh并统计label"):
1023
+ # 保存mesh结果
1024
+ face_seg_res = {}
1025
+ for i in final_mask_sorted:
1026
+ _part_mask = result_mask == i
1027
+ _face_idx = face_idx[_part_mask]
1028
+ for k in _face_idx:
1029
+ if k not in face_seg_res:
1030
+ face_seg_res[k] = []
1031
+ face_seg_res[k].append(i)
1032
+ _part_mask = result_mask == -1
1033
+ _face_idx = face_idx[_part_mask]
1034
+ for k in _face_idx:
1035
+ if k not in face_seg_res:
1036
+ face_seg_res[k] = []
1037
+ face_seg_res[k].append(-1)
1038
+
1039
+ face_ids = -np.ones(len(mesh.faces), dtype=np.int64) * 2
1040
+ for i in tqdm(face_seg_res, leave=False, disable=True):
1041
+ _seg_ids = np.array(face_seg_res[i])
1042
+ # 获取最多的seg_id
1043
+ _max_id = np.argmax(np.bincount(_seg_ids + 2)) - 2
1044
+ face_ids[i] = _max_id
1045
+ face_ids_org = face_ids.copy()
1046
+ if show_info:
1047
+ print("生成face_ids完成")
1048
+
1049
+
1050
+ with Timer("第一次修复face_ids"):
1051
+ face_ids += 1
1052
+ face_ids = fix_label(face_ids, adjacent_faces, mesh=mesh, show_info=show_info)
1053
+ face_ids -= 1
1054
+ if show_info:
1055
+ print("修复face_ids完成")
1056
+
1057
+ color_map[-1] = np.array([255, 0, 0], dtype=np.uint8)
1058
+
1059
+ if save_mid_res:
1060
+ save_mesh(
1061
+ os.path.join(save_path, "auto_mask_mesh.glb"), mesh, face_ids, color_map
1062
+ )
1063
+ save_mesh(
1064
+ os.path.join(save_path, "auto_mask_mesh_org.glb"),
1065
+ mesh,
1066
+ face_ids_org,
1067
+ color_map,
1068
+ )
1069
+ if show_info:
1070
+ print("保存mesh结果完成")
1071
+
1072
+ with Timer("计算连通区域"):
1073
+ face_areas = calculate_face_areas(mesh)
1074
+ mesh_total_area = np.sum(face_areas)
1075
+ parts = get_connected_region(face_ids, adjacent_faces)
1076
+ connected_parts, _face_connected_parts_ids = get_connected_region(np.ones_like(face_ids), adjacent_faces, return_face_part_ids=True)
1077
+ if show_info:
1078
+ print(f"共{len(parts)}个mesh")
1079
+ with Timer("排序连通区域"):
1080
+ parts_cp_idx = []
1081
+ for x in parts:
1082
+ _face_idx = x[0]
1083
+ parts_cp_idx.append(_face_connected_parts_ids[_face_idx])
1084
+ parts_cp_idx = np.array(parts_cp_idx)
1085
+ parts_areas = [float(np.sum(face_areas[x])) for x in parts]
1086
+ connected_parts_areas = [float(np.sum(face_areas[x])) for x in connected_parts]
1087
+ parts_cp_areas = [connected_parts_areas[x] for x in parts_cp_idx]
1088
+ parts_sorted, parts_areas_sorted, parts_cp_areas_sorted = sort_multi_list([parts, parts_areas, parts_cp_areas], key=lambda x: x[1], reverse=True)
1089
+
1090
+ with Timer("去除面积过小的区域"):
1091
+ filtered_parts = []
1092
+ other_parts = []
1093
+ for i in range(len(parts_sorted)):
1094
+ parts = parts_sorted[i]
1095
+ area = parts_areas_sorted[i]
1096
+ cp_area = parts_cp_areas_sorted[i]
1097
+ if area / (cp_area+1e-7) > 0.001:
1098
+ filtered_parts.append(i)
1099
+ else:
1100
+ other_parts.append(i)
1101
+ if show_info:
1102
+ print(f"保留{len(filtered_parts)}个mesh, 其他{len(other_parts)}个mesh")
1103
+
1104
+ with Timer("去除面积过小区域的label"):
1105
+ face_ids_2 = face_ids.copy()
1106
+ part_num = len(cluster2.keys())
1107
+ for j in other_parts:
1108
+ parts = parts_sorted[j]
1109
+ for i in parts:
1110
+ face_ids_2[i] = -1
1111
+
1112
+ with Timer("第二次修复face_ids"):
1113
+ face_ids_3 = face_ids_2.copy()
1114
+ face_ids_3 = fix_label(face_ids_3, adjacent_faces, mesh=mesh, show_info=show_info)
1115
+
1116
+ if save_mid_res:
1117
+ save_mesh(
1118
+ os.path.join(save_path, "auto_mask_mesh_filtered_2.glb"),
1119
+ mesh,
1120
+ face_ids_3,
1121
+ color_map,
1122
+ )
1123
+ if show_info:
1124
+ print("保存mesh结果完成")
1125
+
1126
+ with Timer("第二次计算连通区域"):
1127
+ parts_2 = get_connected_region(face_ids_3, adjacent_faces)
1128
+ parts_areas_2 = [float(np.sum(face_areas[x])) for x in parts_2]
1129
+ parts_ids_2 = [face_ids_3[x[0]] for x in parts_2]
1130
+
1131
+ with Timer("添加过大的缺失part"):
1132
+ color_map_2 = copy.deepcopy(color_map)
1133
+ max_id = np.max(parts_ids_2)
1134
+ for i in range(len(parts_2)):
1135
+ _parts = parts_2[i]
1136
+ _area = parts_areas_2[i]
1137
+ _parts_id = face_ids_3[_parts[0]]
1138
+ if _area / mesh_total_area > 0.001:
1139
+ if _parts_id == -1 or _parts_id == -2:
1140
+ parts_ids_2[i] = max_id + 1
1141
+ max_id += 1
1142
+ color_map_2[max_id] = np.random.rand(3) * 255
1143
+ if show_info:
1144
+ print(f"新增part {max_id}")
1145
+ # else:
1146
+ # parts_ids_2[i] = -1
1147
+
1148
+ with Timer("赋值新的face_ids"):
1149
+ face_ids_4 = face_ids_3.copy()
1150
+ for i in range(len(parts_2)):
1151
+ _parts = parts_2[i]
1152
+ _parts_id = parts_ids_2[i]
1153
+ for j in _parts:
1154
+ face_ids_4[j] = _parts_id
1155
+ with Timer("计算part和label的aabb"):
1156
+ ids_aabb = {}
1157
+ unique_ids = np.unique(face_ids_4)
1158
+ for i in unique_ids:
1159
+ if i < 0:
1160
+ continue
1161
+ _part_mask = face_ids_4 == i
1162
+ _faces = mesh.faces[_part_mask]
1163
+ _faces = np.reshape(_faces, (-1))
1164
+ _points = mesh.vertices[_faces]
1165
+ min_xyz = np.min(_points, axis=0)
1166
+ max_xyz = np.max(_points, axis=0)
1167
+ ids_aabb[i] = [min_xyz, max_xyz]
1168
+
1169
+ parts_2_aabb = []
1170
+ for i in range(len(parts_2)):
1171
+ _parts = parts_2[i]
1172
+ _faces = mesh.faces[_parts]
1173
+ _faces = np.reshape(_faces, (-1))
1174
+ _points = mesh.vertices[_faces]
1175
+ min_xyz = np.min(_points, axis=0)
1176
+ max_xyz = np.max(_points, axis=0)
1177
+ parts_2_aabb.append([min_xyz, max_xyz])
1178
+
1179
+ with Timer("计算part的邻居"):
1180
+ parts_2_neighbor = find_neighbor_part(parts_2, adjacent_faces, parts_2_aabb, parts_ids_2)
1181
+ with Timer("合并无mask区域"):
1182
+ for i in range(len(parts_2)):
1183
+ _parts = parts_2[i]
1184
+ _ids = parts_ids_2[i]
1185
+ if _ids == -1 or _ids == -2:
1186
+ _cur_aabb = parts_2_aabb[i]
1187
+ _min_aabb_increase = 1e10
1188
+ _min_id = -1
1189
+ for j in parts_2_neighbor[i]:
1190
+ if parts_ids_2[j] == -1 or parts_ids_2[j] == -2:
1191
+ continue
1192
+ _tar_id = parts_ids_2[j]
1193
+ _tar_aabb = ids_aabb[_tar_id]
1194
+ _min_increase, _max_increase = aabb_increase(_tar_aabb, _cur_aabb)
1195
+ _increase = max(np.max(_min_increase), np.max(_max_increase))
1196
+ if _min_aabb_increase > _increase:
1197
+ _min_aabb_increase = _increase
1198
+ _min_id = _tar_id
1199
+ if _min_id >= 0:
1200
+ parts_ids_2[i] = _min_id
1201
+
1202
+
1203
+ with Timer("再次赋值新的face_ids"):
1204
+ face_ids_4 = face_ids_3.copy()
1205
+ for i in range(len(parts_2)):
1206
+ _parts = parts_2[i]
1207
+ _parts_id = parts_ids_2[i]
1208
+ for j in _parts:
1209
+ face_ids_4[j] = _parts_id
1210
+
1211
+ final_face_ids = face_ids_4
1212
+ if save_mid_res:
1213
+ save_mesh(
1214
+ os.path.join(save_path, "auto_mask_mesh_final.glb"),
1215
+ mesh,
1216
+ face_ids_4,
1217
+ color_map_2,
1218
+ )
1219
+
1220
+ if post_process:
1221
+ parts = get_connected_region(final_face_ids, adjacent_faces)
1222
+ final_face_ids = do_no_mask_process(parts, final_face_ids)
1223
+ face_ids_5 = do_post_process(face_areas, parts, adjacent_faces, face_ids_4, threshold, show_info=show_info)
1224
+ if save_mid_res:
1225
+ save_mesh(
1226
+ os.path.join(save_path, "auto_mask_mesh_final_post.glb"),
1227
+ mesh,
1228
+ face_ids_5,
1229
+ color_map_2,
1230
+ )
1231
+ final_face_ids = face_ids_5
1232
+ with Timer("计算最后的aabb"):
1233
+ aabb = get_aabb_from_face_ids(mesh, final_face_ids)
1234
+ return aabb, final_face_ids, mesh
1235
+
1236
+
1237
+ class AutoMask:
1238
+ def __init__(
1239
+ self,
1240
+ ckpt_path=None,
1241
+ point_num=100000,
1242
+ prompt_num=400,
1243
+ threshold=0.95,
1244
+ post_process=True,
1245
+ ):
1246
+ """
1247
+ ckpt_path: str, 模型路径
1248
+ point_num: int, 采样点数量
1249
+ prompt_num: int, 提示数量
1250
+ threshold: float, 阈值
1251
+ post_process: bool, 是否后处理
1252
+ """
1253
+ self.model = P3SAM()
1254
+ self.model.load_state_dict(ckpt_path)
1255
+ self.model.eval()
1256
+ self.model_parallel = torch.nn.DataParallel(self.model)
1257
+ self.model.cuda()
1258
+ self.model_parallel.cuda()
1259
+ self.point_num = point_num
1260
+ self.prompt_num = prompt_num
1261
+ self.threshold = threshold
1262
+ self.post_process = post_process
1263
+
1264
+ def predict_aabb(
1265
+ self, mesh, point_num=None, prompt_num=None, threshold=None, post_process=None, save_path=None, save_mid_res=False, show_info=True, clean_mesh_flag=True, seed=42, is_parallel=True, prompt_bs=32
1266
+ ):
1267
+ """
1268
+ Parameters:
1269
+ mesh: trimesh.Trimesh, 输入网格
1270
+ point_num: int, 采样点数量
1271
+ prompt_num: int, 提示数量
1272
+ threshold: float, 阈值
1273
+ post_process: bool, 是否后处理
1274
+ Returns:
1275
+ aabb: np.ndarray, 包围盒
1276
+ face_ids: np.ndarray, 面id
1277
+ """
1278
+ point_num = point_num if point_num is not None else self.point_num
1279
+ prompt_num = prompt_num if prompt_num is not None else self.prompt_num
1280
+ threshold = threshold if threshold is not None else self.threshold
1281
+ post_process = post_process if post_process is not None else self.post_process
1282
+ return mesh_sam(
1283
+ [self.model, self.model_parallel if is_parallel else self.model],
1284
+ mesh,
1285
+ save_path=save_path,
1286
+ point_num=point_num,
1287
+ prompt_num=prompt_num,
1288
+ threshold=threshold,
1289
+ post_process=post_process,
1290
+ show_info=show_info,
1291
+ save_mid_res=save_mid_res,
1292
+ clean_mesh_flag=clean_mesh_flag,
1293
+ seed=seed,
1294
+ prompt_bs=prompt_bs,
1295
+ )
1296
+
1297
+ def set_seed(seed):
1298
+ random.seed(seed)
1299
+ np.random.seed(seed)
1300
+ torch.manual_seed(seed)
1301
+ if torch.cuda.is_available():
1302
+ torch.cuda.manual_seed(seed)
1303
+ torch.cuda.manual_seed_all(seed)
1304
+ torch.backends.cudnn.deterministic = True
1305
+ torch.backends.cudnn.benchmark = False
1306
+
1307
+ if __name__ == '__main__':
1308
+ argparser = argparse.ArgumentParser()
1309
+ argparser.add_argument('--ckpt_path', type=str, default=None, help='模型路径')
1310
+ argparser.add_argument('--mesh_path', type=str, default='assets/1.glb', help='输入网格路径')
1311
+ argparser.add_argument('--output_path', type=str, default='results/1', help='保存路径')
1312
+ argparser.add_argument('--point_num', type=int, default=100000, help='采样点数量')
1313
+ argparser.add_argument('--prompt_num', type=int, default=400, help='提示数量')
1314
+ argparser.add_argument('--threshold', type=float, default=0.95, help='阈值')
1315
+ argparser.add_argument('--post_process', type=int, default=0, help='是否后处理')
1316
+ argparser.add_argument('--save_mid_res', type=int, default=1, help='是否保存中间结果')
1317
+ argparser.add_argument('--show_info', type=int, default=1, help='是否显示信息')
1318
+ argparser.add_argument('--show_time_info', type=int, default=1, help='是否显示时间信息')
1319
+ argparser.add_argument('--seed', type=int, default=42, help='随机种子')
1320
+ argparser.add_argument('--parallel', type=int, default=1, help='是否使用多卡')
1321
+ argparser.add_argument('--prompt_bs', type=int, default=32, help='提示点推理时的batch size大小')
1322
+ argparser.add_argument('--clean_mesh', type=int, default=1, help='是否清洗网格')
1323
+ args = argparser.parse_args()
1324
+ Timer.STATE = args.show_time_info
1325
+
1326
+
1327
+ output_path = args.output_path
1328
+ os.makedirs(output_path, exist_ok=True)
1329
+ ckpt_path = args.ckpt_path
1330
+ auto_mask = AutoMask(ckpt_path)
1331
+ mesh_path = args.mesh_path
1332
+ if os.path.isdir(mesh_path):
1333
+ for file in os.listdir(mesh_path):
1334
+ if not (file.endswith('.glb') or file.endswith('.obj') or file.endswith('.ply')):
1335
+ continue
1336
+ _mesh_path = os.path.join(mesh_path, file)
1337
+ _output_path = os.path.join(output_path, file[:-4])
1338
+ os.makedirs(_output_path, exist_ok=True)
1339
+ mesh = trimesh.load(_mesh_path, force='mesh')
1340
+ set_seed(args.seed)
1341
+ aabb, face_ids, mesh = auto_mask.predict_aabb(mesh,
1342
+ save_path=_output_path,
1343
+ point_num=args.point_num,
1344
+ prompt_num=args.prompt_num,
1345
+ threshold=args.threshold,
1346
+ post_process=args.post_process,
1347
+ save_mid_res=args.save_mid_res,
1348
+ show_info=args.show_info,
1349
+ seed=args.seed,
1350
+ is_parallel=args.parallel,
1351
+ clean_mesh_flag=args.clean_mesh,)
1352
+ else:
1353
+ mesh = trimesh.load(mesh_path, force='mesh')
1354
+ set_seed(args.seed)
1355
+ aabb, face_ids, mesh = auto_mask.predict_aabb(mesh,
1356
+ save_path=output_path,
1357
+ point_num=args.point_num,
1358
+ prompt_num=args.prompt_num,
1359
+ threshold=args.threshold,
1360
+ post_process=args.post_process,
1361
+ save_mid_res=args.save_mid_res,
1362
+ show_info=args.show_info,
1363
+ seed=args.seed,
1364
+ is_parallel=args.parallel,
1365
+ clean_mesh_flag=args.clean_mesh,)
1366
+
1367
+ ###############################################
1368
+ ## 可以通过以下代码保存返回的结果
1369
+ ## You can save the returned result by the following code
1370
+ ################# save result #################
1371
+ # color_map = {}
1372
+ # unique_ids = np.unique(face_ids)
1373
+ # for i in unique_ids:
1374
+ # if i == -1:
1375
+ # continue
1376
+ # part_color = np.random.rand(3) * 255
1377
+ # color_map[i] = part_color
1378
+ # face_colors = []
1379
+ # for i in face_ids:
1380
+ # if i == -1:
1381
+ # face_colors.append([0, 0, 0])
1382
+ # else:
1383
+ # face_colors.append(color_map[i])
1384
+ # face_colors = np.array(face_colors).astype(np.uint8)
1385
+ # mesh_save = mesh.copy()
1386
+ # mesh_save.visual.face_colors = face_colors
1387
+ # mesh_save.export(os.path.join(output_path, 'auto_mask_mesh.glb'))
1388
+ # scene_mesh = trimesh.Scene()
1389
+ # scene_mesh.add_geometry(mesh_save)
1390
+ # for i in range(len(aabb)):
1391
+ # min_xyz, max_xyz = aabb[i]
1392
+ # center = (min_xyz + max_xyz) / 2
1393
+ # size = max_xyz - min_xyz
1394
+ # box = trimesh.path.creation.box_outline()
1395
+ # box.vertices *= size
1396
+ # box.vertices += center
1397
+ # scene_mesh.add_geometry(box)
1398
+ # scene_mesh.export(os.path.join(output_path, 'auto_mask_aabb.glb'))
1399
+ ################# save result #################
1400
+
1401
+ '''
1402
+ python auto_mask.py --parallel 0
1403
+ python auto_mask.py --ckpt_path ../weights/last.ckpt --mesh_path assets/1.glb --output_path results/1 --parallel 0
1404
+ python auto_mask.py --ckpt_path ../weights/last.ckpt --mesh_path assets --output_path results/all
1405
+ '''
P3-SAM/demo/auto_mask_no_postprocess.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import argparse
7
+ import trimesh
8
+ from sklearn.decomposition import PCA
9
+ import fpsample
10
+ from tqdm import tqdm
11
+ import threading
12
+ import random
13
+
14
+ # from tqdm.notebook import tqdm
15
+ import time
16
+ import copy
17
+ import shutil
18
+ from pathlib import Path
19
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
20
+ from collections import defaultdict
21
+
22
+ import numba
23
+ from numba import njit
24
+
25
+ sys.path.append("..")
26
+ from model import build_P3SAM, load_state_dict
27
+
28
+ from utils.chamfer3D.dist_chamfer_3D import chamfer_3DDist
29
+
30
+ cmd_loss = chamfer_3DDist()
31
+
32
+
33
+ class P3SAM(nn.Module):
34
+ def __init__(self):
35
+ super().__init__()
36
+ build_P3SAM(self)
37
+
38
+ def load_state_dict(self,
39
+ ckpt_path=None,
40
+ state_dict=None,
41
+ strict=True,
42
+ assign=False,
43
+ ignore_seg_mlp=False,
44
+ ignore_seg_s2_mlp=False,
45
+ ignore_iou_mlp=False):
46
+ load_state_dict(self,
47
+ ckpt_path=ckpt_path,
48
+ state_dict=state_dict,
49
+ strict=strict,
50
+ assign=assign,
51
+ ignore_seg_mlp=ignore_seg_mlp,
52
+ ignore_seg_s2_mlp=ignore_seg_s2_mlp,
53
+ ignore_iou_mlp=ignore_iou_mlp)
54
+
55
+ def forward(self, feats, points, point_prompt, iter=1):
56
+ """
57
+ feats: [K, N, 512]
58
+ points: [K, N, 3]
59
+ point_prompt: [K, N, 3]
60
+ """
61
+ # print(feats.shape, points.shape, point_prompt.shape)
62
+ point_num = points.shape[1]
63
+ feats = feats.transpose(0, 1) # [N, K, 512]
64
+ points = points.transpose(0, 1) # [N, K, 3]
65
+ point_prompt = point_prompt.transpose(0, 1) # [N, K, 3]
66
+ feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3]
67
+
68
+ # 预测mask stage-1
69
+ pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K]
70
+ pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K]
71
+ pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K]
72
+ pred_mask = torch.stack(
73
+ [pred_mask_1, pred_mask_2, pred_mask_3], dim=-1
74
+ ) # [N, K, 3]
75
+
76
+ for _ in range(iter):
77
+ # 预测mask stage-2
78
+ feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3]
79
+ feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512]
80
+ feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512]
81
+ feats_seg_global = feats_seg_global.unsqueeze(0).repeat(
82
+ point_num, 1, 1
83
+ ) # [N, K, 512]
84
+ feats_seg_3 = torch.cat(
85
+ [feats_seg_global, feats_seg_2], dim=-1
86
+ ) # [N, K, 512+3+3+3+512]
87
+ pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K]
88
+ pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K]
89
+ pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K]
90
+ pred_mask_s2 = torch.stack(
91
+ [pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1
92
+ ) # [N,, K 3]
93
+ pred_mask = pred_mask_s2
94
+
95
+ mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K]
96
+ mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K]
97
+ mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K]
98
+
99
+ feats_iou = torch.cat(
100
+ [feats_seg_global, feats_seg, pred_mask_s2], dim=-1
101
+ ) # [N, K, 512+3+3+3+512]
102
+ feats_iou = self.iou_mlp(feats_iou) # [N, K, 512]
103
+ feats_iou = torch.max(feats_iou, dim=0).values # [K, 512]
104
+ pred_iou = self.iou_mlp_out(feats_iou) # [K, 3]
105
+ pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3]
106
+
107
+ mask_1 = mask_1.transpose(0, 1) # [K, N]
108
+ mask_2 = mask_2.transpose(0, 1) # [K, N]
109
+ mask_3 = mask_3.transpose(0, 1) # [K, N]
110
+
111
+ return mask_1, mask_2, mask_3, pred_iou
112
+
113
+
114
+ def normalize_pc(pc):
115
+ """
116
+ pc: (N, 3)
117
+ """
118
+ max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0)
119
+ center = (max_ + min_) / 2
120
+ scale = (max_ - min_) / 2
121
+ scale = np.max(np.abs(scale))
122
+ pc = (pc - center) / (scale + 1e-10)
123
+ return pc
124
+
125
+
126
+ @torch.no_grad()
127
+ def get_feat(model, points, normals):
128
+ data_dict = {
129
+ "coord": points,
130
+ "normal": normals,
131
+ "color": np.ones_like(points),
132
+ "batch": np.zeros(points.shape[0], dtype=np.int64),
133
+ }
134
+ data_dict = model.transform(data_dict)
135
+ for k in data_dict:
136
+ if isinstance(data_dict[k], torch.Tensor):
137
+ data_dict[k] = data_dict[k].cuda()
138
+ point = model.sonata(data_dict)
139
+ while "pooling_parent" in point.keys():
140
+ assert "pooling_inverse" in point.keys()
141
+ parent = point.pop("pooling_parent")
142
+ inverse = point.pop("pooling_inverse")
143
+ parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
144
+ point = parent
145
+ feat = point.feat # [M, 1232]
146
+ feat = model.mlp(feat) # [M, 512]
147
+ feat = feat[point.inverse] # [N, 512]
148
+ feats = feat
149
+ return feats
150
+
151
+
152
+ @torch.no_grad()
153
+ def get_mask(model, feats, points, point_prompt, iter=1):
154
+ """
155
+ feats: [N, 512]
156
+ points: [N, 3]
157
+ point_prompt: [K, 3]
158
+ """
159
+ point_num = points.shape[0]
160
+ prompt_num = point_prompt.shape[0]
161
+ feats = feats.unsqueeze(1) # [N, 1, 512]
162
+ feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512]
163
+ points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3]
164
+ points = points.repeat(1, prompt_num, 1) # [N, K, 3]
165
+ prompt_coord = (
166
+ torch.from_numpy(point_prompt).float().cuda().unsqueeze(0)
167
+ ) # [1, K, 3]
168
+ prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3]
169
+
170
+ feats = feats.transpose(0, 1) # [K, N, 512]
171
+ points = points.transpose(0, 1) # [K, N, 3]
172
+ prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3]
173
+
174
+ mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter)
175
+
176
+ mask_1 = mask_1.transpose(0, 1) # [N, K]
177
+ mask_2 = mask_2.transpose(0, 1) # [N, K]
178
+ mask_3 = mask_3.transpose(0, 1) # [N, K]
179
+
180
+ mask_1 = mask_1.detach().cpu().numpy() > 0.5
181
+ mask_2 = mask_2.detach().cpu().numpy() > 0.5
182
+ mask_3 = mask_3.detach().cpu().numpy() > 0.5
183
+
184
+ org_iou = pred_iou.detach().cpu().numpy() # [K, 3]
185
+
186
+ return mask_1, mask_2, mask_3, org_iou
187
+
188
+
189
+ def cal_iou(m1, m2):
190
+ return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2))
191
+
192
+
193
+ def cal_single_iou(m1, m2):
194
+ return np.sum(np.logical_and(m1, m2)) / np.sum(m1)
195
+
196
+
197
+ def iou_3d(box1, box2, signle=None):
198
+ """
199
+ 计算两个三维边界框的交并比 (IoU)
200
+
201
+ 参数:
202
+ box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max]
203
+ box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max]
204
+
205
+ 返回:
206
+ float: 交并比 (IoU) 值
207
+ """
208
+ # 计算交集的坐标
209
+ intersection_xmin = max(box1[0], box2[0])
210
+ intersection_ymin = max(box1[1], box2[1])
211
+ intersection_zmin = max(box1[2], box2[2])
212
+ intersection_xmax = min(box1[3], box2[3])
213
+ intersection_ymax = min(box1[4], box2[4])
214
+ intersection_zmax = min(box1[5], box2[5])
215
+
216
+ # 判断是否有交集
217
+ if (
218
+ intersection_xmin >= intersection_xmax
219
+ or intersection_ymin >= intersection_ymax
220
+ or intersection_zmin >= intersection_zmax
221
+ ):
222
+ return 0.0 # 无交集
223
+
224
+ # 计算交集的体积
225
+ intersection_volume = (
226
+ (intersection_xmax - intersection_xmin)
227
+ * (intersection_ymax - intersection_ymin)
228
+ * (intersection_zmax - intersection_zmin)
229
+ )
230
+
231
+ # 计算两个盒子的体积
232
+ box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2])
233
+ box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2])
234
+
235
+ if signle is None:
236
+ # 计算并集的体积
237
+ union_volume = box1_volume + box2_volume - intersection_volume
238
+ elif signle == "1":
239
+ union_volume = box1_volume
240
+ elif signle == "2":
241
+ union_volume = box2_volume
242
+ else:
243
+ raise ValueError("signle must be None or 1 or 2")
244
+
245
+ # 计算 IoU
246
+ iou = intersection_volume / union_volume if union_volume > 0 else 0.0
247
+ return iou
248
+
249
+
250
+ def cal_point_bbox_iou(p1, p2, signle=None):
251
+ min_p1 = np.min(p1, axis=0)
252
+ max_p1 = np.max(p1, axis=0)
253
+ min_p2 = np.min(p2, axis=0)
254
+ max_p2 = np.max(p2, axis=0)
255
+ box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]]
256
+ box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]]
257
+ return iou_3d(box1, box2, signle)
258
+
259
+
260
+ def cal_bbox_iou(points, m1, m2):
261
+ p1 = points[m1]
262
+ p2 = points[m2]
263
+ return cal_point_bbox_iou(p1, p2)
264
+
265
+
266
+ def clean_mesh(mesh):
267
+ """
268
+ mesh: trimesh.Trimesh
269
+ """
270
+ # 1. 合并接近的顶点
271
+ mesh.merge_vertices()
272
+
273
+ # 2. 删除重复的顶点
274
+ # 3. 删除重复的面片
275
+ mesh.process(True)
276
+ return mesh
277
+
278
+
279
+ def get_aabb_from_face_ids(mesh, face_ids):
280
+ unique_ids = np.unique(face_ids)
281
+ aabb = []
282
+ for i in unique_ids:
283
+ if i == -1 or i == -2:
284
+ continue
285
+ _part_mask = face_ids == i
286
+ _faces = mesh.faces[_part_mask]
287
+ _faces = np.reshape(_faces, (-1))
288
+ _points = mesh.vertices[_faces]
289
+ min_xyz = np.min(_points, axis=0)
290
+ max_xyz = np.max(_points, axis=0)
291
+ aabb.append([min_xyz, max_xyz])
292
+ return np.array(aabb)
293
+
294
+
295
+ class Timer:
296
+ def __init__(self, name):
297
+ self.name = name
298
+
299
+ def __enter__(self):
300
+ self.start_time = time.time()
301
+ return self # 可以返回 self 以便在 with 块内访问
302
+
303
+ def __exit__(self, exc_type, exc_val, exc_tb):
304
+ self.end_time = time.time()
305
+ self.elapsed_time = self.end_time - self.start_time
306
+ print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒")
307
+
308
+
309
+ def sample_points_pre_face(vertices, faces, n_point_per_face=2000):
310
+ n_f = faces.shape[0] # 面片数量
311
+
312
+ # 生成随机数 u, v
313
+ u = np.sqrt(np.random.rand(n_f, n_point_per_face, 1)) # (n_f, n_point_per_face, 1)
314
+ v = np.random.rand(n_f, n_point_per_face, 1) # (n_f, n_point_per_face, 1)
315
+
316
+ # 计算 barycentric 坐标
317
+ w0 = 1 - u
318
+ w1 = u * (1 - v)
319
+ w2 = u * v # (n_f, n_point_per_face, 1)
320
+
321
+ # 从顶点中提取每个面的三个顶点
322
+ face_v_0 = vertices[faces[:, 0].reshape(-1)] # (n_f, 3)
323
+ face_v_1 = vertices[faces[:, 1].reshape(-1)] # (n_f, 3)
324
+ face_v_2 = vertices[faces[:, 2].reshape(-1)] # (n_f, 3)
325
+
326
+ # 扩展维度以匹配 w0, w1, w2 的形状
327
+ face_v_0 = face_v_0.reshape(n_f, 1, 3) # (n_f, 1, 3)
328
+ face_v_1 = face_v_1.reshape(n_f, 1, 3) # (n_f, 1, 3)
329
+ face_v_2 = face_v_2.reshape(n_f, 1, 3) # (n_f, 1, 3)
330
+
331
+ # 计算每个点的坐标
332
+ points = w0 * face_v_0 + w1 * face_v_1 + w2 * face_v_2 # (n_f, n_point_per_face, 3)
333
+
334
+ return points
335
+
336
+
337
+ def cal_cd_batch(p1, p2, pn=100000):
338
+ p1_n = p1.shape[0]
339
+ batch_num = (p1_n + pn - 1) // pn
340
+ p2_cuda = torch.from_numpy(p2).cuda().float().unsqueeze(0)
341
+ p1_cuda = torch.from_numpy(p1).cuda().float().unsqueeze(0)
342
+ cd_res = []
343
+ for i in tqdm(range(batch_num)):
344
+ start_idx = i * pn
345
+ end_idx = min((i + 1) * pn, p1_n)
346
+ _p1_cuda = p1_cuda[:, start_idx:end_idx, :]
347
+ _, _, idx, _ = cmd_loss(_p1_cuda, p2_cuda)
348
+ idx = idx[0].detach().cpu().numpy()
349
+ cd_res.append(idx)
350
+ cd_res = np.concatenate(cd_res, axis=0)
351
+ return cd_res
352
+
353
+
354
+ def remove_outliers_iqr(data, factor=1.5):
355
+ """
356
+ 基于 IQR 去除离群值
357
+ :param data: 输入的列表或 NumPy 数组
358
+ :param factor: IQR 的倍数(默认 1.5)
359
+ :return: 去除离群值后的列表
360
+ """
361
+ data = np.array(data, dtype=np.float32)
362
+ q1 = np.percentile(data, 25) # 第一四分位数
363
+ q3 = np.percentile(data, 75) # 第三四分位数
364
+ iqr = q3 - q1 # 四分位距
365
+ lower_bound = q1 - factor * iqr
366
+ upper_bound = q3 + factor * iqr
367
+ return data[(data >= lower_bound) & (data <= upper_bound)].tolist()
368
+
369
+
370
+ def better_aabb(points):
371
+ x = points[:, 0]
372
+ y = points[:, 1]
373
+ z = points[:, 2]
374
+ x = remove_outliers_iqr(x)
375
+ y = remove_outliers_iqr(y)
376
+ z = remove_outliers_iqr(z)
377
+ min_xyz = np.array([np.min(x), np.min(y), np.min(z)])
378
+ max_xyz = np.array([np.max(x), np.max(y), np.max(z)])
379
+ return [min_xyz, max_xyz]
380
+
381
+
382
+ def save_mesh(save_path, mesh, face_ids, color_map):
383
+ face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8)
384
+ for i in tqdm(range(len(mesh.faces)), disable=True):
385
+ _max_id = face_ids[i]
386
+ if _max_id == -2:
387
+ continue
388
+ face_colors[i, :3] = color_map[_max_id]
389
+
390
+ mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
391
+ mesh_save.visual.face_colors = face_colors
392
+ mesh_save.export(save_path)
393
+ mesh_save.export(save_path.replace(".glb", ".ply"))
394
+ # print('保存mesh完成')
395
+
396
+ scene_mesh = trimesh.Scene()
397
+ scene_mesh.add_geometry(mesh_save)
398
+ unique_ids = np.unique(face_ids)
399
+ aabb = []
400
+ for i in unique_ids:
401
+ if i == -1 or i == -2:
402
+ continue
403
+ _part_mask = face_ids == i
404
+ _faces = mesh.faces[_part_mask]
405
+ _faces = np.reshape(_faces, (-1))
406
+ _points = mesh.vertices[_faces]
407
+ min_xyz, max_xyz = better_aabb(_points)
408
+ center = (min_xyz + max_xyz) / 2
409
+ size = max_xyz - min_xyz
410
+ box = trimesh.path.creation.box_outline()
411
+ box.vertices *= size
412
+ box.vertices += center
413
+ box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]])
414
+ box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8)
415
+ box.colors = box_color
416
+ scene_mesh.add_geometry(box)
417
+ min_xyz = np.min(_points, axis=0)
418
+ max_xyz = np.max(_points, axis=0)
419
+ aabb.append([min_xyz, max_xyz])
420
+ scene_mesh.export(save_path.replace(".glb", "_aabb.glb"))
421
+ aabb = np.array(aabb)
422
+ np.save(save_path.replace(".glb", "_aabb.npy"), aabb)
423
+ np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids)
424
+
425
+
426
+ def mesh_sam(
427
+ model,
428
+ mesh,
429
+ save_path,
430
+ point_num=100000,
431
+ prompt_num=400,
432
+ save_mid_res=False,
433
+ show_info=False,
434
+ post_process=False,
435
+ threshold=0.95,
436
+ clean_mesh_flag=True,
437
+ seed=42,
438
+ prompt_bs=32,
439
+ ):
440
+ with Timer("加载mesh"):
441
+ model, model_parallel = model
442
+ if clean_mesh_flag:
443
+ mesh = clean_mesh(mesh)
444
+ mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=False)
445
+ if show_info:
446
+ print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}")
447
+
448
+ point_num = 100000
449
+ prompt_num = 400
450
+
451
+ with Timer("采样点云"):
452
+ _points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed)
453
+ _points_org = _points.copy()
454
+ _points = normalize_pc(_points)
455
+ normals = mesh.face_normals[face_idx]
456
+ # _points = _points + np.random.normal(0, 1, size=_points.shape) * 0.01
457
+ # normals = normals * 0. # debug no normal
458
+ if show_info:
459
+ print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}")
460
+
461
+ with Timer("获取特征"):
462
+ _feats = get_feat(model, _points, normals)
463
+ if show_info:
464
+ print("预处理特征")
465
+
466
+ if save_mid_res:
467
+ feat_save = _feats.float().detach().cpu().numpy()
468
+ data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
469
+ pca = PCA(n_components=3)
470
+ data_reduced = pca.fit_transform(data_scaled)
471
+ data_reduced = (data_reduced - data_reduced.min()) / (
472
+ data_reduced.max() - data_reduced.min()
473
+ )
474
+ _colors_pca = (data_reduced * 255).astype(np.uint8)
475
+ pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca)
476
+ pc_save.export(os.path.join(save_path, "point_pca.glb"))
477
+ pc_save.export(os.path.join(save_path, "point_pca.ply"))
478
+ if show_info:
479
+ print("PCA获取特征颜色")
480
+
481
+ with Timer("FPS采样提示点"):
482
+ fps_idx = fpsample.fps_sampling(_points, prompt_num)
483
+ _point_prompts = _points[fps_idx]
484
+ if save_mid_res:
485
+ trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
486
+ os.path.join(save_path, "point_prompts_pca.glb")
487
+ )
488
+ trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
489
+ os.path.join(save_path, "point_prompts_pca.ply")
490
+ )
491
+ if show_info:
492
+ print("采样完成")
493
+
494
+ with Timer("推理"):
495
+ bs = prompt_bs
496
+ step_num = prompt_num // bs + 1
497
+ mask_res = []
498
+ iou_res = []
499
+ for i in tqdm(range(step_num), disable=not show_info):
500
+ cur_propmt = _point_prompts[bs * i : bs * (i + 1)]
501
+ pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask(
502
+ model_parallel, _feats, _points, cur_propmt
503
+ )
504
+ pred_mask = np.stack(
505
+ [pred_mask_1, pred_mask_2, pred_mask_3], axis=-1
506
+ ) # [N, K, 3]
507
+ max_idx = np.argmax(pred_iou, axis=-1) # [K]
508
+ for j in range(max_idx.shape[0]):
509
+ mask_res.append(pred_mask[:, j, max_idx[j]])
510
+ iou_res.append(pred_iou[j, max_idx[j]])
511
+ mask_res = np.stack(mask_res, axis=-1) # [N, K]
512
+ if show_info:
513
+ print("prmopt 推理完成")
514
+
515
+ with Timer("根据IOU排序"):
516
+ iou_res = np.array(iou_res).tolist()
517
+ mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)]
518
+ mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True)
519
+ mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)]
520
+ iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)]
521
+
522
+ # clusters = {}
523
+ # for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
524
+ # _mask = mask_sorted[i]
525
+ # union_flag = False
526
+ # for j in clusters.keys():
527
+ # if cal_iou(_mask, mask_sorted[j]) > 0.9:
528
+ # clusters[j].append(i)
529
+ # union_flag = True
530
+ # break
531
+ # if not union_flag:
532
+ # clusters[i] = [i]
533
+ with Timer("NMS"):
534
+ clusters = defaultdict(list)
535
+ with ThreadPoolExecutor(max_workers=20) as executor:
536
+ for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
537
+ _mask = mask_sorted[i]
538
+ futures = []
539
+ for j in clusters.keys():
540
+ futures.append(executor.submit(cal_iou, _mask, mask_sorted[j]))
541
+
542
+ for j, future in zip(clusters.keys(), futures):
543
+ if future.result() > 0.9:
544
+ clusters[j].append(i)
545
+ break
546
+ else:
547
+ clusters[i].append(i)
548
+
549
+ # print(clusters)
550
+ if show_info:
551
+ print(f"NMS完成,mask数量:{len(clusters)}")
552
+
553
+ if save_mid_res:
554
+ part_mask_save_path = os.path.join(save_path, "part_mask")
555
+ if os.path.exists(part_mask_save_path):
556
+ shutil.rmtree(part_mask_save_path)
557
+ os.makedirs(part_mask_save_path, exist_ok=True)
558
+ for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info):
559
+ cluster_num = len(clusters[i])
560
+ cluster_iou = iou_sorted[i]
561
+ cluster_area = np.sum(mask_sorted[i])
562
+ if cluster_num <= 2:
563
+ continue
564
+ mask_save = mask_sorted[i]
565
+ mask_save = np.expand_dims(mask_save, axis=-1)
566
+ mask_save = np.repeat(mask_save, 3, axis=-1)
567
+ mask_save = (mask_save * 255).astype(np.uint8)
568
+ point_save = trimesh.points.PointCloud(_points, colors=mask_save)
569
+ point_save.export(
570
+ os.path.join(
571
+ part_mask_save_path,
572
+ f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
573
+ )
574
+ )
575
+
576
+ # 过滤只有一个mask的cluster
577
+ with Timer("过滤只有一个mask的cluster"):
578
+ filtered_clusters = []
579
+ other_clusters = []
580
+ for i in clusters.keys():
581
+ if len(clusters[i]) > 2:
582
+ filtered_clusters.append(i)
583
+ else:
584
+ other_clusters.append(i)
585
+ if show_info:
586
+ print(
587
+ f"过滤前:{len(clusters)} 个cluster,"
588
+ f"过滤后:{len(filtered_clusters)} 个cluster"
589
+ )
590
+
591
+ # 再次合并
592
+ with Timer("再次合并"):
593
+ filtered_clusters_num = len(filtered_clusters)
594
+ cluster2 = {}
595
+ is_union = [False] * filtered_clusters_num
596
+ for i in range(filtered_clusters_num):
597
+ if is_union[i]:
598
+ continue
599
+ cur_cluster = filtered_clusters[i]
600
+ cluster2[cur_cluster] = [cur_cluster]
601
+ for j in range(i + 1, filtered_clusters_num):
602
+ if is_union[j]:
603
+ continue
604
+ tar_cluster = filtered_clusters[j]
605
+ # if cal_single_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.9:
606
+ # if cal_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.5:
607
+ if (
608
+ cal_bbox_iou(
609
+ _points, mask_sorted[tar_cluster], mask_sorted[cur_cluster]
610
+ )
611
+ > 0.5
612
+ ):
613
+ cluster2[cur_cluster].append(tar_cluster)
614
+ is_union[j] = True
615
+ if show_info:
616
+ print(f"再次合并,合并数量:{len(cluster2.keys())}")
617
+
618
+ with Timer("计算没有mask的点"):
619
+ no_mask = np.ones(point_num)
620
+ for i in cluster2:
621
+ part_mask = mask_sorted[i]
622
+ no_mask[part_mask] = 0
623
+ if show_info:
624
+ print(
625
+ f"{np.sum(no_mask == 1)} 个点没有mask,"
626
+ f" 占比:{np.sum(no_mask == 1) / point_num:.4f}"
627
+ )
628
+
629
+ with Timer("修补遗漏mask"):
630
+ # 查询漏掉的mask
631
+ for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info):
632
+ if i in cluster2:
633
+ continue
634
+ part_mask = mask_sorted[i]
635
+ _iou = cal_single_iou(part_mask, no_mask)
636
+ if _iou > 0.7:
637
+ cluster2[i] = [i]
638
+ no_mask[part_mask] = 0
639
+ if save_mid_res:
640
+ mask_save = mask_sorted[i]
641
+ mask_save = np.expand_dims(mask_save, axis=-1)
642
+ mask_save = np.repeat(mask_save, 3, axis=-1)
643
+ mask_save = (mask_save * 255).astype(np.uint8)
644
+ point_save = trimesh.points.PointCloud(_points, colors=mask_save)
645
+ cluster_iou = iou_sorted[i]
646
+ cluster_area = int(np.sum(mask_sorted[i]))
647
+ cluster_num = 1
648
+ point_save.export(
649
+ os.path.join(
650
+ part_mask_save_path,
651
+ f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
652
+ )
653
+ )
654
+ # print(cluster2)
655
+ # print(len(cluster2.keys()))
656
+ if show_info:
657
+ print(f"修补遗漏mask:{len(cluster2.keys())}")
658
+
659
+ with Timer("计算点云最终mask"):
660
+ final_mask = list(cluster2.keys())
661
+ final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask]
662
+ final_mask_area = [
663
+ [final_mask[i], final_mask_area[i]] for i in range(len(final_mask))
664
+ ]
665
+ final_mask_area_sorted = sorted(
666
+ final_mask_area, key=lambda x: x[1], reverse=True
667
+ )
668
+ final_mask_sorted = [
669
+ final_mask_area_sorted[i][0] for i in range(len(final_mask_area))
670
+ ]
671
+ final_mask_area_sorted = [
672
+ final_mask_area_sorted[i][1] for i in range(len(final_mask_area))
673
+ ]
674
+ # print(final_mask_sorted)
675
+ # print(final_mask_area_sorted)
676
+ if show_info:
677
+ print(f"最终mask数量:{len(final_mask_sorted)}")
678
+
679
+ with Timer("点云上色"):
680
+ # 生成color map
681
+ color_map = {}
682
+ for i in final_mask_sorted:
683
+ part_color = np.random.rand(3) * 255
684
+ color_map[i] = part_color
685
+ # print(color_map)
686
+
687
+ result_mask = -np.ones(point_num, dtype=np.int64)
688
+ for i in final_mask_sorted:
689
+ part_mask = mask_sorted[i]
690
+ result_mask[part_mask] = i
691
+ if save_mid_res:
692
+ # 保存点云结果
693
+ result_colors = np.zeros_like(_colors_pca)
694
+ for i in final_mask_sorted:
695
+ part_color = color_map[i]
696
+ part_mask = mask_sorted[i]
697
+ result_colors[part_mask, :3] = part_color
698
+ trimesh.points.PointCloud(_points, colors=result_colors).export(
699
+ os.path.join(save_path, "auto_mask_cluster.glb")
700
+ )
701
+ trimesh.points.PointCloud(_points, colors=result_colors).export(
702
+ os.path.join(save_path, "auto_mask_cluster.ply")
703
+ )
704
+ if show_info:
705
+ print("保存点云完成")
706
+
707
+ with Timer("后处理"):
708
+ valid_mask = result_mask >= 0
709
+ _org = _points_org[valid_mask]
710
+ _results = result_mask[valid_mask]
711
+ pre_face = 10
712
+ _face_points = sample_points_pre_face(
713
+ mesh.vertices, mesh.faces, n_point_per_face=pre_face
714
+ )
715
+ _face_points = np.reshape(_face_points, (len(mesh.faces) * pre_face, 3))
716
+ _idx = cal_cd_batch(_face_points, _org)
717
+ _idx_res = _results[_idx]
718
+ _idx_res = np.reshape(_idx_res, (-1, pre_face))
719
+
720
+ face_ids = []
721
+ for i in range(len(mesh.faces)):
722
+ _label = np.argmax(np.bincount(_idx_res[i] + 2)) - 2
723
+ face_ids.append(_label)
724
+ final_face_ids = np.array(face_ids)
725
+
726
+ if save_mid_res:
727
+ save_mesh(
728
+ os.path.join(save_path, "auto_mask_mesh_final.glb"),
729
+ mesh,
730
+ final_face_ids,
731
+ color_map,
732
+ )
733
+
734
+ with Timer("计算最后的aabb"):
735
+ aabb = get_aabb_from_face_ids(mesh, final_face_ids)
736
+ return aabb, final_face_ids, mesh
737
+
738
+
739
+ class AutoMask:
740
+ def __init__(
741
+ self,
742
+ ckpt_path=None,
743
+ point_num=100000,
744
+ prompt_num=400,
745
+ threshold=0.95,
746
+ post_process=True,
747
+ automask_instance=None,
748
+ ):
749
+ """
750
+ ckpt_path: str, 模型路径
751
+ point_num: int, 采样点数量
752
+ prompt_num: int, 提示数量
753
+ threshold: float, 阈值
754
+ post_process: bool, 是否后处理
755
+ """
756
+ if automask_instance is not None:
757
+ self.model = automask_instance.model
758
+ self.model_parallel = automask_instance.model_parallel
759
+ else:
760
+ self.model = P3SAM()
761
+ self.model.load_state_dict(ckpt_path)
762
+ self.model.eval()
763
+ self.model_parallel = torch.nn.DataParallel(self.model)
764
+ self.model.cuda()
765
+ self.model_parallel.cuda()
766
+ self.point_num = point_num
767
+ self.prompt_num = prompt_num
768
+ self.threshold = threshold
769
+ self.post_process = post_process
770
+
771
+ def predict_aabb(
772
+ self,
773
+ mesh,
774
+ point_num=None,
775
+ prompt_num=None,
776
+ threshold=None,
777
+ post_process=None,
778
+ save_path=None,
779
+ save_mid_res=False,
780
+ show_info=True,
781
+ clean_mesh_flag=True,
782
+ seed=42,
783
+ is_parallel=True,
784
+ prompt_bs=32,
785
+ ):
786
+ """
787
+ Parameters:
788
+ mesh: trimesh.Trimesh, 输入网格
789
+ point_num: int, 采样点数量
790
+ prompt_num: int, 提示数量
791
+ threshold: float, 阈值
792
+ post_process: bool, 是否后处理
793
+ Returns:
794
+ aabb: np.ndarray, 包围盒
795
+ face_ids: np.ndarray, 面id
796
+ """
797
+ point_num = point_num if point_num is not None else self.point_num
798
+ prompt_num = prompt_num if prompt_num is not None else self.prompt_num
799
+ threshold = threshold if threshold is not None else self.threshold
800
+ post_process = post_process if post_process is not None else self.post_process
801
+ return mesh_sam(
802
+ [self.model, self.model_parallel if is_parallel else self.model],
803
+ mesh,
804
+ save_path=save_path,
805
+ point_num=point_num,
806
+ prompt_num=prompt_num,
807
+ threshold=threshold,
808
+ post_process=post_process,
809
+ show_info=show_info,
810
+ save_mid_res=save_mid_res,
811
+ clean_mesh_flag=clean_mesh_flag,
812
+ seed=seed,
813
+ prompt_bs=prompt_bs,
814
+ )
815
+
816
+
817
+ def set_seed(seed):
818
+ random.seed(seed)
819
+ np.random.seed(seed)
820
+ torch.manual_seed(seed)
821
+ if torch.cuda.is_available():
822
+ torch.cuda.manual_seed(seed)
823
+ torch.cuda.manual_seed_all(seed)
824
+ torch.backends.cudnn.deterministic = True
825
+ torch.backends.cudnn.benchmark = False
826
+
827
+
828
+ if __name__ == "__main__":
829
+ argparser = argparse.ArgumentParser()
830
+ argparser.add_argument(
831
+ "--ckpt_path", type=str, default=None, help="模型路径"
832
+ )
833
+ argparser.add_argument(
834
+ "--mesh_path", type=str, default="assets/1.glb", help="输入网格路径"
835
+ )
836
+ argparser.add_argument(
837
+ "--output_path", type=str, default="results/1", help="保存路径"
838
+ )
839
+ argparser.add_argument("--point_num", type=int, default=100000, help="采样点数量")
840
+ argparser.add_argument("--prompt_num", type=int, default=400, help="提示数量")
841
+ argparser.add_argument("--threshold", type=float, default=0.95, help="阈值")
842
+ argparser.add_argument("--post_process", type=int, default=0, help="是否后处理")
843
+ argparser.add_argument(
844
+ "--save_mid_res", type=int, default=1, help="是否保存中间结果"
845
+ )
846
+ argparser.add_argument("--show_info", type=int, default=1, help="是否显示信息")
847
+ argparser.add_argument(
848
+ "--show_time_info", type=int, default=1, help="是否显示时间信息"
849
+ )
850
+ argparser.add_argument("--seed", type=int, default=42, help="随机种子")
851
+ argparser.add_argument("--parallel", type=int, default=1, help="是否使用多卡")
852
+ argparser.add_argument(
853
+ "--prompt_bs", type=int, default=32, help="提示点推理时的batch size大小"
854
+ )
855
+ argparser.add_argument("--clean_mesh", type=int, default=1, help="是否清洗网格")
856
+ args = argparser.parse_args()
857
+ Timer.STATE = args.show_time_info
858
+
859
+ output_path = args.output_path
860
+ os.makedirs(output_path, exist_ok=True)
861
+ ckpt_path = args.ckpt_path
862
+ auto_mask = AutoMask(ckpt_path)
863
+ mesh_path = args.mesh_path
864
+ if os.path.isdir(mesh_path):
865
+ for file in os.listdir(mesh_path):
866
+ if not (
867
+ file.endswith(".glb") or file.endswith(".obj") or file.endswith(".ply")
868
+ ):
869
+ continue
870
+ _mesh_path = os.path.join(mesh_path, file)
871
+ _output_path = os.path.join(output_path, file[:-4])
872
+ os.makedirs(_output_path, exist_ok=True)
873
+ mesh = trimesh.load(_mesh_path, force="mesh")
874
+ set_seed(args.seed)
875
+ aabb, face_ids, mesh = auto_mask.predict_aabb(
876
+ mesh,
877
+ save_path=_output_path,
878
+ point_num=args.point_num,
879
+ prompt_num=args.prompt_num,
880
+ threshold=args.threshold,
881
+ post_process=args.post_process,
882
+ save_mid_res=args.save_mid_res,
883
+ show_info=args.show_info,
884
+ seed=args.seed,
885
+ is_parallel=args.parallel,
886
+ clean_mesh_flag=args.clean_mesh,
887
+ )
888
+ else:
889
+ mesh = trimesh.load(mesh_path, force="mesh")
890
+ set_seed(args.seed)
891
+ aabb, face_ids, mesh = auto_mask.predict_aabb(
892
+ mesh,
893
+ save_path=output_path,
894
+ point_num=args.point_num,
895
+ prompt_num=args.prompt_num,
896
+ threshold=args.threshold,
897
+ post_process=args.post_process,
898
+ save_mid_res=args.save_mid_res,
899
+ show_info=args.show_info,
900
+ seed=args.seed,
901
+ is_parallel=args.parallel,
902
+ clean_mesh_flag=args.clean_mesh,
903
+ )
904
+
905
+ ###############################################
906
+ ## 可以通过以下代码保存返回的结果
907
+ ## You can save the returned result by the following code
908
+ ################# save result #################
909
+ # color_map = {}
910
+ # unique_ids = np.unique(face_ids)
911
+ # for i in unique_ids:
912
+ # if i == -1:
913
+ # continue
914
+ # part_color = np.random.rand(3) * 255
915
+ # color_map[i] = part_color
916
+ # face_colors = []
917
+ # for i in face_ids:
918
+ # if i == -1:
919
+ # face_colors.append([0, 0, 0])
920
+ # else:
921
+ # face_colors.append(color_map[i])
922
+ # face_colors = np.array(face_colors).astype(np.uint8)
923
+ # mesh_save = mesh.copy()
924
+ # mesh_save.visual.face_colors = face_colors
925
+ # mesh_save.export(os.path.join(output_path, 'auto_mask_mesh.glb'))
926
+ # scene_mesh = trimesh.Scene()
927
+ # scene_mesh.add_geometry(mesh_save)
928
+ # for i in range(len(aabb)):
929
+ # min_xyz, max_xyz = aabb[i]
930
+ # center = (min_xyz + max_xyz) / 2
931
+ # size = max_xyz - min_xyz
932
+ # box = trimesh.path.creation.box_outline()
933
+ # box.vertices *= size
934
+ # box.vertices += center
935
+ # scene_mesh.add_geometry(box)
936
+ # scene_mesh.export(os.path.join(output_path, 'auto_mask_aabb.glb'))
937
+ ################# save result #################
938
+
939
+ """
940
+ python auto_mask_no_postprocess.py --parallel 0
941
+ python auto_mask_no_postprocess.py --ckpt_path ../weights/p3sam.ckpt --mesh_path assets/1.glb --output_path results/1 --parallel 0
942
+ python auto_mask_no_postprocess.py --ckpt_path ../weights/p3sam.ckpt --mesh_path assets --output_path results/all_no_postprocess
943
+ """
P3-SAM/model.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'XPart/partgen'))
6
+ from models import sonata
7
+ from utils.misc import smart_load_model
8
+
9
+ '''
10
+ This is the P3-SAM model.
11
+ The model is composed of three parts:
12
+ 1. Sonata: a 3D-CNN model for point cloud feature extraction.
13
+ 2. SEG1+SEG2: a two-stage multi-head segmentor
14
+ 3. IoU prediction: an IoU predictor
15
+ '''
16
+ def build_P3SAM(self):
17
+ ######################## Sonata ########################
18
+ self.sonata = sonata.load("sonata", repo_id="facebook/sonata", download_root='/root/sonata')
19
+ self.mlp = nn.Sequential(
20
+ nn.Linear(1232, 512),
21
+ nn.GELU(),
22
+ nn.Linear(512, 512),
23
+ nn.GELU(),
24
+ nn.Linear(512, 512),
25
+ )
26
+ self.transform = sonata.transform.default()
27
+ ######################## Sonata ########################
28
+
29
+ ######################## SEG1 ########################
30
+ self.seg_mlp_1 = nn.Sequential(
31
+ nn.Linear(512+3+3, 512),
32
+ nn.GELU(),
33
+ nn.Linear(512, 512),
34
+ nn.GELU(),
35
+ nn.Linear(512, 1),
36
+ )
37
+ self.seg_mlp_2 = nn.Sequential(
38
+ nn.Linear(512+3+3, 512),
39
+ nn.GELU(),
40
+ nn.Linear(512, 512),
41
+ nn.GELU(),
42
+ nn.Linear(512, 1),
43
+ )
44
+ self.seg_mlp_3 = nn.Sequential(
45
+ nn.Linear(512+3+3, 512),
46
+ nn.GELU(),
47
+ nn.Linear(512, 512),
48
+ nn.GELU(),
49
+ nn.Linear(512, 1),
50
+ )
51
+ ######################## SEG1 ########################
52
+
53
+ ######################## SEG2 ########################
54
+ self.seg_s2_mlp_g = nn.Sequential(
55
+ nn.Linear(512+3+3+3, 256),
56
+ nn.GELU(),
57
+ nn.Linear(256, 256),
58
+ nn.GELU(),
59
+ nn.Linear(256, 256),
60
+ )
61
+ self.seg_s2_mlp_1 = nn.Sequential(
62
+ nn.Linear(512+3+3+3+256, 256),
63
+ nn.GELU(),
64
+ nn.Linear(256, 256),
65
+ nn.GELU(),
66
+ nn.Linear(256, 1),
67
+ )
68
+ self.seg_s2_mlp_2 = nn.Sequential(
69
+ nn.Linear(512+3+3+3+256, 256),
70
+ nn.GELU(),
71
+ nn.Linear(256, 256),
72
+ nn.GELU(),
73
+ nn.Linear(256, 1),
74
+ )
75
+ self.seg_s2_mlp_3 = nn.Sequential(
76
+ nn.Linear(512+3+3+3+256, 256),
77
+ nn.GELU(),
78
+ nn.Linear(256, 256),
79
+ nn.GELU(),
80
+ nn.Linear(256, 1),
81
+ )
82
+ ######################## SEG2 ########################
83
+
84
+
85
+ self.iou_mlp = nn.Sequential(
86
+ nn.Linear(512+3+3+3+256, 256),
87
+ nn.GELU(),
88
+ nn.Linear(256, 256),
89
+ nn.GELU(),
90
+ nn.Linear(256, 256),
91
+ )
92
+ self.iou_mlp_out = nn.Sequential(
93
+ nn.Linear(256, 256),
94
+ nn.GELU(),
95
+ nn.Linear(256, 256),
96
+ nn.GELU(),
97
+ nn.Linear(256, 3),
98
+ )
99
+ self.iou_criterion = torch.nn.MSELoss()
100
+
101
+ '''
102
+ Load the P3-SAM model from a checkpoint.
103
+ If ckpt_path is not None, load the checkpoint from the given path.
104
+ If state_dict is not None, load the state_dict from the given state_dict.
105
+ If both ckpt_path and state_dict are None, download the model from huggingface and load the checkpoint.
106
+ '''
107
+ def load_state_dict(self,
108
+ ckpt_path=None,
109
+ state_dict=None,
110
+ strict=True,
111
+ assign=False,
112
+ ignore_seg_mlp=False,
113
+ ignore_seg_s2_mlp=False,
114
+ ignore_iou_mlp=False):
115
+ if ckpt_path is not None:
116
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
117
+ elif state_dict is None:
118
+ # download from huggingface
119
+ print(f'trying to download model from huggingface...')
120
+ from huggingface_hub import hf_hub_download
121
+ ckpt_path = hf_hub_download(repo_id="tencent/Hunyuan3D-Part", filename="p3sam.ckpt", local_dir='/cache/P3-SAM/')
122
+ print(f'download model from huggingface to: {ckpt_path}')
123
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
124
+
125
+ local_state_dict = self.state_dict()
126
+ seen_keys = {k: False for k in local_state_dict.keys()}
127
+ for k, v in state_dict.items():
128
+ if k.startswith("dit."):
129
+ k = k[4:]
130
+ if k in local_state_dict:
131
+ seen_keys[k] = True
132
+ if local_state_dict[k].shape == v.shape:
133
+ local_state_dict[k].copy_(v)
134
+ else:
135
+ print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
136
+ else:
137
+ print(f"unexpected key {k} in loaded state dict")
138
+ seg_mlp_flag = False
139
+ seg_s2_mlp_flag = False
140
+ iou_mlp_flag = False
141
+ for k in seen_keys:
142
+ if not seen_keys[k]:
143
+ if ignore_seg_mlp and 'seg_mlp' in k:
144
+ seg_mlp_flag = True
145
+ elif ignore_seg_s2_mlp and'seg_s2_mlp' in k:
146
+ seg_s2_mlp_flag = True
147
+ elif ignore_iou_mlp and 'iou_mlp' in k:
148
+ iou_mlp_flag = True
149
+ else:
150
+ print(f"missing key {k} in loaded state dict")
151
+ if ignore_seg_mlp and seg_mlp_flag:
152
+ print("seg_mlp is missing in loaded state dict, ignore seg_mlp in loaded state dict")
153
+ if ignore_seg_s2_mlp and seg_s2_mlp_flag:
154
+ print("seg_s2_mlp is missing in loaded state dict, ignore seg_s2_mlp in loaded state dict")
155
+ if ignore_iou_mlp and iou_mlp_flag:
156
+ print("iou_mlp is missing in loaded state dict, ignore iou_mlp in loaded state dict")
P3-SAM/utils/chamfer3D/chamfer3D.cu ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <stdio.h>
3
+ #include <ATen/ATen.h>
4
+
5
+ #include <cuda.h>
6
+ #include <cuda_runtime.h>
7
+
8
+ #include <vector>
9
+
10
+
11
+
12
+ __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
13
+ const int batch=512;
14
+ __shared__ float buf[batch*3];
15
+ for (int i=blockIdx.x;i<b;i+=gridDim.x){
16
+ for (int k2=0;k2<m;k2+=batch){
17
+ int end_k=min(m,k2+batch)-k2;
18
+ for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
19
+ buf[j]=xyz2[(i*m+k2)*3+j];
20
+ }
21
+ __syncthreads();
22
+ for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
23
+ float x1=xyz[(i*n+j)*3+0];
24
+ float y1=xyz[(i*n+j)*3+1];
25
+ float z1=xyz[(i*n+j)*3+2];
26
+ int best_i=0;
27
+ float best=0;
28
+ int end_ka=end_k-(end_k&3);
29
+ if (end_ka==batch){
30
+ for (int k=0;k<batch;k+=4){
31
+ {
32
+ float x2=buf[k*3+0]-x1;
33
+ float y2=buf[k*3+1]-y1;
34
+ float z2=buf[k*3+2]-z1;
35
+ float d=x2*x2+y2*y2+z2*z2;
36
+ if (k==0 || d<best){
37
+ best=d;
38
+ best_i=k+k2;
39
+ }
40
+ }
41
+ {
42
+ float x2=buf[k*3+3]-x1;
43
+ float y2=buf[k*3+4]-y1;
44
+ float z2=buf[k*3+5]-z1;
45
+ float d=x2*x2+y2*y2+z2*z2;
46
+ if (d<best){
47
+ best=d;
48
+ best_i=k+k2+1;
49
+ }
50
+ }
51
+ {
52
+ float x2=buf[k*3+6]-x1;
53
+ float y2=buf[k*3+7]-y1;
54
+ float z2=buf[k*3+8]-z1;
55
+ float d=x2*x2+y2*y2+z2*z2;
56
+ if (d<best){
57
+ best=d;
58
+ best_i=k+k2+2;
59
+ }
60
+ }
61
+ {
62
+ float x2=buf[k*3+9]-x1;
63
+ float y2=buf[k*3+10]-y1;
64
+ float z2=buf[k*3+11]-z1;
65
+ float d=x2*x2+y2*y2+z2*z2;
66
+ if (d<best){
67
+ best=d;
68
+ best_i=k+k2+3;
69
+ }
70
+ }
71
+ }
72
+ }else{
73
+ for (int k=0;k<end_ka;k+=4){
74
+ {
75
+ float x2=buf[k*3+0]-x1;
76
+ float y2=buf[k*3+1]-y1;
77
+ float z2=buf[k*3+2]-z1;
78
+ float d=x2*x2+y2*y2+z2*z2;
79
+ if (k==0 || d<best){
80
+ best=d;
81
+ best_i=k+k2;
82
+ }
83
+ }
84
+ {
85
+ float x2=buf[k*3+3]-x1;
86
+ float y2=buf[k*3+4]-y1;
87
+ float z2=buf[k*3+5]-z1;
88
+ float d=x2*x2+y2*y2+z2*z2;
89
+ if (d<best){
90
+ best=d;
91
+ best_i=k+k2+1;
92
+ }
93
+ }
94
+ {
95
+ float x2=buf[k*3+6]-x1;
96
+ float y2=buf[k*3+7]-y1;
97
+ float z2=buf[k*3+8]-z1;
98
+ float d=x2*x2+y2*y2+z2*z2;
99
+ if (d<best){
100
+ best=d;
101
+ best_i=k+k2+2;
102
+ }
103
+ }
104
+ {
105
+ float x2=buf[k*3+9]-x1;
106
+ float y2=buf[k*3+10]-y1;
107
+ float z2=buf[k*3+11]-z1;
108
+ float d=x2*x2+y2*y2+z2*z2;
109
+ if (d<best){
110
+ best=d;
111
+ best_i=k+k2+3;
112
+ }
113
+ }
114
+ }
115
+ }
116
+ for (int k=end_ka;k<end_k;k++){
117
+ float x2=buf[k*3+0]-x1;
118
+ float y2=buf[k*3+1]-y1;
119
+ float z2=buf[k*3+2]-z1;
120
+ float d=x2*x2+y2*y2+z2*z2;
121
+ if (k==0 || d<best){
122
+ best=d;
123
+ best_i=k+k2;
124
+ }
125
+ }
126
+ if (k2==0 || result[(i*n+j)]>best){
127
+ result[(i*n+j)]=best;
128
+ result_i[(i*n+j)]=best_i;
129
+ }
130
+ }
131
+ __syncthreads();
132
+ }
133
+ }
134
+ }
135
+ // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
136
+ int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
137
+
138
+ const auto batch_size = xyz1.size(0);
139
+ const auto n = xyz1.size(1); //num_points point cloud A
140
+ const auto m = xyz2.size(1); //num_points point cloud B
141
+
142
+ NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(), dist1.data<float>(), idx1.data<int>());
143
+ NmDistanceKernel<<<dim3(32,16,1),512>>>(batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(), dist2.data<float>(), idx2.data<int>());
144
+
145
+ cudaError_t err = cudaGetLastError();
146
+ if (err != cudaSuccess) {
147
+ printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
148
+ //THError("aborting");
149
+ return 0;
150
+ }
151
+ return 1;
152
+
153
+
154
+ }
155
+ __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
156
+ for (int i=blockIdx.x;i<b;i+=gridDim.x){
157
+ for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
158
+ float x1=xyz1[(i*n+j)*3+0];
159
+ float y1=xyz1[(i*n+j)*3+1];
160
+ float z1=xyz1[(i*n+j)*3+2];
161
+ int j2=idx1[i*n+j];
162
+ float x2=xyz2[(i*m+j2)*3+0];
163
+ float y2=xyz2[(i*m+j2)*3+1];
164
+ float z2=xyz2[(i*m+j2)*3+2];
165
+ float g=grad_dist1[i*n+j]*2;
166
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
167
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
168
+ atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
169
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
170
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
171
+ atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
172
+ }
173
+ }
174
+ }
175
+ // int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
176
+ int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2){
177
+ // cudaMemset(grad_xyz1,0,b*n*3*4);
178
+ // cudaMemset(grad_xyz2,0,b*m*3*4);
179
+
180
+ const auto batch_size = xyz1.size(0);
181
+ const auto n = xyz1.size(1); //num_points point cloud A
182
+ const auto m = xyz2.size(1); //num_points point cloud B
183
+
184
+ NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,n,xyz1.data<float>(),m,xyz2.data<float>(),graddist1.data<float>(),idx1.data<int>(),gradxyz1.data<float>(),gradxyz2.data<float>());
185
+ NmDistanceGradKernel<<<dim3(1,16,1),256>>>(batch_size,m,xyz2.data<float>(),n,xyz1.data<float>(),graddist2.data<float>(),idx2.data<int>(),gradxyz2.data<float>(),gradxyz1.data<float>());
186
+
187
+ cudaError_t err = cudaGetLastError();
188
+ if (err != cudaSuccess) {
189
+ printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
190
+ //THError("aborting");
191
+ return 0;
192
+ }
193
+ return 1;
194
+
195
+ }
196
+
P3-SAM/utils/chamfer3D/chamfer_cuda.cpp ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/torch.h>
2
+ #include <vector>
3
+
4
+ /// TMP
5
+ // #include "common.h"
6
+ /// NOT TMP
7
+
8
+ int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2,
9
+ at::Tensor idx1, at::Tensor idx2);
10
+
11
+ int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1,
12
+ at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2,
13
+ at::Tensor idx1, at::Tensor idx2);
14
+
15
+ int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2,
16
+ at::Tensor idx1, at::Tensor idx2) {
17
+ return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
18
+ }
19
+
20
+ int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2,
21
+ at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
22
+
23
+ return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
24
+ }
25
+
26
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
27
+ m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
28
+ m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
29
+ }
P3-SAM/utils/chamfer3D/dist_chamfer_3D.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.autograd import Function
3
+ import torch
4
+ import importlib
5
+ import os
6
+ chamfer_found = importlib.find_loader("chamfer_3D") is not None
7
+ if not chamfer_found:
8
+ ## Cool trick from https://github.com/chrdiller
9
+ print("Jitting Chamfer 3D")
10
+ cur_path = os.path.dirname(os.path.abspath(__file__))
11
+ build_path = cur_path.replace('chamfer3D', 'tmp')
12
+ os.makedirs(build_path, exist_ok=True)
13
+
14
+ from torch.utils.cpp_extension import load
15
+ chamfer_3D = load(name="chamfer_3D",
16
+ sources=[
17
+ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
18
+ "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
19
+ ], build_directory=build_path)
20
+ print("Loaded JIT 3D CUDA chamfer distance")
21
+
22
+ else:
23
+ import chamfer_3D
24
+ print("Loaded compiled 3D CUDA chamfer distance")
25
+
26
+
27
+ # Chamfer's distance module @thibaultgroueix
28
+ # GPU tensors only
29
+ class chamfer_3DFunction(Function):
30
+ @staticmethod
31
+ def forward(ctx, xyz1, xyz2):
32
+ batchsize, n, dim = xyz1.size()
33
+ assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
34
+ _, m, dim = xyz2.size()
35
+ assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
36
+ device = xyz1.device
37
+
38
+ device = xyz1.device
39
+
40
+ dist1 = torch.zeros(batchsize, n)
41
+ dist2 = torch.zeros(batchsize, m)
42
+
43
+ idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
44
+ idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
45
+
46
+ dist1 = dist1.to(device)
47
+ dist2 = dist2.to(device)
48
+ idx1 = idx1.to(device)
49
+ idx2 = idx2.to(device)
50
+ torch.cuda.set_device(device)
51
+
52
+ chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
53
+ ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
54
+ return dist1, dist2, idx1, idx2
55
+
56
+ @staticmethod
57
+ def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
58
+ xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
59
+ graddist1 = graddist1.contiguous()
60
+ graddist2 = graddist2.contiguous()
61
+ device = graddist1.device
62
+
63
+ gradxyz1 = torch.zeros(xyz1.size())
64
+ gradxyz2 = torch.zeros(xyz2.size())
65
+
66
+ gradxyz1 = gradxyz1.to(device)
67
+ gradxyz2 = gradxyz2.to(device)
68
+ chamfer_3D.backward(
69
+ xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
70
+ )
71
+ return gradxyz1, gradxyz2
72
+
73
+
74
+ class chamfer_3DDist(nn.Module):
75
+ def __init__(self):
76
+ super(chamfer_3DDist, self).__init__()
77
+
78
+ def forward(self, input1, input2):
79
+ input1 = input1.contiguous()
80
+ input2 = input2.contiguous()
81
+ return chamfer_3DFunction.apply(input1, input2)
P3-SAM/utils/chamfer3D/setup.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+
4
+ setup(
5
+ name='chamfer_3D',
6
+ ext_modules=[
7
+ CUDAExtension('chamfer_3D', [
8
+ "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
9
+ "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
10
+ ]),
11
+ ],
12
+ cmdclass={
13
+ 'build_ext': BuildExtension
14
+ })
XPart/data/000.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1b728ba92d353d87c5acd2289d9f19a0dc9ab6ceacdde488e7c5d3b456b2ff8
3
+ size 9000484
XPart/data/001.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98601f642c444d8466007b5b35e33a57d3f0bade873c9a7d7c57db039ea7a318
3
+ size 9000676
XPart/data/002.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dbd8408e37e41be78358bbca5bc1ef4d81e6413c92a19bc9c2f136760c0248a
3
+ size 8999812
XPart/data/003.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad8390bdafd83b8aea0f7e0159b24d9a8070c58140cd1706b80c6217fe8cf14d
3
+ size 9000880
XPart/data/004.glb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fb21265f371436d18437a9cb420e012fe07cf1e50fb20c1473223c921a101dc
3
+ size 9000796
XPart/partgen/bbox_estimator/auto_mask_api.py ADDED
@@ -0,0 +1,1417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import argparse
7
+ import trimesh
8
+ from sklearn.decomposition import PCA
9
+ import fpsample
10
+ from tqdm import tqdm
11
+ from collections import defaultdict
12
+
13
+ # from tqdm.notebook import tqdm
14
+ import time
15
+ import copy
16
+ import shutil
17
+ from pathlib import Path
18
+ from concurrent.futures import ThreadPoolExecutor
19
+
20
+ from numba import njit
21
+
22
+ #################################
23
+ # 修改sonata import路径
24
+ from ..models import sonata
25
+
26
+ #################################
27
+ sys.path.append("../P3-SAM")
28
+ from model import build_P3SAM, load_state_dict
29
+
30
+
31
+ class YSAM(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+ build_P3SAM(self)
35
+
36
+ def load_state_dict(
37
+ self,
38
+ state_dict=None,
39
+ strict=True,
40
+ assign=False,
41
+ ignore_seg_mlp=False,
42
+ ignore_seg_s2_mlp=False,
43
+ ignore_iou_mlp=False,
44
+ ):
45
+ load_state_dict(
46
+ self,
47
+ state_dict=state_dict,
48
+ strict=strict,
49
+ assign=assign,
50
+ ignore_seg_mlp=ignore_seg_mlp,
51
+ ignore_seg_s2_mlp=ignore_seg_s2_mlp,
52
+ ignore_iou_mlp=ignore_iou_mlp,
53
+ )
54
+
55
+ def forward(self, feats, points, point_prompt, iter=1):
56
+ """
57
+ feats: [K, N, 512]
58
+ points: [K, N, 3]
59
+ point_prompt: [K, N, 3]
60
+ """
61
+ # print(feats.shape, points.shape, point_prompt.shape)
62
+ point_num = points.shape[1]
63
+ feats = feats.transpose(0, 1) # [N, K, 512]
64
+ points = points.transpose(0, 1) # [N, K, 3]
65
+ point_prompt = point_prompt.transpose(0, 1) # [N, K, 3]
66
+ feats_seg = torch.cat([feats, points, point_prompt], dim=-1) # [N, K, 512+3+3]
67
+
68
+ # 预测mask stage-1
69
+ pred_mask_1 = self.seg_mlp_1(feats_seg).squeeze(-1) # [N, K]
70
+ pred_mask_2 = self.seg_mlp_2(feats_seg).squeeze(-1) # [N, K]
71
+ pred_mask_3 = self.seg_mlp_3(feats_seg).squeeze(-1) # [N, K]
72
+ pred_mask = torch.stack(
73
+ [pred_mask_1, pred_mask_2, pred_mask_3], dim=-1
74
+ ) # [N, K, 3]
75
+
76
+ for _ in range(iter):
77
+ # 预测mask stage-2
78
+ feats_seg_2 = torch.cat([feats_seg, pred_mask], dim=-1) # [N, K, 512+3+3+3]
79
+ feats_seg_global = self.seg_s2_mlp_g(feats_seg_2) # [N, K, 512]
80
+ feats_seg_global = torch.max(feats_seg_global, dim=0).values # [K, 512]
81
+ feats_seg_global = feats_seg_global.unsqueeze(0).repeat(
82
+ point_num, 1, 1
83
+ ) # [N, K, 512]
84
+ feats_seg_3 = torch.cat(
85
+ [feats_seg_global, feats_seg_2], dim=-1
86
+ ) # [N, K, 512+3+3+3+512]
87
+ pred_mask_s2_1 = self.seg_s2_mlp_1(feats_seg_3).squeeze(-1) # [N, K]
88
+ pred_mask_s2_2 = self.seg_s2_mlp_2(feats_seg_3).squeeze(-1) # [N, K]
89
+ pred_mask_s2_3 = self.seg_s2_mlp_3(feats_seg_3).squeeze(-1) # [N, K]
90
+ pred_mask_s2 = torch.stack(
91
+ [pred_mask_s2_1, pred_mask_s2_2, pred_mask_s2_3], dim=-1
92
+ ) # [N,, K 3]
93
+ pred_mask = pred_mask_s2
94
+
95
+ mask_1 = torch.sigmoid(pred_mask_s2_1).to(dtype=torch.float32) # [N, K]
96
+ mask_2 = torch.sigmoid(pred_mask_s2_2).to(dtype=torch.float32) # [N, K]
97
+ mask_3 = torch.sigmoid(pred_mask_s2_3).to(dtype=torch.float32) # [N, K]
98
+
99
+ feats_iou = torch.cat(
100
+ [feats_seg_global, feats_seg, pred_mask_s2], dim=-1
101
+ ) # [N, K, 512+3+3+3+512]
102
+ feats_iou = self.iou_mlp(feats_iou) # [N, K, 512]
103
+ feats_iou = torch.max(feats_iou, dim=0).values # [K, 512]
104
+ pred_iou = self.iou_mlp_out(feats_iou) # [K, 3]
105
+ pred_iou = torch.sigmoid(pred_iou).to(dtype=torch.float32) # [K, 3]
106
+
107
+ mask_1 = mask_1.transpose(0, 1) # [K, N]
108
+ mask_2 = mask_2.transpose(0, 1) # [K, N]
109
+ mask_3 = mask_3.transpose(0, 1) # [K, N]
110
+
111
+ return mask_1, mask_2, mask_3, pred_iou
112
+
113
+
114
+ def normalize_pc(pc):
115
+ """
116
+ pc: (N, 3)
117
+ """
118
+ max_, min_ = np.max(pc, axis=0), np.min(pc, axis=0)
119
+ center = (max_ + min_) / 2
120
+ scale = (max_ - min_) / 2
121
+ scale = np.max(np.abs(scale))
122
+ pc = (pc - center) / (scale + 1e-10)
123
+ return pc
124
+
125
+
126
+ @torch.no_grad()
127
+ def get_feat(model, points, normals):
128
+ data_dict = {
129
+ "coord": points,
130
+ "normal": normals,
131
+ "color": np.ones_like(points),
132
+ "batch": np.zeros(points.shape[0], dtype=np.int64),
133
+ }
134
+ data_dict = model.transform(data_dict)
135
+ for k in data_dict:
136
+ if isinstance(data_dict[k], torch.Tensor):
137
+ data_dict[k] = data_dict[k].cuda()
138
+ point = model.sonata(data_dict)
139
+ while "pooling_parent" in point.keys():
140
+ assert "pooling_inverse" in point.keys()
141
+ parent = point.pop("pooling_parent")
142
+ inverse = point.pop("pooling_inverse")
143
+ parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
144
+ point = parent
145
+ feat = point.feat # [M, 1232]
146
+ feat = model.mlp(feat) # [M, 512]
147
+ feat = feat[point.inverse] # [N, 512]
148
+ feats = feat
149
+ return feats
150
+
151
+
152
+ @torch.no_grad()
153
+ def get_mask(model, feats, points, point_prompt, iter=1):
154
+ """
155
+ feats: [N, 512]
156
+ points: [N, 3]
157
+ point_prompt: [K, 3]
158
+ """
159
+ point_num = points.shape[0]
160
+ prompt_num = point_prompt.shape[0]
161
+ feats = feats.unsqueeze(1) # [N, 1, 512]
162
+ feats = feats.repeat(1, prompt_num, 1).cuda() # [N, K, 512]
163
+ points = torch.from_numpy(points).float().cuda().unsqueeze(1) # [N, 1, 3]
164
+ points = points.repeat(1, prompt_num, 1) # [N, K, 3]
165
+ prompt_coord = (
166
+ torch.from_numpy(point_prompt).float().cuda().unsqueeze(0)
167
+ ) # [1, K, 3]
168
+ prompt_coord = prompt_coord.repeat(point_num, 1, 1) # [N, K, 3]
169
+
170
+ feats = feats.transpose(0, 1) # [K, N, 512]
171
+ points = points.transpose(0, 1) # [K, N, 3]
172
+ prompt_coord = prompt_coord.transpose(0, 1) # [K, N, 3]
173
+
174
+ mask_1, mask_2, mask_3, pred_iou = model(feats, points, prompt_coord, iter)
175
+
176
+ mask_1 = mask_1.transpose(0, 1) # [N, K]
177
+ mask_2 = mask_2.transpose(0, 1) # [N, K]
178
+ mask_3 = mask_3.transpose(0, 1) # [N, K]
179
+
180
+ mask_1 = mask_1.detach().cpu().numpy() > 0.5
181
+ mask_2 = mask_2.detach().cpu().numpy() > 0.5
182
+ mask_3 = mask_3.detach().cpu().numpy() > 0.5
183
+
184
+ org_iou = pred_iou.detach().cpu().numpy() # [K, 3]
185
+
186
+ return mask_1, mask_2, mask_3, org_iou
187
+
188
+
189
+ def cal_iou(m1, m2):
190
+ return np.sum(np.logical_and(m1, m2)) / np.sum(np.logical_or(m1, m2))
191
+
192
+
193
+ def cal_single_iou(m1, m2):
194
+ return np.sum(np.logical_and(m1, m2)) / np.sum(m1)
195
+
196
+
197
+ def iou_3d(box1, box2, signle=None):
198
+ """
199
+ 计算两个三维边界框的交并比 (IoU)
200
+
201
+ 参数:
202
+ box1 (list): 第一个边界框的坐标 [x1_min, y1_min, z1_min, x1_max, y1_max, z1_max]
203
+ box2 (list): 第二个边界框的坐标 [x2_min, y2_min, z2_min, x2_max, y2_max, z2_max]
204
+
205
+ 返回:
206
+ float: 交并比 (IoU) 值
207
+ """
208
+ # 计算交集的坐标
209
+ intersection_xmin = max(box1[0], box2[0])
210
+ intersection_ymin = max(box1[1], box2[1])
211
+ intersection_zmin = max(box1[2], box2[2])
212
+ intersection_xmax = min(box1[3], box2[3])
213
+ intersection_ymax = min(box1[4], box2[4])
214
+ intersection_zmax = min(box1[5], box2[5])
215
+
216
+ # 判断是否有交集
217
+ if (
218
+ intersection_xmin >= intersection_xmax
219
+ or intersection_ymin >= intersection_ymax
220
+ or intersection_zmin >= intersection_zmax
221
+ ):
222
+ return 0.0 # 无交集
223
+
224
+ # 计算交集的体积
225
+ intersection_volume = (
226
+ (intersection_xmax - intersection_xmin)
227
+ * (intersection_ymax - intersection_ymin)
228
+ * (intersection_zmax - intersection_zmin)
229
+ )
230
+
231
+ # 计算两个盒子的体积
232
+ box1_volume = (box1[3] - box1[0]) * (box1[4] - box1[1]) * (box1[5] - box1[2])
233
+ box2_volume = (box2[3] - box2[0]) * (box2[4] - box2[1]) * (box2[5] - box2[2])
234
+
235
+ if signle is None:
236
+ # 计算并集的体积
237
+ union_volume = box1_volume + box2_volume - intersection_volume
238
+ elif signle == "1":
239
+ union_volume = box1_volume
240
+ elif signle == "2":
241
+ union_volume = box2_volume
242
+ else:
243
+ raise ValueError("signle must be None or 1 or 2")
244
+
245
+ # 计算 IoU
246
+ iou = intersection_volume / union_volume if union_volume > 0 else 0.0
247
+ return iou
248
+
249
+
250
+ def cal_point_bbox_iou(p1, p2, signle=None):
251
+ min_p1 = np.min(p1, axis=0)
252
+ max_p1 = np.max(p1, axis=0)
253
+ min_p2 = np.min(p2, axis=0)
254
+ max_p2 = np.max(p2, axis=0)
255
+ box1 = [min_p1[0], min_p1[1], min_p1[2], max_p1[0], max_p1[1], max_p1[2]]
256
+ box2 = [min_p2[0], min_p2[1], min_p2[2], max_p2[0], max_p2[1], max_p2[2]]
257
+ return iou_3d(box1, box2, signle)
258
+
259
+
260
+ def cal_bbox_iou(points, m1, m2):
261
+ p1 = points[m1]
262
+ p2 = points[m2]
263
+ return cal_point_bbox_iou(p1, p2)
264
+
265
+
266
+ def clean_mesh(mesh):
267
+ """
268
+ mesh: trimesh.Trimesh
269
+ """
270
+ # 1. 合并接近的顶点
271
+ mesh.merge_vertices()
272
+
273
+ # 2. 删除重复的顶点
274
+ # 3. 删除重复的面片
275
+ mesh.process(True)
276
+ return mesh
277
+
278
+
279
+ # @njit
280
+ def remove_outliers_iqr(data, factor=1.5):
281
+ """
282
+ 基于 IQR 去除离群值
283
+ :param data: 输入的列表或 NumPy 数组
284
+ :param factor: IQR 的倍数(默认 1.5)
285
+ :return: 去除离群值后的列表
286
+ """
287
+ data = np.array(data, dtype=np.float32)
288
+ q1 = np.percentile(data, 25) # 第一四分位数
289
+ q3 = np.percentile(data, 75) # 第三四分位数
290
+ iqr = q3 - q1 # 四分位距
291
+ lower_bound = q1 - factor * iqr
292
+ upper_bound = q3 + factor * iqr
293
+ return data[(data >= lower_bound) & (data <= upper_bound)].tolist()
294
+
295
+
296
+ # @njit
297
+ def better_aabb(points):
298
+ x = points[:, 0]
299
+ y = points[:, 1]
300
+ z = points[:, 2]
301
+ x = remove_outliers_iqr(x)
302
+ y = remove_outliers_iqr(y)
303
+ z = remove_outliers_iqr(z)
304
+ min_xyz = np.array([np.min(x), np.min(y), np.min(z)])
305
+ max_xyz = np.array([np.max(x), np.max(y), np.max(z)])
306
+ return [min_xyz, max_xyz]
307
+
308
+
309
+ def fix_label(face_ids, adjacent_faces, use_aabb=False, mesh=None, show_info=False):
310
+ if use_aabb:
311
+
312
+ def _cal_aabb(face_ids, i, _points_org):
313
+ _part_mask = face_ids == i
314
+ _faces = mesh.faces[_part_mask]
315
+ _faces = np.reshape(_faces, (-1))
316
+ _points = mesh.vertices[_faces]
317
+ min_xyz, max_xyz = better_aabb(_points)
318
+ _part_mask = (
319
+ (_points_org[:, 0] >= min_xyz[0])
320
+ & (_points_org[:, 0] <= max_xyz[0])
321
+ & (_points_org[:, 1] >= min_xyz[1])
322
+ & (_points_org[:, 1] <= max_xyz[1])
323
+ & (_points_org[:, 2] >= min_xyz[2])
324
+ & (_points_org[:, 2] <= max_xyz[2])
325
+ )
326
+ _part_mask = np.reshape(_part_mask, (-1, 3))
327
+ _part_mask = np.all(_part_mask, axis=1)
328
+ return i, [min_xyz, max_xyz], _part_mask
329
+
330
+ with Timer("计算aabb"):
331
+ aabb = {}
332
+ unique_ids = np.unique(face_ids)
333
+ # print(max(unique_ids))
334
+ aabb_face_mask = {}
335
+ _faces = mesh.faces
336
+ _vertices = mesh.vertices
337
+ _faces = np.reshape(_faces, (-1))
338
+ _points = _vertices[_faces]
339
+ with ThreadPoolExecutor(max_workers=20) as executor:
340
+ futures = []
341
+ for i in unique_ids:
342
+ if i < 0:
343
+ continue
344
+ futures.append(executor.submit(_cal_aabb, face_ids, i, _points))
345
+ for future in futures:
346
+ res = future.result()
347
+ aabb[res[0]] = res[1]
348
+ aabb_face_mask[res[0]] = res[2]
349
+
350
+ # _faces = mesh.faces
351
+ # _vertices = mesh.vertices
352
+ # _faces = np.reshape(_faces, (-1))
353
+ # _points = _vertices[_faces]
354
+ # aabb_face_mask = cal_aabb_mask(_points, face_ids)
355
+
356
+ with Timer("合并mesh"):
357
+ loop_cnt = 1
358
+ changed = True
359
+ progress = tqdm(disable=not show_info)
360
+ no_mask_ids = np.where(face_ids < 0)[0].tolist()
361
+ faces_max = adjacent_faces.shape[0]
362
+ while changed and loop_cnt <= 50:
363
+ changed = False
364
+ # 获取无色面片
365
+ new_no_mask_ids = []
366
+ for i in no_mask_ids:
367
+ # if face_ids[i] < 0:
368
+ # 找邻居
369
+ if not (0 <= i < faces_max):
370
+ continue
371
+ _adj_faces = adjacent_faces[i]
372
+ _adj_ids = []
373
+ for j in _adj_faces:
374
+ if j == -1:
375
+ break
376
+ if face_ids[j] >= 0:
377
+ _tar_id = face_ids[j]
378
+ if use_aabb:
379
+ _mask = aabb_face_mask[_tar_id]
380
+ if _mask[i]:
381
+ _adj_ids.append(_tar_id)
382
+ else:
383
+ _adj_ids.append(_tar_id)
384
+ if len(_adj_ids) == 0:
385
+ new_no_mask_ids.append(i)
386
+ continue
387
+ _max_id = np.argmax(np.bincount(_adj_ids))
388
+ face_ids[i] = _max_id
389
+ changed = True
390
+ no_mask_ids = new_no_mask_ids
391
+ # print(loop_cnt)
392
+ progress.update(1)
393
+ # progress.set_description(f"合并mesh循环:{loop_cnt} {np.sum(face_ids < 0)}")
394
+ loop_cnt += 1
395
+ return face_ids
396
+
397
+
398
+ def save_mesh(save_path, mesh, face_ids, color_map):
399
+ face_colors = np.zeros((len(mesh.faces), 3), dtype=np.uint8)
400
+ for i in tqdm(range(len(mesh.faces)), disable=True):
401
+ _max_id = face_ids[i]
402
+ if _max_id == -2:
403
+ continue
404
+ face_colors[i, :3] = color_map[_max_id]
405
+
406
+ mesh_save = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
407
+ mesh_save.visual.face_colors = face_colors
408
+ mesh_save.export(save_path)
409
+ mesh_save.export(save_path.replace(".glb", ".ply"))
410
+ # print('保存mesh完成')
411
+
412
+ scene_mesh = trimesh.Scene()
413
+ scene_mesh.add_geometry(mesh_save)
414
+ unique_ids = np.unique(face_ids)
415
+ aabb = []
416
+ for i in unique_ids:
417
+ if i == -1 or i == -2:
418
+ continue
419
+ _part_mask = face_ids == i
420
+ _faces = mesh.faces[_part_mask]
421
+ _faces = np.reshape(_faces, (-1))
422
+ _points = mesh.vertices[_faces]
423
+ min_xyz, max_xyz = better_aabb(_points)
424
+ center = (min_xyz + max_xyz) / 2
425
+ size = max_xyz - min_xyz
426
+ box = trimesh.path.creation.box_outline()
427
+ box.vertices *= size
428
+ box.vertices += center
429
+ box_color = np.array([[color_map[i][0], color_map[i][1], color_map[i][2], 255]])
430
+ box_color = np.repeat(box_color, len(box.entities), axis=0).astype(np.uint8)
431
+ box.colors = box_color
432
+ scene_mesh.add_geometry(box)
433
+ min_xyz = np.min(_points, axis=0)
434
+ max_xyz = np.max(_points, axis=0)
435
+ aabb.append([min_xyz, max_xyz])
436
+ scene_mesh.export(save_path.replace(".glb", "_aabb.glb"))
437
+ aabb = np.array(aabb)
438
+ np.save(save_path.replace(".glb", "_aabb.npy"), aabb)
439
+ np.save(save_path.replace(".glb", "_face_ids.npy"), face_ids)
440
+
441
+
442
+ def get_aabb_from_face_ids(mesh, face_ids):
443
+ unique_ids = np.unique(face_ids)
444
+ aabb = []
445
+ for i in unique_ids:
446
+ if i == -1 or i == -2:
447
+ continue
448
+ _part_mask = face_ids == i
449
+ _faces = mesh.faces[_part_mask]
450
+ _faces = np.reshape(_faces, (-1))
451
+ _points = mesh.vertices[_faces]
452
+ min_xyz = np.min(_points, axis=0)
453
+ max_xyz = np.max(_points, axis=0)
454
+ aabb.append([min_xyz, max_xyz])
455
+ return np.array(aabb)
456
+
457
+
458
+ def calculate_face_areas(mesh):
459
+ """
460
+ 计算每个三角形面片的面积
461
+ :param mesh: trimesh.Trimesh 对象
462
+ :return: 面片面积数组 (n_faces,)
463
+ """
464
+ return mesh.area_faces
465
+ # # 提取顶点和面片索引
466
+ # vertices = mesh.vertices
467
+ # faces = mesh.faces
468
+
469
+ # # 获取所有三个顶点的坐标
470
+ # v0 = vertices[faces[:, 0]]
471
+ # v1 = vertices[faces[:, 1]]
472
+ # v2 = vertices[faces[:, 2]]
473
+
474
+ # # 计算两个边向量
475
+ # edge1 = v1 - v0
476
+ # edge2 = v2 - v0
477
+
478
+ # # 计算叉积的模长(向量面积的两倍)
479
+ # cross_product = np.cross(edge1, edge2)
480
+ # areas = 0.5 * np.linalg.norm(cross_product, axis=1)
481
+
482
+ # return areas
483
+
484
+
485
+ def get_connected_region(face_ids, adjacent_faces, return_face_part_ids=False):
486
+ vis = [False] * len(face_ids)
487
+ parts = []
488
+ face_part_ids = np.ones_like(face_ids) * -1
489
+ for i in range(len(face_ids)):
490
+ if vis[i]:
491
+ continue
492
+ _part = []
493
+ _queue = [i]
494
+ while len(_queue) > 0:
495
+ _cur_face = _queue.pop(0)
496
+ if vis[_cur_face]:
497
+ continue
498
+ vis[_cur_face] = True
499
+ _part.append(_cur_face)
500
+ face_part_ids[_cur_face] = len(parts)
501
+ if not (0 <= _cur_face < adjacent_faces.shape[0]):
502
+ continue
503
+ _cur_face_id = face_ids[_cur_face]
504
+ _adj_faces = adjacent_faces[_cur_face]
505
+ for j in _adj_faces:
506
+ if j == -1:
507
+ break
508
+ if not vis[j] and face_ids[j] == _cur_face_id:
509
+ _queue.append(j)
510
+ parts.append(_part)
511
+ if return_face_part_ids:
512
+ return parts, face_part_ids
513
+ else:
514
+ return parts
515
+
516
+
517
+ def aabb_distance(box1, box2):
518
+ """
519
+ 计算两个轴对齐包围盒(AABB)之间的最近距离。
520
+ :param box1: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
521
+ :param box2: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
522
+ :return: 最近距离(浮点数)
523
+ """
524
+ # 解包坐标
525
+ min1, max1 = box1
526
+ min2, max2 = box2
527
+
528
+ # 计算各轴上的分离距离
529
+ dx = max(0, max2[0] - min1[0], max1[0] - min2[0]) # x轴分离距离
530
+ dy = max(0, max2[1] - min1[1], max1[1] - min2[1]) # y轴分离距离
531
+ dz = max(0, max2[2] - min1[2], max1[2] - min2[2]) # z轴分离距离
532
+
533
+ # 如果所有轴都重叠,则距离为0
534
+ if dx == 0 and dy == 0 and dz == 0:
535
+ return 0.0
536
+
537
+ # 计算欧几里得距离
538
+ return np.sqrt(dx**2 + dy**2 + dz**2)
539
+
540
+
541
+ def aabb_volume(aabb):
542
+ """
543
+ 计算轴对齐包围盒(AABB)的体积。
544
+ :param aabb: 元组 (min_x, min_y, min_z, max_x, max_y, max_z)
545
+ :return: 体积(浮点数)
546
+ """
547
+ # 解包坐标
548
+ min_xyz, max_xyz = aabb
549
+
550
+ # 计算体积
551
+ dx = max_xyz[0] - min_xyz[0]
552
+ dy = max_xyz[1] - min_xyz[1]
553
+ dz = max_xyz[2] - min_xyz[2]
554
+ return dx * dy * dz
555
+
556
+
557
+ def find_neighbor_part(parts, adjacent_faces, parts_aabb=None, parts_ids=None):
558
+ face2part = {}
559
+ for i, part in enumerate(parts):
560
+ for face in part:
561
+ face2part[face] = i
562
+ neighbor_parts = []
563
+ for i, part in enumerate(parts):
564
+ neighbor_part = set()
565
+ for face in part:
566
+ if not (0 <= face < adjacent_faces.shape[0]):
567
+ continue
568
+ for adj_face in adjacent_faces[face]:
569
+ if adj_face == -1:
570
+ break
571
+ if adj_face not in face2part:
572
+ continue
573
+ if face2part[adj_face] == i:
574
+ continue
575
+ if parts_ids is not None and parts_ids[face2part[adj_face]] in [-1, -2]:
576
+ continue
577
+ neighbor_part.add(face2part[adj_face])
578
+ neighbor_part = list(neighbor_part)
579
+ if (
580
+ parts_aabb is not None
581
+ and parts_ids is not None
582
+ and (parts_ids[i] == -1 or parts_ids[i] == -2)
583
+ and len(neighbor_part) == 0
584
+ ):
585
+ min_dis = np.inf
586
+ min_idx = -1
587
+ for j, _part in enumerate(parts):
588
+ if j == i:
589
+ continue
590
+ if parts_ids[j] == -1 or parts_ids[j] == -2:
591
+ continue
592
+ aabb_1 = parts_aabb[i]
593
+ aabb_2 = parts_aabb[j]
594
+ dis = aabb_distance(aabb_1, aabb_2)
595
+ if dis < min_dis:
596
+ min_dis = dis
597
+ min_idx = j
598
+ elif dis == min_dis:
599
+ if aabb_volume(parts_aabb[j]) < aabb_volume(parts_aabb[min_idx]):
600
+ min_idx = j
601
+ neighbor_part = [min_idx]
602
+ neighbor_parts.append(neighbor_part)
603
+ return neighbor_parts
604
+
605
+
606
+ def do_post_process(
607
+ face_areas, parts, adjacent_faces, face_ids, threshold=0.95, show_info=False
608
+ ):
609
+ # # 获取邻接面片
610
+ # mesh_save = mesh.copy()
611
+ # face_adjacency = mesh.face_adjacency
612
+ # adjacent_faces = {}
613
+ # for face1, face2 in face_adjacency:
614
+ # if face1 not in adjacent_faces:
615
+ # adjacent_faces[face1] = []
616
+ # if face2 not in adjacent_faces:
617
+ # adjacent_faces[face2] = []
618
+ # adjacent_faces[face1].append(face2)
619
+ # adjacent_faces[face2].append(face1)
620
+
621
+ # parts = get_connected_region(face_ids, adjacent_faces)
622
+
623
+ unique_ids = np.unique(face_ids)
624
+ if show_info:
625
+ print(f"连通区域数量:{len(parts)}")
626
+ print(f"ID数量:{len(unique_ids)}")
627
+
628
+ # face_areas = calculate_face_areas(mesh)
629
+ total_area = np.sum(face_areas)
630
+ if show_info:
631
+ print(f"总面积:{total_area}")
632
+ part_areas = []
633
+ for i, part in enumerate(parts):
634
+ part_area = np.sum(face_areas[part])
635
+ part_areas.append(float(part_area / total_area))
636
+
637
+ sorted_parts = sorted(zip(part_areas, parts), key=lambda x: x[0], reverse=True)
638
+ parts = [x[1] for x in sorted_parts]
639
+ part_areas = [x[0] for x in sorted_parts]
640
+ integral_part_areas = np.cumsum(part_areas)
641
+
642
+ neighbor_parts = find_neighbor_part(parts, adjacent_faces)
643
+
644
+ new_face_ids = face_ids.copy()
645
+
646
+ for i, part in enumerate(parts):
647
+ if integral_part_areas[i] > threshold and part_areas[i] < 0.01:
648
+ if len(neighbor_parts[i]) > 0:
649
+ max_area = 0
650
+ max_part = -1
651
+ for j in neighbor_parts[i]:
652
+ if integral_part_areas[j] > threshold:
653
+ continue
654
+ if part_areas[j] > max_area:
655
+ max_area = part_areas[j]
656
+ max_part = j
657
+ if max_part != -1:
658
+ if show_info:
659
+ print(f"合并mesh:{i} {max_part}")
660
+ parts[max_part].extend(part)
661
+ parts[i] = []
662
+ target_face_id = face_ids[parts[max_part][0]]
663
+ for face in part:
664
+ new_face_ids[face] = target_face_id
665
+
666
+ return new_face_ids
667
+
668
+
669
+ def do_no_mask_process(parts, face_ids):
670
+ # # 获取邻接面片
671
+ # mesh_save = mesh.copy()
672
+ # face_adjacency = mesh.face_adjacency
673
+ # adjacent_faces = {}
674
+ # for face1, face2 in face_adjacency:
675
+ # if face1 not in adjacent_faces:
676
+ # adjacent_faces[face1] = []
677
+ # if face2 not in adjacent_faces:
678
+ # adjacent_faces[face2] = []
679
+ # adjacent_faces[face1].append(face2)
680
+ # adjacent_faces[face2].append(face1)
681
+ # parts = get_connected_region(face_ids, adjacent_faces)
682
+
683
+ unique_ids = np.unique(face_ids)
684
+ max_id = np.max(unique_ids)
685
+ if -1 or -2 in unique_ids:
686
+ new_face_ids = face_ids.copy()
687
+ for i, part in enumerate(parts):
688
+ if face_ids[part[0]] == -1 or face_ids[part[0]] == -2:
689
+ for face in part:
690
+ new_face_ids[face] = max_id + 1
691
+ max_id += 1
692
+ return new_face_ids
693
+ else:
694
+ return face_ids
695
+
696
+
697
+ def union_aabb(aabb1, aabb2):
698
+ min_xyz1 = aabb1[0]
699
+ max_xyz1 = aabb1[1]
700
+ min_xyz2 = aabb2[0]
701
+ max_xyz2 = aabb2[1]
702
+ min_xyz = np.minimum(min_xyz1, min_xyz2)
703
+ max_xyz = np.maximum(max_xyz1, max_xyz2)
704
+ return [min_xyz, max_xyz]
705
+
706
+
707
+ def aabb_increase(aabb1, aabb2):
708
+ min_xyz_before = aabb1[0]
709
+ max_xyz_before = aabb1[1]
710
+ min_xyz_after, max_xyz_after = union_aabb(aabb1, aabb2)
711
+ min_xyz_increase = np.abs(min_xyz_after - min_xyz_before) / np.abs(min_xyz_before)
712
+ max_xyz_increase = np.abs(max_xyz_after - max_xyz_before) / np.abs(max_xyz_before)
713
+ return min_xyz_increase, max_xyz_increase
714
+
715
+
716
+ def sort_multi_list(multi_list, key=lambda x: x[0], reverse=False):
717
+ """
718
+ multi_list: [list1, list2, list3, list4, ...], len(list1)=N, len(list2)=N, len(list3)=N, ...
719
+ key: 排序函数,默认按第一个元素排序
720
+ reverse: 排序顺序,默认降序
721
+ return:
722
+ [list1, list2, list3, list4, ...]: 按同一个顺序排序后的多个list
723
+ """
724
+ sorted_list = sorted(zip(*multi_list), key=key, reverse=reverse)
725
+ return zip(*sorted_list)
726
+
727
+
728
+ # def sample_mesh(mesh, adjacent_faces, point_num=100000):
729
+ # connected_parts = get_connected_region(np.ones(len(mesh.faces)), adjacent_faces)
730
+ # _points, face_idx = trimesh.sample.sample_surface(mesh, point_num)
731
+ # face_sampled = np.zeros(len(mesh.faces), dtype=np.bool)
732
+ # face_sampled[face_idx] = True
733
+ # for parts in connected_parts
734
+
735
+ # def parallel_run(model_parallel, feats, points, prompts):
736
+ # bs = prompts.shape[0]
737
+ # prompts_1 = prompts[:bs//2]
738
+ # prompts_2 = prompts[bs//2:]
739
+ # device_1 = 'cuda:0'
740
+ # device_2 = 'cuda:1'
741
+ # pred_mask_1_1, pred_mask_2_1, pred_mask_3_1, pred_iou_1 = get_mask(
742
+ # model_parallel.module.to(device_1), feats, points, prompts_1, device=device_1
743
+ # )
744
+ # pred_mask_1_2, pred_mask_2_2, pred_mask_3_2, pred_iou_2 = get_mask(
745
+ # model_parallel.module.to(device_2), feats, points, prompts_2, device=device_2
746
+ # )
747
+ # pred_mask_1 = np.concatenate([pred_mask_1_1, pred_mask_1_2], axis=1)
748
+ # pred_mask_2 = np.concatenate([pred_mask_2_1, pred_mask_2_2], axis=1)
749
+ # pred_mask_3 = np.concatenate([pred_mask_3_1, pred_mask_3_2], axis=1)
750
+ # pred_iou = np.concatenate([pred_iou_1, pred_iou_2], axis=0)
751
+ # return pred_mask_1, pred_mask_2, pred_mask_3, pred_iou
752
+
753
+ ############################################################################################
754
+
755
+
756
+ class Timer:
757
+ def __init__(self, name):
758
+ self.name = name
759
+
760
+ def __enter__(self):
761
+ self.start_time = time.time()
762
+ return self # 可以返回 self 以便在 with 块内访问
763
+
764
+ def __exit__(self, exc_type, exc_val, exc_tb):
765
+ self.end_time = time.time()
766
+ self.elapsed_time = self.end_time - self.start_time
767
+ print(f">>>>>>代码{self.name} 运行时间: {self.elapsed_time:.4f} 秒")
768
+
769
+
770
+ ###################### NUMBA 加速 ######################
771
+ @njit
772
+ def build_adjacent_faces_numba(face_adjacency):
773
+ """
774
+ 使用 Numba 加速构建邻接面片数组。
775
+ :param face_adjacency: (N, 2) numpy 数组,包含邻接面片对。
776
+ :return:
777
+ - adj_list: 一维数组,存储所有邻接面片。
778
+ - offsets: 一维数组,记录每个面片的邻接起始位置。
779
+ """
780
+ n_faces = np.max(face_adjacency) + 1 # 总面片数
781
+ n_edges = face_adjacency.shape[0] # 总邻接边数
782
+
783
+ # 第一步:统计每个面片的邻接数量(度数)
784
+ degrees = np.zeros(n_faces, dtype=np.int32)
785
+ for i in range(n_edges):
786
+ f1, f2 = face_adjacency[i]
787
+ degrees[f1] += 1
788
+ degrees[f2] += 1
789
+ max_degree = np.max(degrees) # 最大度数
790
+
791
+ adjacent_faces = np.ones((n_faces, max_degree), dtype=np.int32) * -1 # 邻接面片数组
792
+ adjacent_faces_count = np.zeros(n_faces, dtype=np.int32) # 邻接面片计数器
793
+ for i in range(n_edges):
794
+ f1, f2 = face_adjacency[i]
795
+ adjacent_faces[f1, adjacent_faces_count[f1]] = f2
796
+ adjacent_faces_count[f1] += 1
797
+ adjacent_faces[f2, adjacent_faces_count[f2]] = f1
798
+ adjacent_faces_count[f2] += 1
799
+ return adjacent_faces
800
+
801
+
802
+ ###################### NUMBA 加速 ######################
803
+
804
+
805
+ def mesh_sam(
806
+ model,
807
+ mesh,
808
+ save_path,
809
+ point_num=100000,
810
+ prompt_num=400,
811
+ save_mid_res=False,
812
+ show_info=False,
813
+ post_process=False,
814
+ threshold=0.95,
815
+ clean_mesh_flag=True,
816
+ seed=42,
817
+ ):
818
+ with Timer("加载mesh"):
819
+ model, model_parallel = model
820
+ if clean_mesh_flag:
821
+ mesh = clean_mesh(mesh)
822
+ mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, process=False)
823
+ if show_info:
824
+ print(f"点数:{mesh.vertices.shape[0]} 面片数:{mesh.faces.shape[0]}")
825
+
826
+ point_num = 100000
827
+ prompt_num = 400
828
+ with Timer("获取邻接面片"):
829
+ # 获取邻接面片
830
+ face_adjacency = mesh.face_adjacency
831
+ with Timer("处理邻接面片"):
832
+ # adjacent_faces = defaultdict(list)
833
+ # for face1, face2 in face_adjacency:
834
+ # adjacent_faces[face1].append(face2)
835
+ # adjacent_faces[face2].append(face1)
836
+ # adj_list, offsets = build_adjacent_faces_numba(face_adjacency)
837
+ adjacent_faces = build_adjacent_faces_numba(face_adjacency)
838
+ # with Timer("处理邻接面片2"):
839
+ # adjacent_faces = to_adj_dict(adj_list, offsets)
840
+
841
+ with Timer("采样点云"):
842
+ _points, face_idx = trimesh.sample.sample_surface(mesh, point_num, seed=seed)
843
+ _points_org = _points.copy()
844
+ _points = normalize_pc(_points)
845
+ normals = mesh.face_normals[face_idx]
846
+ # _points = _points + np.random.normal(0, 1, size=_points.shape) * 0.01
847
+ # normals = normals * 0. # debug no normal
848
+ if show_info:
849
+ print(f"点数:{point_num} 面片数:{mesh.faces.shape[0]}")
850
+
851
+ with Timer("获取特征"):
852
+ _feats = get_feat(model, _points, normals)
853
+ if show_info:
854
+ print("预处理特征")
855
+
856
+ if save_mid_res:
857
+ feat_save = _feats.float().detach().cpu().numpy()
858
+ data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
859
+ pca = PCA(n_components=3)
860
+ data_reduced = pca.fit_transform(data_scaled)
861
+ data_reduced = (data_reduced - data_reduced.min()) / (
862
+ data_reduced.max() - data_reduced.min()
863
+ )
864
+ _colors_pca = (data_reduced * 255).astype(np.uint8)
865
+ pc_save = trimesh.points.PointCloud(_points, colors=_colors_pca)
866
+ pc_save.export(os.path.join(save_path, "point_pca.glb"))
867
+ pc_save.export(os.path.join(save_path, "point_pca.ply"))
868
+ if show_info:
869
+ print("PCA获取特征颜色")
870
+
871
+ with Timer("FPS采样提示点"):
872
+ fps_idx = fpsample.fps_sampling(_points, prompt_num)
873
+ _point_prompts = _points[fps_idx]
874
+ if save_mid_res:
875
+ trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
876
+ os.path.join(save_path, "point_prompts_pca.glb")
877
+ )
878
+ trimesh.points.PointCloud(_point_prompts, colors=_colors_pca[fps_idx]).export(
879
+ os.path.join(save_path, "point_prompts_pca.ply")
880
+ )
881
+ if show_info:
882
+ print("采样完成")
883
+
884
+ with Timer("推理"):
885
+ bs = 64
886
+ step_num = prompt_num // bs + 1
887
+ mask_res = []
888
+ iou_res = []
889
+ for i in tqdm(range(step_num), disable=not show_info):
890
+ cur_propmt = _point_prompts[bs * i : bs * (i + 1)]
891
+ # pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask(
892
+ # model, _feats, _points, cur_propmt
893
+ # )
894
+ # pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = model_parallel(
895
+ # _feats, _points, cur_propmt
896
+ # )
897
+ # pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = parallel_run(
898
+ # model_parallel, _feats, _points, cur_propmt
899
+ # )
900
+ pred_mask_1, pred_mask_2, pred_mask_3, pred_iou = get_mask(
901
+ model_parallel, _feats, _points, cur_propmt
902
+ )
903
+ # print(pred_mask_1.shape, pred_mask_2.shape, pred_mask_3.shape, pred_iou.shape)
904
+ pred_mask = np.stack(
905
+ [pred_mask_1, pred_mask_2, pred_mask_3], axis=-1
906
+ ) # [N, K, 3]
907
+ max_idx = np.argmax(pred_iou, axis=-1) # [K]
908
+ for j in range(max_idx.shape[0]):
909
+ mask_res.append(pred_mask[:, j, max_idx[j]])
910
+ iou_res.append(pred_iou[j, max_idx[j]])
911
+ mask_res = np.stack(mask_res, axis=-1) # [N, K]
912
+ if show_info:
913
+ print("prmopt 推理完成")
914
+
915
+ with Timer("根据IOU排序"):
916
+ iou_res = np.array(iou_res).tolist()
917
+ mask_iou = [[mask_res[:, i], iou_res[i]] for i in range(prompt_num)]
918
+ mask_iou_sorted = sorted(mask_iou, key=lambda x: x[1], reverse=True)
919
+ mask_sorted = [mask_iou_sorted[i][0] for i in range(prompt_num)]
920
+ iou_sorted = [mask_iou_sorted[i][1] for i in range(prompt_num)]
921
+
922
+ # clusters = {}
923
+ # for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
924
+ # _mask = mask_sorted[i]
925
+ # union_flag = False
926
+ # for j in clusters.keys():
927
+ # if cal_iou(_mask, mask_sorted[j]) > 0.9:
928
+ # clusters[j].append(i)
929
+ # union_flag = True
930
+ # break
931
+ # if not union_flag:
932
+ # clusters[i] = [i]
933
+ with Timer("NMS"):
934
+ clusters = defaultdict(list)
935
+ with ThreadPoolExecutor(max_workers=20) as executor:
936
+ for i in tqdm(range(prompt_num), desc="NMS", disable=not show_info):
937
+ _mask = mask_sorted[i]
938
+ futures = []
939
+ for j in clusters.keys():
940
+ futures.append(executor.submit(cal_iou, _mask, mask_sorted[j]))
941
+
942
+ for j, future in zip(clusters.keys(), futures):
943
+ if future.result() > 0.9:
944
+ clusters[j].append(i)
945
+ break
946
+ else:
947
+ clusters[i].append(i)
948
+
949
+ # print(clusters)
950
+ if show_info:
951
+ print(f"NMS完成,mask数量:{len(clusters)}")
952
+
953
+ if save_mid_res:
954
+ part_mask_save_path = os.path.join(save_path, "part_mask")
955
+ if os.path.exists(part_mask_save_path):
956
+ shutil.rmtree(part_mask_save_path)
957
+ os.makedirs(part_mask_save_path, exist_ok=True)
958
+ for i in tqdm(clusters.keys(), desc="保存mask", disable=not show_info):
959
+ cluster_num = len(clusters[i])
960
+ cluster_iou = iou_sorted[i]
961
+ cluster_area = np.sum(mask_sorted[i])
962
+ if cluster_num <= 2:
963
+ continue
964
+ mask_save = mask_sorted[i]
965
+ mask_save = np.expand_dims(mask_save, axis=-1)
966
+ mask_save = np.repeat(mask_save, 3, axis=-1)
967
+ mask_save = (mask_save * 255).astype(np.uint8)
968
+ point_save = trimesh.points.PointCloud(_points, colors=mask_save)
969
+ point_save.export(
970
+ os.path.join(
971
+ part_mask_save_path,
972
+ f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
973
+ )
974
+ )
975
+
976
+ # 过滤只有一个mask的cluster
977
+ with Timer("过滤只有一个mask的cluster"):
978
+ filtered_clusters = []
979
+ other_clusters = []
980
+ for i in clusters.keys():
981
+ if len(clusters[i]) > 2:
982
+ filtered_clusters.append(i)
983
+ else:
984
+ other_clusters.append(i)
985
+ if show_info:
986
+ print(
987
+ f"过滤前:{len(clusters)} 个cluster,"
988
+ f"过滤后:{len(filtered_clusters)} 个cluster"
989
+ )
990
+
991
+ # 再次合并
992
+ with Timer("再次合并"):
993
+ filtered_clusters_num = len(filtered_clusters)
994
+ cluster2 = {}
995
+ is_union = [False] * filtered_clusters_num
996
+ for i in range(filtered_clusters_num):
997
+ if is_union[i]:
998
+ continue
999
+ cur_cluster = filtered_clusters[i]
1000
+ cluster2[cur_cluster] = [cur_cluster]
1001
+ for j in range(i + 1, filtered_clusters_num):
1002
+ if is_union[j]:
1003
+ continue
1004
+ tar_cluster = filtered_clusters[j]
1005
+ # if cal_single_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.9:
1006
+ # if cal_iou(mask_sorted[tar_cluster], mask_sorted[cur_cluster]) > 0.5:
1007
+ if (
1008
+ cal_bbox_iou(
1009
+ _points, mask_sorted[tar_cluster], mask_sorted[cur_cluster]
1010
+ )
1011
+ > 0.5
1012
+ ):
1013
+ cluster2[cur_cluster].append(tar_cluster)
1014
+ is_union[j] = True
1015
+ if show_info:
1016
+ print(f"再次合并,合并数量:{len(cluster2.keys())}")
1017
+
1018
+ with Timer("计算没有mask的点"):
1019
+ no_mask = np.ones(point_num)
1020
+ for i in cluster2:
1021
+ part_mask = mask_sorted[i]
1022
+ no_mask[part_mask] = 0
1023
+ if show_info:
1024
+ print(
1025
+ f"{np.sum(no_mask == 1)} 个点没有mask,"
1026
+ f" 占比:{np.sum(no_mask == 1) / point_num:.4f}"
1027
+ )
1028
+
1029
+ with Timer("修补遗漏mask"):
1030
+ # 查询漏掉的mask
1031
+ for i in tqdm(range(len(mask_sorted)), desc="漏掉mask", disable=not show_info):
1032
+ if i in cluster2:
1033
+ continue
1034
+ part_mask = mask_sorted[i]
1035
+ _iou = cal_single_iou(part_mask, no_mask)
1036
+ if _iou > 0.7:
1037
+ cluster2[i] = [i]
1038
+ no_mask[part_mask] = 0
1039
+ if save_mid_res:
1040
+ mask_save = mask_sorted[i]
1041
+ mask_save = np.expand_dims(mask_save, axis=-1)
1042
+ mask_save = np.repeat(mask_save, 3, axis=-1)
1043
+ mask_save = (mask_save * 255).astype(np.uint8)
1044
+ point_save = trimesh.points.PointCloud(_points, colors=mask_save)
1045
+ cluster_iou = iou_sorted[i]
1046
+ cluster_area = int(np.sum(mask_sorted[i]))
1047
+ cluster_num = 1
1048
+ point_save.export(
1049
+ os.path.join(
1050
+ part_mask_save_path,
1051
+ f"mask_{i}_iou_{cluster_iou:.5f}_area_{cluster_area:.5f}_num_{cluster_num}.glb",
1052
+ )
1053
+ )
1054
+ # print(cluster2)
1055
+ # print(len(cluster2.keys()))
1056
+ if show_info:
1057
+ print(f"修补遗漏mask:{len(cluster2.keys())}")
1058
+
1059
+ with Timer("计算点云最终mask"):
1060
+ final_mask = list(cluster2.keys())
1061
+ final_mask_area = [int(np.sum(mask_sorted[i])) for i in final_mask]
1062
+ final_mask_area = [
1063
+ [final_mask[i], final_mask_area[i]] for i in range(len(final_mask))
1064
+ ]
1065
+ final_mask_area_sorted = sorted(
1066
+ final_mask_area, key=lambda x: x[1], reverse=True
1067
+ )
1068
+ final_mask_sorted = [
1069
+ final_mask_area_sorted[i][0] for i in range(len(final_mask_area))
1070
+ ]
1071
+ final_mask_area_sorted = [
1072
+ final_mask_area_sorted[i][1] for i in range(len(final_mask_area))
1073
+ ]
1074
+ # print(final_mask_sorted)
1075
+ # print(final_mask_area_sorted)
1076
+ if show_info:
1077
+ print(f"最终mask数量:{len(final_mask_sorted)}")
1078
+
1079
+ with Timer("点云上色"):
1080
+ # 生成color map
1081
+ color_map = {}
1082
+ for i in final_mask_sorted:
1083
+ part_color = np.random.rand(3) * 255
1084
+ color_map[i] = part_color
1085
+ # print(color_map)
1086
+
1087
+ result_mask = -np.ones(point_num, dtype=np.int64)
1088
+ for i in final_mask_sorted:
1089
+ part_mask = mask_sorted[i]
1090
+ result_mask[part_mask] = i
1091
+ if save_mid_res:
1092
+ # 保存点云结果
1093
+ result_colors = np.zeros_like(_colors_pca)
1094
+ for i in final_mask_sorted:
1095
+ part_color = color_map[i]
1096
+ part_mask = mask_sorted[i]
1097
+ result_colors[part_mask, :3] = part_color
1098
+ trimesh.points.PointCloud(_points, colors=result_colors).export(
1099
+ os.path.join(save_path, "auto_mask_cluster.glb")
1100
+ )
1101
+ trimesh.points.PointCloud(_points, colors=result_colors).export(
1102
+ os.path.join(save_path, "auto_mask_cluster.ply")
1103
+ )
1104
+ if show_info:
1105
+ print("保存点云完成")
1106
+
1107
+ with Timer("投影Mesh并统计label"):
1108
+ # 保存mesh结果
1109
+ face_seg_res = {}
1110
+ for i in final_mask_sorted:
1111
+ _part_mask = result_mask == i
1112
+ _face_idx = face_idx[_part_mask]
1113
+ for k in _face_idx:
1114
+ if k not in face_seg_res:
1115
+ face_seg_res[k] = []
1116
+ face_seg_res[k].append(i)
1117
+ _part_mask = result_mask == -1
1118
+ _face_idx = face_idx[_part_mask]
1119
+ for k in _face_idx:
1120
+ if k not in face_seg_res:
1121
+ face_seg_res[k] = []
1122
+ face_seg_res[k].append(-1)
1123
+
1124
+ face_ids = -np.ones(len(mesh.faces), dtype=np.int64) * 2
1125
+ for i in tqdm(face_seg_res, leave=False, disable=True):
1126
+ _seg_ids = np.array(face_seg_res[i])
1127
+ # 获取最多的seg_id
1128
+ _max_id = np.argmax(np.bincount(_seg_ids + 2)) - 2
1129
+ face_ids[i] = _max_id
1130
+ face_ids_org = face_ids.copy()
1131
+ if show_info:
1132
+ print("生成face_ids完成")
1133
+
1134
+ # 获取邻接面片
1135
+ # face_adjacency = mesh.face_adjacency
1136
+ # adjacent_faces = {}
1137
+ # for face1, face2 in face_adjacency:
1138
+ # if face1 not in adjacent_faces:
1139
+ # adjacent_faces[face1] = []
1140
+ # if face2 not in adjacent_faces:
1141
+ # adjacent_faces[face2] = []
1142
+ # adjacent_faces[face1].append(face2)
1143
+ # adjacent_faces[face2].append(face1)
1144
+
1145
+ with Timer("第一次修复face_ids"):
1146
+ face_ids += 1
1147
+ # face_ids = fix_label(face_ids, adjacent_faces, use_aabb=True, mesh=mesh, show_info=show_info)
1148
+ face_ids = fix_label(face_ids, adjacent_faces, mesh=mesh, show_info=show_info)
1149
+ face_ids -= 1
1150
+ if show_info:
1151
+ print("修复face_ids完成")
1152
+
1153
+ color_map[-1] = np.array([255, 0, 0], dtype=np.uint8)
1154
+
1155
+ if save_mid_res:
1156
+ save_mesh(
1157
+ os.path.join(save_path, "auto_mask_mesh.glb"), mesh, face_ids, color_map
1158
+ )
1159
+ save_mesh(
1160
+ os.path.join(save_path, "auto_mask_mesh_org.glb"),
1161
+ mesh,
1162
+ face_ids_org,
1163
+ color_map,
1164
+ )
1165
+ if show_info:
1166
+ print("保存mesh结果完成")
1167
+
1168
+ with Timer("计算连通区域"):
1169
+ face_areas = calculate_face_areas(mesh)
1170
+ mesh_total_area = np.sum(face_areas)
1171
+ parts = get_connected_region(face_ids, adjacent_faces)
1172
+ connected_parts, _face_connected_parts_ids = get_connected_region(
1173
+ np.ones_like(face_ids), adjacent_faces, return_face_part_ids=True
1174
+ )
1175
+ if show_info:
1176
+ print(f"共{len(parts)}个mesh")
1177
+ with Timer("排序连通区域"):
1178
+ parts_cp_idx = []
1179
+ for x in parts:
1180
+ _face_idx = x[0]
1181
+ parts_cp_idx.append(_face_connected_parts_ids[_face_idx])
1182
+ parts_cp_idx = np.array(parts_cp_idx)
1183
+ parts_areas = [float(np.sum(face_areas[x])) for x in parts]
1184
+ connected_parts_areas = [float(np.sum(face_areas[x])) for x in connected_parts]
1185
+ parts_cp_areas = [connected_parts_areas[x] for x in parts_cp_idx]
1186
+ parts_sorted, parts_areas_sorted, parts_cp_areas_sorted = sort_multi_list(
1187
+ [parts, parts_areas, parts_cp_areas], key=lambda x: x[1], reverse=True
1188
+ )
1189
+
1190
+ with Timer("去除面积过小的区域"):
1191
+ filtered_parts = []
1192
+ other_parts = []
1193
+ for i in range(len(parts_sorted)):
1194
+ parts = parts_sorted[i]
1195
+ area = parts_areas_sorted[i]
1196
+ cp_area = parts_cp_areas_sorted[i]
1197
+ if area / (cp_area + 1e-7) > 0.001:
1198
+ filtered_parts.append(i)
1199
+ else:
1200
+ other_parts.append(i)
1201
+ if show_info:
1202
+ print(f"保留{len(filtered_parts)}个mesh, 其他{len(other_parts)}个mesh")
1203
+
1204
+ with Timer("去除面积过小区域的label"):
1205
+ face_ids_2 = face_ids.copy()
1206
+ part_num = len(cluster2.keys())
1207
+ for j in other_parts:
1208
+ parts = parts_sorted[j]
1209
+ for i in parts:
1210
+ face_ids_2[i] = -1
1211
+
1212
+ with Timer("第二次修复face_ids"):
1213
+ face_ids_3 = face_ids_2.copy()
1214
+ # face_ids_3 = fix_label(face_ids_3, adjacent_faces, use_aabb=True, mesh=mesh, show_info=show_info)
1215
+ face_ids_3 = fix_label(
1216
+ face_ids_3, adjacent_faces, mesh=mesh, show_info=show_info
1217
+ )
1218
+
1219
+ if save_mid_res:
1220
+ save_mesh(
1221
+ os.path.join(save_path, "auto_mask_mesh_filtered_2.glb"),
1222
+ mesh,
1223
+ face_ids_3,
1224
+ color_map,
1225
+ )
1226
+ if show_info:
1227
+ print("保存mesh结果完成")
1228
+
1229
+ with Timer("第二次计算连通区域"):
1230
+ parts_2 = get_connected_region(face_ids_3, adjacent_faces)
1231
+ parts_areas_2 = [float(np.sum(face_areas[x])) for x in parts_2]
1232
+ parts_ids_2 = [face_ids_3[x[0]] for x in parts_2]
1233
+
1234
+ with Timer("添加过大的缺失part"):
1235
+ color_map_2 = copy.deepcopy(color_map)
1236
+ max_id = np.max(parts_ids_2)
1237
+ for i in range(len(parts_2)):
1238
+ _parts = parts_2[i]
1239
+ _area = parts_areas_2[i]
1240
+ _parts_id = face_ids_3[_parts[0]]
1241
+ if _area / mesh_total_area > 0.001:
1242
+ if _parts_id == -1 or _parts_id == -2:
1243
+ parts_ids_2[i] = max_id + 1
1244
+ max_id += 1
1245
+ color_map_2[max_id] = np.random.rand(3) * 255
1246
+ if show_info:
1247
+ print(f"新增part {max_id}")
1248
+ # else:
1249
+ # parts_ids_2[i] = -1
1250
+
1251
+ with Timer("赋值新的face_ids"):
1252
+ face_ids_4 = face_ids_3.copy()
1253
+ for i in range(len(parts_2)):
1254
+ _parts = parts_2[i]
1255
+ _parts_id = parts_ids_2[i]
1256
+ for j in _parts:
1257
+ face_ids_4[j] = _parts_id
1258
+ with Timer("计算part和label的aabb"):
1259
+ ids_aabb = {}
1260
+ unique_ids = np.unique(face_ids_4)
1261
+ for i in unique_ids:
1262
+ if i < 0:
1263
+ continue
1264
+ _part_mask = face_ids_4 == i
1265
+ _faces = mesh.faces[_part_mask]
1266
+ _faces = np.reshape(_faces, (-1))
1267
+ _points = mesh.vertices[_faces]
1268
+ min_xyz = np.min(_points, axis=0)
1269
+ max_xyz = np.max(_points, axis=0)
1270
+ ids_aabb[i] = [min_xyz, max_xyz]
1271
+
1272
+ parts_2_aabb = []
1273
+ for i in range(len(parts_2)):
1274
+ _parts = parts_2[i]
1275
+ _faces = mesh.faces[_parts]
1276
+ _faces = np.reshape(_faces, (-1))
1277
+ _points = mesh.vertices[_faces]
1278
+ min_xyz = np.min(_points, axis=0)
1279
+ max_xyz = np.max(_points, axis=0)
1280
+ parts_2_aabb.append([min_xyz, max_xyz])
1281
+
1282
+ with Timer("计算part的邻居"):
1283
+ parts_2_neighbor = find_neighbor_part(
1284
+ parts_2, adjacent_faces, parts_2_aabb, parts_ids_2
1285
+ )
1286
+ with Timer("合并无mask区域"):
1287
+ for i in range(len(parts_2)):
1288
+ _parts = parts_2[i]
1289
+ _ids = parts_ids_2[i]
1290
+ if _ids == -1 or _ids == -2:
1291
+ _cur_aabb = parts_2_aabb[i]
1292
+ _min_aabb_increase = 1e10
1293
+ _min_id = -1
1294
+ for j in parts_2_neighbor[i]:
1295
+ if parts_ids_2[j] == -1 or parts_ids_2[j] == -2:
1296
+ continue
1297
+ _tar_id = parts_ids_2[j]
1298
+ _tar_aabb = ids_aabb[_tar_id]
1299
+ _min_increase, _max_increase = aabb_increase(_tar_aabb, _cur_aabb)
1300
+ _increase = max(np.max(_min_increase), np.max(_max_increase))
1301
+ if _min_aabb_increase > _increase:
1302
+ _min_aabb_increase = _increase
1303
+ _min_id = _tar_id
1304
+ if _min_id >= 0:
1305
+ parts_ids_2[i] = _min_id
1306
+
1307
+ with Timer("再次赋值新的face_ids"):
1308
+ face_ids_4 = face_ids_3.copy()
1309
+ for i in range(len(parts_2)):
1310
+ _parts = parts_2[i]
1311
+ _parts_id = parts_ids_2[i]
1312
+ for j in _parts:
1313
+ face_ids_4[j] = _parts_id
1314
+
1315
+ final_face_ids = face_ids_4
1316
+ if save_mid_res:
1317
+ save_mesh(
1318
+ os.path.join(save_path, "auto_mask_mesh_final.glb"),
1319
+ mesh,
1320
+ face_ids_4,
1321
+ color_map_2,
1322
+ )
1323
+
1324
+ if post_process:
1325
+ parts = get_connected_region(final_face_ids, adjacent_faces)
1326
+ final_face_ids = do_no_mask_process(parts, final_face_ids)
1327
+ face_ids_5 = do_post_process(
1328
+ face_areas,
1329
+ parts,
1330
+ adjacent_faces,
1331
+ face_ids_4,
1332
+ threshold,
1333
+ show_info=show_info,
1334
+ )
1335
+ if save_mid_res:
1336
+ save_mesh(
1337
+ os.path.join(save_path, "auto_mask_mesh_final_post.glb"),
1338
+ mesh,
1339
+ face_ids_5,
1340
+ color_map_2,
1341
+ )
1342
+ final_face_ids = face_ids_5
1343
+ with Timer("计算最后的aabb"):
1344
+ aabb = get_aabb_from_face_ids(mesh, final_face_ids)
1345
+ return aabb, final_face_ids, mesh
1346
+
1347
+
1348
+ class AutoMask:
1349
+ def __init__(
1350
+ self,
1351
+ ckpt_path,
1352
+ point_num=100000,
1353
+ prompt_num=400,
1354
+ threshold=0.95,
1355
+ post_process=True,
1356
+ ):
1357
+ """
1358
+ ckpt_path: str, 模型路径
1359
+ point_num: int, 采样点数量
1360
+ prompt_num: int, 提示数量
1361
+ threshold: float, 阈值
1362
+ post_process: bool, 是否后处理
1363
+ """
1364
+ self.model = YSAM()
1365
+ self.model.load_state_dict(
1366
+ state_dict=torch.load(ckpt_path, map_location="cpu")["state_dict"]
1367
+ )
1368
+ self.model.eval()
1369
+ self.model_parallel = torch.nn.DataParallel(self.model)
1370
+ self.model.cuda()
1371
+ self.model_parallel.cuda()
1372
+ self.point_num = point_num
1373
+ self.prompt_num = prompt_num
1374
+ self.threshold = threshold
1375
+ self.post_process = post_process
1376
+
1377
+ def predict_aabb(
1378
+ self,
1379
+ mesh,
1380
+ point_num=None,
1381
+ prompt_num=None,
1382
+ threshold=None,
1383
+ post_process=None,
1384
+ save_path=None,
1385
+ save_mid_res=False,
1386
+ show_info=True,
1387
+ clean_mesh_flag=True,
1388
+ seed=42,
1389
+ ):
1390
+ """
1391
+ Parameters:
1392
+ mesh: trimesh.Trimesh, 输入网格
1393
+ point_num: int, 采样点数量
1394
+ prompt_num: int, 提示数量
1395
+ threshold: float, 阈值
1396
+ post_process: bool, 是否后处理
1397
+ Returns:
1398
+ aabb: np.ndarray, 包围盒
1399
+ face_ids: np.ndarray, 面id
1400
+ """
1401
+ point_num = point_num if point_num is not None else self.point_num
1402
+ prompt_num = prompt_num if prompt_num is not None else self.prompt_num
1403
+ threshold = threshold if threshold is not None else self.threshold
1404
+ post_process = post_process if post_process is not None else self.post_process
1405
+ return mesh_sam(
1406
+ [self.model, self.model_parallel],
1407
+ mesh,
1408
+ save_path=save_path,
1409
+ point_num=point_num,
1410
+ prompt_num=prompt_num,
1411
+ threshold=threshold,
1412
+ post_process=post_process,
1413
+ show_info=show_info,
1414
+ save_mid_res=save_mid_res,
1415
+ clean_mesh_flag=clean_mesh_flag,
1416
+ seed=seed,
1417
+ )
XPart/partgen/config/infer.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "Xpart Pipeline release"
2
+
3
+ ckpt_path: checkpoints/xpart.pt
4
+
5
+ shapevae:
6
+ target: partgen.models.autoencoders.VolumeDecoderShapeVAE
7
+ params:
8
+ num_latents: &num_latents 1024
9
+ embed_dim: 64
10
+ num_freqs: 8
11
+ include_pi: false
12
+ heads: 16
13
+ width: 1024
14
+ num_encoder_layers: 8
15
+ num_decoder_layers: 16
16
+ qkv_bias: false
17
+ qk_norm: true
18
+ scale_factor: &z_scale_factor 1.0039506158752403
19
+ geo_decoder_mlp_expand_ratio: 4
20
+ geo_decoder_downsample_ratio: 1
21
+ geo_decoder_ln_post: true
22
+ point_feats: 4
23
+ pc_size: &pc_size 81920
24
+ pc_sharpedge_size: &pc_sharpedge_size 0
25
+
26
+ bbox_predictor:
27
+ target: partgen.bbox_estimator.auto_mask_api.AutoMask
28
+ params:
29
+ ckpt_path: checkpoints/p3sam.ckpt
30
+ conditioner:
31
+ target: partgen.models.conditioner.condioner_release.Conditioner
32
+ params:
33
+ use_geo: true
34
+ use_obj: true
35
+ use_seg_feat: true
36
+ geo_cfg:
37
+ target: partgen.models.conditioner.part_encoders.PartEncoder
38
+ output_dim: &cross2_output_dim 1024
39
+ params:
40
+ use_local: true
41
+ local_feat_type: latents_shape # [latents,miche-point-query-structural-vae]
42
+ num_tokens_cond: &num_tokens_cond 4096 # num_tokens :2048 for holopart conditioner
43
+ local_geo_cfg:
44
+ target: partgen.models.autoencoders.VolumeDecoderShapeVAE
45
+ params:
46
+ num_latents: *num_tokens_cond
47
+ embed_dim: 64
48
+ num_freqs: 8
49
+ include_pi: false
50
+ heads: 16
51
+ width: 1024
52
+ num_encoder_layers: 8
53
+ num_decoder_layers: 16
54
+ qkv_bias: false
55
+ qk_norm: true
56
+ scale_factor: *z_scale_factor
57
+ geo_decoder_mlp_expand_ratio: 4
58
+ geo_decoder_downsample_ratio: 1
59
+ geo_decoder_ln_post: true
60
+ point_feats: 4
61
+ pc_size: &pc_size_bbox 81920
62
+ pc_sharpedge_size: &pc_sharpedge_size_bbox 0
63
+
64
+ obj_encoder_cfg:
65
+ target: partgen.models.autoencoders.VolumeDecoderShapeVAE
66
+ output_dim: &cross1_output_dim 1024
67
+ params:
68
+ num_latents: 4096
69
+ embed_dim: 64
70
+ num_freqs: 8
71
+ include_pi: false
72
+ heads: 16
73
+ width: 1024
74
+ num_encoder_layers: 8
75
+ num_decoder_layers: 16
76
+ qkv_bias: false
77
+ qk_norm: true
78
+ scale_factor: 1.0039506158752403
79
+ geo_decoder_mlp_expand_ratio: 4
80
+ geo_decoder_downsample_ratio: 1
81
+ geo_decoder_ln_post: true
82
+ point_feats: 4
83
+ pc_size: *pc_size
84
+ pc_sharpedge_size: *pc_sharpedge_size
85
+ seg_feat_cfg:
86
+ target: partgen.models.conditioner.sonata_extractor.SonataFeatureExtractor
87
+
88
+ model:
89
+ target: partgen.models.partformer_dit.PartFormerDITPlain
90
+ params:
91
+ use_self_attention: true
92
+ use_cross_attention: true
93
+ use_cross_attention_2: true
94
+ # cond
95
+ use_bbox_cond: false
96
+ num_freqs: 8
97
+ use_part_embed: true
98
+ valid_num: 50 #*valid_num
99
+ # para
100
+ input_size: *num_latents
101
+ in_channels: 64
102
+ hidden_size: 2048
103
+ encoder_hidden_dim: *cross1_output_dim # for object mesh
104
+ encoder_hidden2_dim: *cross2_output_dim # for part in bbox
105
+ depth: 21
106
+ num_heads: 16
107
+ qk_norm: true
108
+ qkv_bias: false
109
+ qk_norm_type: 'rms'
110
+ with_decoupled_ca: false
111
+ decoupled_ca_dim: *num_tokens_cond
112
+ decoupled_ca_weight: 1.0
113
+ use_attention_pooling: false
114
+ use_pos_emb: false
115
+ num_moe_layers: 6
116
+ num_experts: 8
117
+ moe_top_k: 2
118
+
119
+ scheduler:
120
+ target: partgen.models.diffusion.schedulers.FlowMatchEulerDiscreteScheduler
121
+ params:
122
+ num_train_timesteps: 1000
XPart/partgen/config/sonata.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 9,
3
+ "order": [
4
+ "z",
5
+ "z-trans",
6
+ "hilbert",
7
+ "hilbert-trans"
8
+ ],
9
+ "stride": [
10
+ 2,
11
+ 2,
12
+ 2,
13
+ 2
14
+ ],
15
+ "enc_depths": [
16
+ 3,
17
+ 3,
18
+ 3,
19
+ 12,
20
+ 3
21
+ ],
22
+ "enc_channels": [
23
+ 48,
24
+ 96,
25
+ 192,
26
+ 384,
27
+ 512
28
+ ],
29
+ "enc_num_head": [
30
+ 3,
31
+ 6,
32
+ 12,
33
+ 24,
34
+ 32
35
+ ],
36
+ "enc_patch_size": [
37
+ 1024,
38
+ 1024,
39
+ 1024,
40
+ 1024,
41
+ 1024
42
+ ],
43
+ "mlp_ratio": 4,
44
+ "qkv_bias": true,
45
+ "qk_scale": null,
46
+ "attn_drop": 0.0,
47
+ "proj_drop": 0.0,
48
+ "drop_path": 0.3,
49
+ "shuffle_orders": true,
50
+ "pre_norm": true,
51
+ "enable_rpe": false,
52
+ "enable_flash": true,
53
+ "upcast_attention": false,
54
+ "upcast_softmax": false,
55
+ "traceable": true,
56
+ "enc_mode": true,
57
+ "mask_token": true
58
+ }
XPart/partgen/models/autoencoders/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
+ # except for the third-party components listed below.
3
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
+ # in the repsective licenses of these third-party components.
5
+ # Users must comply with all terms and conditions of original licenses of these third-party
6
+ # components and must ensure that the usage of the third party components adheres to
7
+ # all relevant laws and regulations.
8
+
9
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
10
+ # their software and algorithms, including trained model weights, parameters (including
11
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
13
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
+
15
+ from .attention_blocks import CrossAttentionDecoder
16
+ from .attention_processors import (
17
+ CrossAttentionProcessor,
18
+ )
19
+ from .model import VectsetVAE, VolumeDecoderShapeVAE
20
+
21
+ from .surface_extractors import (
22
+ SurfaceExtractors,
23
+ MCSurfaceExtractor,
24
+ DMCSurfaceExtractor,
25
+ Latent2MeshOutput,
26
+ )
27
+ from .volume_decoders import (
28
+ VanillaVolumeDecoder,
29
+ )
XPart/partgen/models/autoencoders/attention_blocks.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+
26
+ import os
27
+ from typing import Optional, Union, List
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from einops import rearrange
32
+ from torch import Tensor
33
+
34
+ from .attention_processors import CrossAttentionProcessor
35
+ from ...utils.misc import logger
36
+
37
+ scaled_dot_product_attention = nn.functional.scaled_dot_product_attention
38
+
39
+ if os.environ.get("USE_SAGEATTN", "0") == "1":
40
+ try:
41
+ from sageattention import sageattn
42
+ except ImportError:
43
+ raise ImportError(
44
+ 'Please install the package "sageattention" to use this USE_SAGEATTN.'
45
+ )
46
+ scaled_dot_product_attention = sageattn
47
+
48
+
49
+ class FourierEmbedder(nn.Module):
50
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
51
+ each feature dimension of `x[..., i]` into:
52
+ [
53
+ sin(x[..., i]),
54
+ sin(f_1*x[..., i]),
55
+ sin(f_2*x[..., i]),
56
+ ...
57
+ sin(f_N * x[..., i]),
58
+ cos(x[..., i]),
59
+ cos(f_1*x[..., i]),
60
+ cos(f_2*x[..., i]),
61
+ ...
62
+ cos(f_N * x[..., i]),
63
+ x[..., i] # only present if include_input is True.
64
+ ], here f_i is the frequency.
65
+
66
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
67
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
68
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
69
+
70
+ Args:
71
+ num_freqs (int): the number of frequencies, default is 6;
72
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
73
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
74
+ input_dim (int): the input dimension, default is 3;
75
+ include_input (bool): include the input tensor or not, default is True.
76
+
77
+ Attributes:
78
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
79
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
80
+
81
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
82
+ otherwise, it is input_dim * num_freqs * 2.
83
+
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ num_freqs: int = 6,
89
+ logspace: bool = True,
90
+ input_dim: int = 3,
91
+ include_input: bool = True,
92
+ include_pi: bool = True,
93
+ ) -> None:
94
+ """The initialization"""
95
+
96
+ super().__init__()
97
+
98
+ if logspace:
99
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
100
+ else:
101
+ frequencies = torch.linspace(
102
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
103
+ )
104
+
105
+ if include_pi:
106
+ frequencies *= torch.pi
107
+
108
+ self.register_buffer("frequencies", frequencies, persistent=False)
109
+ self.include_input = include_input
110
+ self.num_freqs = num_freqs
111
+
112
+ self.out_dim = self.get_dims(input_dim)
113
+
114
+ def get_dims(self, input_dim):
115
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
116
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
117
+
118
+ return out_dim
119
+
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ """Forward process.
122
+
123
+ Args:
124
+ x: tensor of shape [..., dim]
125
+
126
+ Returns:
127
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
128
+ where temp is 1 if include_input is True and 0 otherwise.
129
+ """
130
+
131
+ if self.num_freqs > 0:
132
+ embed = (x[..., None].contiguous() * self.frequencies).view(
133
+ *x.shape[:-1], -1
134
+ )
135
+ if self.include_input:
136
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
137
+ else:
138
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
139
+ else:
140
+ return x
141
+
142
+
143
+ class DropPath(nn.Module):
144
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
145
+
146
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
147
+ super(DropPath, self).__init__()
148
+ self.drop_prob = drop_prob
149
+ self.scale_by_keep = scale_by_keep
150
+
151
+ def forward(self, x):
152
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
153
+
154
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
155
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
156
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
157
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
158
+ 'survival rate' as the argument.
159
+
160
+ """
161
+ if self.drop_prob == 0.0 or not self.training:
162
+ return x
163
+ keep_prob = 1 - self.drop_prob
164
+ shape = (x.shape[0],) + (1,) * (
165
+ x.ndim - 1
166
+ ) # work with diff dim tensors, not just 2D ConvNets
167
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
168
+ if keep_prob > 0.0 and self.scale_by_keep:
169
+ random_tensor.div_(keep_prob)
170
+ return x * random_tensor
171
+
172
+ def extra_repr(self):
173
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
174
+
175
+
176
+ class MLP(nn.Module):
177
+ def __init__(
178
+ self,
179
+ *,
180
+ width: int,
181
+ expand_ratio: int = 4,
182
+ output_width: int = None,
183
+ drop_path_rate: float = 0.0,
184
+ ):
185
+ super().__init__()
186
+ self.width = width
187
+ self.c_fc = nn.Linear(width, width * expand_ratio)
188
+ self.c_proj = nn.Linear(
189
+ width * expand_ratio, output_width if output_width is not None else width
190
+ )
191
+ self.gelu = nn.GELU()
192
+ self.drop_path = (
193
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
194
+ )
195
+
196
+ def forward(self, x):
197
+ return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
198
+
199
+
200
+ class QKVMultiheadCrossAttention(nn.Module):
201
+ def __init__(
202
+ self,
203
+ *,
204
+ heads: int,
205
+ width=None,
206
+ qk_norm=False,
207
+ norm_layer=nn.LayerNorm,
208
+ ):
209
+ super().__init__()
210
+ self.heads = heads
211
+ self.q_norm = (
212
+ norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
213
+ if qk_norm
214
+ else nn.Identity()
215
+ )
216
+ self.k_norm = (
217
+ norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
218
+ if qk_norm
219
+ else nn.Identity()
220
+ )
221
+
222
+ self.attn_processor = CrossAttentionProcessor()
223
+
224
+ def forward(self, q, kv):
225
+ _, n_ctx, _ = q.shape
226
+ bs, n_data, width = kv.shape
227
+ attn_ch = width // self.heads // 2
228
+ q = q.view(bs, n_ctx, self.heads, -1)
229
+ kv = kv.view(bs, n_data, self.heads, -1)
230
+ k, v = torch.split(kv, attn_ch, dim=-1)
231
+
232
+ q = self.q_norm(q)
233
+ k = self.k_norm(k)
234
+ q, k, v = map(
235
+ lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v)
236
+ )
237
+ out = self.attn_processor(self, q, k, v)
238
+ out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
239
+ return out
240
+
241
+
242
+ class MultiheadCrossAttention(nn.Module):
243
+ def __init__(
244
+ self,
245
+ *,
246
+ width: int,
247
+ heads: int,
248
+ qkv_bias: bool = True,
249
+ data_width: Optional[int] = None,
250
+ norm_layer=nn.LayerNorm,
251
+ qk_norm: bool = False,
252
+ kv_cache: bool = False,
253
+ ):
254
+ super().__init__()
255
+ self.width = width
256
+ self.heads = heads
257
+ self.data_width = width if data_width is None else data_width
258
+ self.c_q = nn.Linear(width, width, bias=qkv_bias)
259
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
260
+ self.c_proj = nn.Linear(width, width)
261
+ self.attention = QKVMultiheadCrossAttention(
262
+ heads=heads,
263
+ width=width,
264
+ norm_layer=norm_layer,
265
+ qk_norm=qk_norm,
266
+ )
267
+ self.kv_cache = kv_cache
268
+ self.data = None
269
+
270
+ def forward(self, x, data):
271
+ x = self.c_q(x)
272
+ if self.kv_cache:
273
+ if self.data is None:
274
+ self.data = self.c_kv(data)
275
+ logger.info(
276
+ "Save kv cache,this should be called only once for one mesh"
277
+ )
278
+ data = self.data
279
+ else:
280
+ data = self.c_kv(data)
281
+ x = self.attention(x, data)
282
+ x = self.c_proj(x)
283
+ return x
284
+
285
+
286
+ class ResidualCrossAttentionBlock(nn.Module):
287
+ def __init__(
288
+ self,
289
+ *,
290
+ width: int,
291
+ heads: int,
292
+ mlp_expand_ratio: int = 4,
293
+ data_width: Optional[int] = None,
294
+ qkv_bias: bool = True,
295
+ norm_layer=nn.LayerNorm,
296
+ qk_norm: bool = False,
297
+ ):
298
+ super().__init__()
299
+
300
+ if data_width is None:
301
+ data_width = width
302
+
303
+ self.attn = MultiheadCrossAttention(
304
+ width=width,
305
+ heads=heads,
306
+ data_width=data_width,
307
+ qkv_bias=qkv_bias,
308
+ norm_layer=norm_layer,
309
+ qk_norm=qk_norm,
310
+ )
311
+ self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
312
+ self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
313
+ self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
314
+ self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
315
+
316
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
317
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
318
+ x = x + self.mlp(self.ln_3(x))
319
+ return x
320
+
321
+
322
+ class QKVMultiheadAttention(nn.Module):
323
+ def __init__(
324
+ self, *, heads: int, width=None, qk_norm=False, norm_layer=nn.LayerNorm
325
+ ):
326
+ super().__init__()
327
+ self.heads = heads
328
+ self.q_norm = (
329
+ norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
330
+ if qk_norm
331
+ else nn.Identity()
332
+ )
333
+ self.k_norm = (
334
+ norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
335
+ if qk_norm
336
+ else nn.Identity()
337
+ )
338
+
339
+ def forward(self, qkv):
340
+ bs, n_ctx, width = qkv.shape
341
+ attn_ch = width // self.heads // 3
342
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
343
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
344
+
345
+ q = self.q_norm(q)
346
+ k = self.k_norm(k)
347
+
348
+ q, k, v = map(
349
+ lambda t: rearrange(t, "b n h d -> b h n d", h=self.heads), (q, k, v)
350
+ )
351
+ out = (
352
+ scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
353
+ )
354
+ return out
355
+
356
+
357
+ class MultiheadAttention(nn.Module):
358
+ def __init__(
359
+ self,
360
+ *,
361
+ width: int,
362
+ heads: int,
363
+ qkv_bias: bool,
364
+ norm_layer=nn.LayerNorm,
365
+ qk_norm: bool = False,
366
+ drop_path_rate: float = 0.0,
367
+ ):
368
+ super().__init__()
369
+ self.width = width
370
+ self.heads = heads
371
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
372
+ self.c_proj = nn.Linear(width, width)
373
+ self.attention = QKVMultiheadAttention(
374
+ heads=heads,
375
+ width=width,
376
+ norm_layer=norm_layer,
377
+ qk_norm=qk_norm,
378
+ )
379
+ self.drop_path = (
380
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
381
+ )
382
+
383
+ def forward(self, x):
384
+ x = self.c_qkv(x)
385
+ x = self.attention(x)
386
+ x = self.drop_path(self.c_proj(x))
387
+ return x
388
+
389
+
390
+ class ResidualAttentionBlock(nn.Module):
391
+ def __init__(
392
+ self,
393
+ *,
394
+ width: int,
395
+ heads: int,
396
+ qkv_bias: bool = True,
397
+ norm_layer=nn.LayerNorm,
398
+ qk_norm: bool = False,
399
+ drop_path_rate: float = 0.0,
400
+ ):
401
+ super().__init__()
402
+ self.attn = MultiheadAttention(
403
+ width=width,
404
+ heads=heads,
405
+ qkv_bias=qkv_bias,
406
+ norm_layer=norm_layer,
407
+ qk_norm=qk_norm,
408
+ drop_path_rate=drop_path_rate,
409
+ )
410
+ self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
411
+ self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
412
+ self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
413
+
414
+ def forward(self, x: torch.Tensor):
415
+ x = x + self.attn(self.ln_1(x))
416
+ x = x + self.mlp(self.ln_2(x))
417
+ return x
418
+
419
+
420
+ class Transformer(nn.Module):
421
+ def __init__(
422
+ self,
423
+ *,
424
+ width: int,
425
+ layers: int,
426
+ heads: int,
427
+ qkv_bias: bool = True,
428
+ norm_layer=nn.LayerNorm,
429
+ qk_norm: bool = False,
430
+ drop_path_rate: float = 0.0,
431
+ ):
432
+ super().__init__()
433
+ self.width = width
434
+ self.layers = layers
435
+ self.resblocks = nn.ModuleList([
436
+ ResidualAttentionBlock(
437
+ width=width,
438
+ heads=heads,
439
+ qkv_bias=qkv_bias,
440
+ norm_layer=norm_layer,
441
+ qk_norm=qk_norm,
442
+ drop_path_rate=drop_path_rate,
443
+ )
444
+ for _ in range(layers)
445
+ ])
446
+
447
+ def forward(self, x: torch.Tensor):
448
+ for block in self.resblocks:
449
+ x = block(x)
450
+ return x
451
+
452
+
453
+ class CrossAttentionDecoder(nn.Module):
454
+
455
+ def __init__(
456
+ self,
457
+ *,
458
+ out_channels: int,
459
+ fourier_embedder: FourierEmbedder,
460
+ width: int,
461
+ heads: int,
462
+ mlp_expand_ratio: int = 4,
463
+ downsample_ratio: int = 1,
464
+ enable_ln_post: bool = True,
465
+ qkv_bias: bool = True,
466
+ qk_norm: bool = False,
467
+ label_type: str = "binary",
468
+ ):
469
+ super().__init__()
470
+
471
+ self.enable_ln_post = enable_ln_post
472
+ self.fourier_embedder = fourier_embedder
473
+ self.downsample_ratio = downsample_ratio
474
+ self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width)
475
+ if self.downsample_ratio != 1:
476
+ self.latents_proj = nn.Linear(width * downsample_ratio, width)
477
+ if self.enable_ln_post == False:
478
+ qk_norm = False
479
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
480
+ width=width,
481
+ mlp_expand_ratio=mlp_expand_ratio,
482
+ heads=heads,
483
+ qkv_bias=qkv_bias,
484
+ qk_norm=qk_norm,
485
+ )
486
+
487
+ if self.enable_ln_post:
488
+ self.ln_post = nn.LayerNorm(width)
489
+ self.output_proj = nn.Linear(width, out_channels)
490
+ self.label_type = label_type
491
+ self.count = 0
492
+
493
+ def set_cross_attention_processor(self, processor):
494
+ self.cross_attn_decoder.attn.attention.attn_processor = processor
495
+
496
+ # def set_default_cross_attention_processor(self):
497
+ # self.cross_attn_decoder.attn.attention.attn_processor = CrossAttentionProcessor
498
+
499
+ def forward(self, queries=None, query_embeddings=None, latents=None):
500
+ if query_embeddings is None:
501
+ query_embeddings = self.query_proj(
502
+ self.fourier_embedder(queries).to(latents.dtype)
503
+ )
504
+ self.count += query_embeddings.shape[1]
505
+ if self.downsample_ratio != 1:
506
+ latents = self.latents_proj(latents)
507
+ x = self.cross_attn_decoder(query_embeddings, latents)
508
+ if self.enable_ln_post:
509
+ x = self.ln_post(x)
510
+ occ = self.output_proj(x)
511
+ return occ
512
+
513
+
514
+ def fps(
515
+ src: torch.Tensor,
516
+ batch: Optional[Tensor] = None,
517
+ ratio: Optional[Union[Tensor, float]] = None,
518
+ random_start: bool = True,
519
+ batch_size: Optional[int] = None,
520
+ ptr: Optional[Union[Tensor, List[int]]] = None,
521
+ ):
522
+ src = src.float()
523
+ from torch_cluster import fps as fps_fn
524
+
525
+ output = fps_fn(src, batch, ratio, random_start, batch_size, ptr)
526
+ return output
527
+
528
+
529
+ class PointCrossAttentionEncoder(nn.Module):
530
+
531
+ def __init__(
532
+ self,
533
+ *,
534
+ num_latents: int,
535
+ downsample_ratio: float,
536
+ pc_size: int,
537
+ pc_sharpedge_size: int,
538
+ fourier_embedder: FourierEmbedder,
539
+ point_feats: int,
540
+ width: int,
541
+ heads: int,
542
+ layers: int,
543
+ normal_pe: bool = False,
544
+ qkv_bias: bool = True,
545
+ use_ln_post: bool = False,
546
+ use_checkpoint: bool = False,
547
+ qk_norm: bool = False,
548
+ ):
549
+
550
+ super().__init__()
551
+
552
+ self.use_checkpoint = use_checkpoint
553
+ self.num_latents = num_latents
554
+ self.downsample_ratio = downsample_ratio
555
+ self.point_feats = point_feats
556
+ self.normal_pe = normal_pe
557
+
558
+ if pc_sharpedge_size == 0:
559
+ print(
560
+ f"PointCrossAttentionEncoder INFO: pc_sharpedge_size is not given,"
561
+ f" using pc_size as pc_sharpedge_size"
562
+ )
563
+ else:
564
+ print(
565
+ "PointCrossAttentionEncoder INFO: pc_sharpedge_size is given, using"
566
+ f" pc_size={pc_size}, pc_sharpedge_size={pc_sharpedge_size}"
567
+ )
568
+
569
+ self.pc_size = pc_size
570
+ self.pc_sharpedge_size = pc_sharpedge_size
571
+
572
+ self.fourier_embedder = fourier_embedder
573
+
574
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
575
+ self.cross_attn = ResidualCrossAttentionBlock(
576
+ width=width, heads=heads, qkv_bias=qkv_bias, qk_norm=qk_norm
577
+ )
578
+
579
+ self.self_attn = None
580
+ if layers > 0:
581
+ self.self_attn = Transformer(
582
+ width=width,
583
+ layers=layers,
584
+ heads=heads,
585
+ qkv_bias=qkv_bias,
586
+ qk_norm=qk_norm,
587
+ )
588
+
589
+ if use_ln_post:
590
+ self.ln_post = nn.LayerNorm(width)
591
+ else:
592
+ self.ln_post = None
593
+
594
+ def sample_points_and_latents(
595
+ self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None
596
+ ):
597
+ B, N, D = pc.shape
598
+ num_pts = self.num_latents * self.downsample_ratio
599
+
600
+ # Compute number of latents
601
+ num_latents = int(num_pts / self.downsample_ratio)
602
+
603
+ # Compute the number of random and sharpedge latents
604
+ num_random_query = (
605
+ self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
606
+ )
607
+ num_sharpedge_query = num_latents - num_random_query
608
+
609
+ # Split random and sharpedge surface points
610
+ random_pc, sharpedge_pc = torch.split(
611
+ pc, [self.pc_size, self.pc_sharpedge_size], dim=1
612
+ )
613
+ assert (
614
+ random_pc.shape[1] <= self.pc_size
615
+ ), "Random surface points size must be less than or equal to pc_size"
616
+ assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, (
617
+ "Sharpedge surface points size must be less than or equal to"
618
+ " pc_sharpedge_size"
619
+ )
620
+
621
+ # Randomly select random surface points and random query points
622
+ input_random_pc_size = int(num_random_query * self.downsample_ratio)
623
+ random_query_ratio = num_random_query / input_random_pc_size
624
+ idx_random_pc = torch.randperm(random_pc.shape[1], device=random_pc.device)[
625
+ :input_random_pc_size
626
+ ]
627
+ input_random_pc = random_pc[:, idx_random_pc, :]
628
+ flatten_input_random_pc = input_random_pc.view(B * input_random_pc_size, D)
629
+ N_down = int(flatten_input_random_pc.shape[0] / B)
630
+ batch_down = torch.arange(B).to(pc.device)
631
+ batch_down = torch.repeat_interleave(batch_down, N_down)
632
+ idx_query_random = fps(
633
+ flatten_input_random_pc, batch_down, ratio=random_query_ratio
634
+ )
635
+ query_random_pc = flatten_input_random_pc[idx_query_random].view(B, -1, D)
636
+
637
+ # Randomly select sharpedge surface points and sharpedge query points
638
+ input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
639
+ if input_sharpedge_pc_size == 0:
640
+ input_sharpedge_pc = torch.zeros(B, 0, D, dtype=input_random_pc.dtype).to(
641
+ pc.device
642
+ )
643
+ query_sharpedge_pc = torch.zeros(B, 0, D, dtype=query_random_pc.dtype).to(
644
+ pc.device
645
+ )
646
+ else:
647
+ sharpedge_query_ratio = num_sharpedge_query / input_sharpedge_pc_size
648
+ idx_sharpedge_pc = torch.randperm(
649
+ sharpedge_pc.shape[1], device=sharpedge_pc.device
650
+ )[:input_sharpedge_pc_size]
651
+ input_sharpedge_pc = sharpedge_pc[:, idx_sharpedge_pc, :]
652
+ flatten_input_sharpedge_surface_points = input_sharpedge_pc.view(
653
+ B * input_sharpedge_pc_size, D
654
+ )
655
+ N_down = int(flatten_input_sharpedge_surface_points.shape[0] / B)
656
+ batch_down = torch.arange(B).to(pc.device)
657
+ batch_down = torch.repeat_interleave(batch_down, N_down)
658
+ idx_query_sharpedge = fps(
659
+ flatten_input_sharpedge_surface_points,
660
+ batch_down,
661
+ ratio=sharpedge_query_ratio,
662
+ )
663
+ query_sharpedge_pc = flatten_input_sharpedge_surface_points[
664
+ idx_query_sharpedge
665
+ ].view(B, -1, D)
666
+
667
+ # Concatenate random and sharpedge surface points and query points
668
+ query_pc = torch.cat([query_random_pc, query_sharpedge_pc], dim=1)
669
+ input_pc = torch.cat([input_random_pc, input_sharpedge_pc], dim=1)
670
+
671
+ # PE
672
+ query = self.fourier_embedder(query_pc)
673
+ data = self.fourier_embedder(input_pc)
674
+
675
+ # Concat normal if given
676
+ if self.point_feats != 0:
677
+
678
+ random_surface_feats, sharpedge_surface_feats = torch.split(
679
+ feats, [self.pc_size, self.pc_sharpedge_size], dim=1
680
+ )
681
+ input_random_surface_feats = random_surface_feats[:, idx_random_pc, :]
682
+ flatten_input_random_surface_feats = input_random_surface_feats.view(
683
+ B * input_random_pc_size, -1
684
+ )
685
+ query_random_feats = flatten_input_random_surface_feats[
686
+ idx_query_random
687
+ ].view(B, -1, flatten_input_random_surface_feats.shape[-1])
688
+
689
+ if input_sharpedge_pc_size == 0:
690
+ input_sharpedge_surface_feats = torch.zeros(
691
+ B, 0, self.point_feats, dtype=input_random_surface_feats.dtype
692
+ ).to(pc.device)
693
+ query_sharpedge_feats = torch.zeros(
694
+ B, 0, self.point_feats, dtype=query_random_feats.dtype
695
+ ).to(pc.device)
696
+ else:
697
+ input_sharpedge_surface_feats = sharpedge_surface_feats[
698
+ :, idx_sharpedge_pc, :
699
+ ]
700
+ flatten_input_sharpedge_surface_feats = (
701
+ input_sharpedge_surface_feats.view(B * input_sharpedge_pc_size, -1)
702
+ )
703
+ query_sharpedge_feats = flatten_input_sharpedge_surface_feats[
704
+ idx_query_sharpedge
705
+ ].view(B, -1, flatten_input_sharpedge_surface_feats.shape[-1])
706
+
707
+ query_feats = torch.cat([query_random_feats, query_sharpedge_feats], dim=1)
708
+ input_feats = torch.cat(
709
+ [input_random_surface_feats, input_sharpedge_surface_feats], dim=1
710
+ )
711
+
712
+ if self.normal_pe:
713
+ query_normal_pe = self.fourier_embedder(query_feats[..., :3])
714
+ input_normal_pe = self.fourier_embedder(input_feats[..., :3])
715
+ query_feats = torch.cat([query_normal_pe, query_feats[..., 3:]], dim=-1)
716
+ input_feats = torch.cat([input_normal_pe, input_feats[..., 3:]], dim=-1)
717
+
718
+ query = torch.cat([query, query_feats], dim=-1)
719
+ data = torch.cat([data, input_feats], dim=-1)
720
+
721
+ if input_sharpedge_pc_size == 0:
722
+ query_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
723
+ input_sharpedge_pc = torch.zeros(B, 1, D).to(pc.device)
724
+
725
+ # print(f'query_pc: {query_pc.shape}')
726
+ # print(f'input_pc: {input_pc.shape}')
727
+ # print(f'query_random_pc: {query_random_pc.shape}')
728
+ # print(f'input_random_pc: {input_random_pc.shape}')
729
+ # print(f'query_sharpedge_pc: {query_sharpedge_pc.shape}')
730
+ # print(f'input_sharpedge_pc: {input_sharpedge_pc.shape}')
731
+
732
+ return (
733
+ query.view(B, -1, query.shape[-1]),
734
+ data.view(B, -1, data.shape[-1]),
735
+ [
736
+ query_pc,
737
+ input_pc,
738
+ query_random_pc,
739
+ input_random_pc,
740
+ query_sharpedge_pc,
741
+ input_sharpedge_pc,
742
+ ],
743
+ )
744
+
745
+ def forward(self, pc, feats):
746
+ """
747
+
748
+ Args:
749
+ pc (torch.FloatTensor): [B, N, 3]
750
+ feats (torch.FloatTensor or None): [B, N, C]
751
+
752
+ Returns:
753
+
754
+ """
755
+
756
+ query, data, pc_infos = self.sample_points_and_latents(pc, feats)
757
+
758
+ query = self.input_proj(query)
759
+ query = query
760
+ data = self.input_proj(data)
761
+ data = data
762
+
763
+ latents = self.cross_attn(query, data)
764
+ if self.self_attn is not None:
765
+ latents = self.self_attn(latents)
766
+
767
+ if self.ln_post is not None:
768
+ latents = self.ln_post(latents)
769
+
770
+ return latents, pc_infos
XPart/partgen/models/autoencoders/attention_processors.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
+ # except for the third-party components listed below.
3
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
+ # in the repsective licenses of these third-party components.
5
+ # Users must comply with all terms and conditions of original licenses of these third-party
6
+ # components and must ensure that the usage of the third party components adheres to
7
+ # all relevant laws and regulations.
8
+
9
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
10
+ # their software and algorithms, including trained model weights, parameters (including
11
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
13
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
+
15
+ import os
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ scaled_dot_product_attention = F.scaled_dot_product_attention
21
+ if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
22
+ try:
23
+ from sageattention import sageattn
24
+ except ImportError:
25
+ raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
26
+ scaled_dot_product_attention = sageattn
27
+
28
+
29
+ class CrossAttentionProcessor:
30
+ def __call__(self, attn, q, k, v):
31
+ out = scaled_dot_product_attention(q, k, v)
32
+ return out
XPart/partgen/models/autoencoders/model.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Open Source Model Licensed under the Apache License Version 2.0
2
+ # and Other Licenses of the Third-Party Components therein:
3
+ # The below Model in this distribution may have been modified by THL A29 Limited
4
+ # ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
5
+
6
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
7
+ # The below software and/or models in this distribution may have been
8
+ # modified by THL A29 Limited ("Tencent Modifications").
9
+ # All Tencent Modifications are Copyright (C) THL A29 Limited.
10
+
11
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
12
+ # except for the third-party components listed below.
13
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
14
+ # in the repsective licenses of these third-party components.
15
+ # Users must comply with all terms and conditions of original licenses of these third-party
16
+ # components and must ensure that the usage of the third party components adheres to
17
+ # all relevant laws and regulations.
18
+
19
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
20
+ # their software and algorithms, including trained model weights, parameters (including
21
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
22
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
23
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
+
25
+
26
+ import os
27
+ from typing import Tuple, List, Union
28
+
29
+ from functools import partial
30
+
31
+ import copy
32
+ import numpy as np
33
+ import torch
34
+ import torch.nn as nn
35
+ import yaml
36
+
37
+ from .attention_blocks import (
38
+ FourierEmbedder,
39
+ Transformer,
40
+ CrossAttentionDecoder,
41
+ PointCrossAttentionEncoder,
42
+ )
43
+ from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors, Latent2MeshOutput
44
+ from .volume_decoders import (
45
+ VanillaVolumeDecoder,
46
+ )
47
+ from ...utils.misc import logger, synchronize_timer, smart_load_model
48
+ from ...utils.mesh_utils import extract_geometry_fast
49
+
50
+
51
+ class DiagonalGaussianDistribution(object):
52
+ def __init__(
53
+ self,
54
+ parameters: Union[torch.Tensor, List[torch.Tensor]],
55
+ deterministic=False,
56
+ feat_dim=1,
57
+ ):
58
+ """
59
+ Initialize a diagonal Gaussian distribution with mean and log-variance parameters.
60
+
61
+ Args:
62
+ parameters (Union[torch.Tensor, List[torch.Tensor]]):
63
+ Either a single tensor containing concatenated mean and log-variance along `feat_dim`,
64
+ or a list of two tensors [mean, logvar].
65
+ deterministic (bool, optional): If True, the distribution is deterministic (zero variance).
66
+ Default is False. feat_dim (int, optional): Dimension along which mean and logvar are
67
+ concatenated if parameters is a single tensor. Default is 1.
68
+ """
69
+ self.feat_dim = feat_dim
70
+ self.parameters = parameters
71
+
72
+ if isinstance(parameters, list):
73
+ self.mean = parameters[0]
74
+ self.logvar = parameters[1]
75
+ else:
76
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
77
+
78
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
79
+ self.deterministic = deterministic
80
+ self.std = torch.exp(0.5 * self.logvar)
81
+ self.var = torch.exp(self.logvar)
82
+ if self.deterministic:
83
+ self.var = self.std = torch.zeros_like(self.mean)
84
+
85
+ def sample(self):
86
+ """
87
+ Sample from the diagonal Gaussian distribution.
88
+
89
+ Returns:
90
+ torch.Tensor: A sample tensor with the same shape as the mean.
91
+ """
92
+ x = self.mean + self.std * torch.randn_like(self.mean)
93
+ return x
94
+
95
+ def kl(self, other=None, dims=(1, 2, 3)):
96
+ """
97
+ Compute the Kullback-Leibler (KL) divergence between this distribution and another.
98
+
99
+ If `other` is None, compute KL divergence to a standard normal distribution N(0, I).
100
+
101
+ Args:
102
+ other (DiagonalGaussianDistribution, optional): Another diagonal Gaussian distribution.
103
+ dims (tuple, optional): Dimensions along which to compute the mean KL divergence.
104
+ Default is (1, 2, 3).
105
+
106
+ Returns:
107
+ torch.Tensor: The mean KL divergence value.
108
+ """
109
+ if self.deterministic:
110
+ return torch.Tensor([0.0])
111
+ else:
112
+ if other is None:
113
+ return 0.5 * torch.mean(
114
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims
115
+ )
116
+ else:
117
+ return 0.5 * torch.mean(
118
+ torch.pow(self.mean - other.mean, 2) / other.var
119
+ + self.var / other.var
120
+ - 1.0
121
+ - self.logvar
122
+ + other.logvar,
123
+ dim=dims,
124
+ )
125
+
126
+ def nll(self, sample, dims=(1, 2, 3)):
127
+ if self.deterministic:
128
+ return torch.Tensor([0.0])
129
+ logtwopi = np.log(2.0 * np.pi)
130
+ return 0.5 * torch.sum(
131
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
132
+ dim=dims,
133
+ )
134
+
135
+ def mode(self):
136
+ return self.mean
137
+
138
+
139
+ class VectsetVAE(nn.Module):
140
+
141
+ @classmethod
142
+ @synchronize_timer("VectsetVAE Model Loading")
143
+ def from_single_file(
144
+ cls,
145
+ ckpt_path,
146
+ config_path,
147
+ device="cuda",
148
+ dtype=torch.float16,
149
+ use_safetensors=None,
150
+ **kwargs,
151
+ ):
152
+ # load config
153
+ with open(config_path, "r") as f:
154
+ config = yaml.safe_load(f)
155
+
156
+ # load ckpt
157
+ if use_safetensors:
158
+ ckpt_path = ckpt_path.replace(".ckpt", ".safetensors")
159
+ if not os.path.exists(ckpt_path):
160
+ raise FileNotFoundError(f"Model file {ckpt_path} not found")
161
+
162
+ logger.info(f"Loading model from {ckpt_path}")
163
+ if use_safetensors:
164
+ import safetensors.torch
165
+
166
+ ckpt = safetensors.torch.load_file(ckpt_path, device="cpu")
167
+ else:
168
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
169
+
170
+ model_kwargs = config["params"]
171
+ model_kwargs.update(kwargs)
172
+
173
+ model = cls(**model_kwargs)
174
+ model.load_state_dict(ckpt)
175
+ model.to(device=device, dtype=dtype)
176
+ return model
177
+
178
+ @classmethod
179
+ def from_pretrained(
180
+ cls,
181
+ model_path,
182
+ device="cuda",
183
+ dtype=torch.float16,
184
+ use_safetensors=False,
185
+ variant="fp16",
186
+ subfolder="hunyuan3d-vae-v2-1",
187
+ **kwargs,
188
+ ):
189
+ config_path, ckpt_path = smart_load_model(
190
+ model_path,
191
+ subfolder=subfolder,
192
+ use_safetensors=use_safetensors,
193
+ variant=variant,
194
+ )
195
+
196
+ return cls.from_single_file(
197
+ ckpt_path,
198
+ config_path,
199
+ device=device,
200
+ dtype=dtype,
201
+ use_safetensors=use_safetensors,
202
+ **kwargs,
203
+ )
204
+
205
+ def init_from_ckpt(self, path, ignore_keys=()):
206
+ state_dict = torch.load(path, map_location="cpu")
207
+ state_dict = state_dict.get("state_dict", state_dict)
208
+ keys = list(state_dict.keys())
209
+ for k in keys:
210
+ for ik in ignore_keys:
211
+ if k.startswith(ik):
212
+ print("Deleting key {} from state_dict.".format(k))
213
+ del state_dict[k]
214
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
215
+ print(
216
+ f"Restored from {path} with {len(missing)} missing and"
217
+ f" {len(unexpected)} unexpected keys"
218
+ )
219
+ if len(missing) > 0:
220
+ print(f"Missing Keys: {missing}")
221
+ print(f"Unexpected Keys: {unexpected}")
222
+
223
+ def __init__(self, volume_decoder=None, surface_extractor=None):
224
+ super().__init__()
225
+ if volume_decoder is None:
226
+ volume_decoder = VanillaVolumeDecoder()
227
+ if surface_extractor is None:
228
+ surface_extractor = MCSurfaceExtractor()
229
+ self.volume_decoder = volume_decoder
230
+ self.surface_extractor = surface_extractor
231
+
232
+ def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
233
+ with synchronize_timer("Volume decoding"):
234
+ grid_logits = self.volume_decoder(latents, self.geo_decoder, **kwargs)
235
+ with synchronize_timer("Surface extraction"):
236
+ outputs = self.surface_extractor(grid_logits, **kwargs)
237
+ return outputs
238
+
239
+
240
+ class VolumeDecoderShapeVAE(VectsetVAE):
241
+ def __init__(
242
+ self,
243
+ *,
244
+ num_latents: int,
245
+ embed_dim: int,
246
+ width: int,
247
+ heads: int,
248
+ num_decoder_layers: int,
249
+ num_encoder_layers: int = 8,
250
+ pc_size: int = 5120,
251
+ pc_sharpedge_size: int = 5120,
252
+ point_feats: int = 3,
253
+ downsample_ratio: int = 20,
254
+ geo_decoder_downsample_ratio: int = 1,
255
+ geo_decoder_mlp_expand_ratio: int = 4,
256
+ geo_decoder_ln_post: bool = True,
257
+ num_freqs: int = 8,
258
+ include_pi: bool = True,
259
+ qkv_bias: bool = True,
260
+ qk_norm: bool = False,
261
+ label_type: str = "binary",
262
+ drop_path_rate: float = 0.0,
263
+ scale_factor: float = 1.0,
264
+ use_ln_post: bool = True,
265
+ ckpt_path=None,
266
+ volume_decoder=None,
267
+ surface_extractor=None,
268
+ ):
269
+ super().__init__(volume_decoder, surface_extractor)
270
+ self.geo_decoder_ln_post = geo_decoder_ln_post
271
+ self.downsample_ratio = downsample_ratio
272
+
273
+ self.fourier_embedder = FourierEmbedder(
274
+ num_freqs=num_freqs, include_pi=include_pi
275
+ )
276
+
277
+ self.encoder = PointCrossAttentionEncoder(
278
+ fourier_embedder=self.fourier_embedder,
279
+ num_latents=num_latents,
280
+ downsample_ratio=self.downsample_ratio,
281
+ pc_size=pc_size,
282
+ pc_sharpedge_size=pc_sharpedge_size,
283
+ point_feats=point_feats,
284
+ width=width,
285
+ heads=heads,
286
+ layers=num_encoder_layers,
287
+ qkv_bias=qkv_bias,
288
+ use_ln_post=use_ln_post,
289
+ qk_norm=qk_norm,
290
+ )
291
+
292
+ self.pre_kl = nn.Linear(width, embed_dim * 2)
293
+ self.post_kl = nn.Linear(embed_dim, width)
294
+
295
+ self.transformer = Transformer(
296
+ width=width,
297
+ layers=num_decoder_layers,
298
+ heads=heads,
299
+ qkv_bias=qkv_bias,
300
+ qk_norm=qk_norm,
301
+ drop_path_rate=drop_path_rate,
302
+ )
303
+
304
+ self.geo_decoder = CrossAttentionDecoder(
305
+ fourier_embedder=self.fourier_embedder,
306
+ out_channels=1,
307
+ mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
308
+ downsample_ratio=geo_decoder_downsample_ratio,
309
+ enable_ln_post=self.geo_decoder_ln_post,
310
+ width=width // geo_decoder_downsample_ratio,
311
+ heads=heads // geo_decoder_downsample_ratio,
312
+ qkv_bias=qkv_bias,
313
+ qk_norm=qk_norm,
314
+ label_type=label_type,
315
+ )
316
+
317
+ self.scale_factor = scale_factor
318
+ self.latent_shape = (num_latents, embed_dim)
319
+
320
+ if ckpt_path is not None:
321
+ self.init_from_ckpt(ckpt_path)
322
+
323
+ def forward(self, latents):
324
+ latents = self.post_kl(latents)
325
+ latents = self.transformer(latents)
326
+ return latents
327
+
328
+ def encode(self, surface, sample_posterior=True, return_pc_info=False):
329
+ pc, feats = surface[:, :, :3], surface[:, :, 3:]
330
+ latents, pc_infos = self.encoder(pc, feats)
331
+ # print(latents.shape, self.pre_kl.weight.shape)
332
+ moments = self.pre_kl(latents)
333
+ posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
334
+ if sample_posterior:
335
+ latents = posterior.sample()
336
+ else:
337
+ latents = posterior.mode()
338
+ if return_pc_info:
339
+ return latents, pc_infos
340
+ else:
341
+ return latents
342
+
343
+ def encode_shape(self, surface, return_pc_info=False):
344
+ pc, feats = surface[:, :, :3], surface[:, :, 3:]
345
+ latents, pc_infos = self.encoder(pc, feats)
346
+ if return_pc_info:
347
+ return latents, pc_infos
348
+ else:
349
+ return latents
350
+
351
+ def decode(self, latents):
352
+ latents = self.post_kl(latents)
353
+ latents = self.transformer(latents)
354
+ return latents
355
+
356
+ def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
357
+ logits = self.geo_decoder(queries=queries, latents=latents).squeeze(-1)
358
+ return logits
359
+
360
+ def latents2mesh(self, latents: torch.FloatTensor, **kwargs):
361
+ coarse_kwargs = copy.deepcopy(kwargs)
362
+ coarse_kwargs["octree_resolution"] = 256
363
+
364
+ with synchronize_timer("Coarse Volume decoding"):
365
+ coarse_grid_logits = self.volume_decoder(
366
+ latents, self.geo_decoder, **coarse_kwargs
367
+ )
368
+ with synchronize_timer("Coarse Surface extraction"):
369
+ coarse_mesh = self.surface_extractor(coarse_grid_logits, **coarse_kwargs)
370
+
371
+ assert len(coarse_mesh) == 1
372
+ bbox_gen_by_coarse_matching_cube_mesh = np.stack(
373
+ [coarse_mesh[0].mesh_v.max(0), coarse_mesh[0].mesh_v.min(0)]
374
+ )
375
+ bbox_gen_by_coarse_matching_cube_mesh_range = (
376
+ bbox_gen_by_coarse_matching_cube_mesh[0]
377
+ - bbox_gen_by_coarse_matching_cube_mesh[1]
378
+ )
379
+
380
+ # extend by 10%
381
+ bbox_gen_by_coarse_matching_cube_mesh[0] += (
382
+ bbox_gen_by_coarse_matching_cube_mesh_range * 0.1
383
+ )
384
+ bbox_gen_by_coarse_matching_cube_mesh[1] -= (
385
+ bbox_gen_by_coarse_matching_cube_mesh_range * 0.1
386
+ )
387
+ with synchronize_timer("Fine-grained Volume decoding"):
388
+ grid_logits = self.volume_decoder(
389
+ latents,
390
+ self.geo_decoder,
391
+ bbox_corner=bbox_gen_by_coarse_matching_cube_mesh[None],
392
+ **kwargs,
393
+ )
394
+ with synchronize_timer("Fine-grained Surface extraction"):
395
+ outputs = self.surface_extractor(
396
+ grid_logits,
397
+ bbox_corner=bbox_gen_by_coarse_matching_cube_mesh[None],
398
+ **kwargs,
399
+ )
400
+
401
+ return outputs
402
+
403
+ def latent2mesh_2(
404
+ self,
405
+ latents: torch.FloatTensor,
406
+ bounds: Union[Tuple[float], List[float], float] = 1.1,
407
+ octree_depth: int = 7,
408
+ num_chunks: int = 10000,
409
+ mc_level: float = -1 / 512,
410
+ octree_resolution: int = None,
411
+ mc_mode: str = "mc",
412
+ ) -> List[Latent2MeshOutput]:
413
+ """
414
+ Args:
415
+ latents: [bs, num_latents, dim]
416
+ bounds:
417
+ octree_depth:
418
+ num_chunks:
419
+ Returns:
420
+ mesh_outputs (List[MeshOutput]): the mesh outputs list.
421
+ """
422
+ outputs = []
423
+ geometric_func = partial(self.query_geometry, latents=latents)
424
+ # 2. decode geometry
425
+ device = latents.device
426
+ if mc_mode == "dmc" and not hasattr(self, "diffdmc"):
427
+ from diso import DiffDMC
428
+
429
+ self.diffdmc = DiffDMC(dtype=torch.float32).to(device)
430
+ mesh_v_f, has_surface = extract_geometry_fast(
431
+ geometric_func=geometric_func,
432
+ device=device,
433
+ batch_size=len(latents),
434
+ bounds=bounds,
435
+ octree_depth=octree_depth,
436
+ num_chunks=num_chunks,
437
+ disable=False,
438
+ mc_level=mc_level,
439
+ octree_resolution=octree_resolution,
440
+ diffdmc=self.diffdmc if mc_mode == "dmc" else None,
441
+ mc_mode=mc_mode,
442
+ )
443
+ # 3. decode texture
444
+ for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)):
445
+ if not is_surface:
446
+ outputs.append(None)
447
+ continue
448
+ out = Latent2MeshOutput()
449
+ out.mesh_v = mesh_v
450
+ out.mesh_f = mesh_f
451
+ outputs.append(out)
452
+ return outputs
XPart/partgen/models/autoencoders/surface_extractors.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
+ # except for the third-party components listed below.
3
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
+ # in the repsective licenses of these third-party components.
5
+ # Users must comply with all terms and conditions of original licenses of these third-party
6
+ # components and must ensure that the usage of the third party components adheres to
7
+ # all relevant laws and regulations.
8
+
9
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
10
+ # their software and algorithms, including trained model weights, parameters (including
11
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
13
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
+
15
+ from typing import Union, Tuple, List
16
+
17
+ import numpy as np
18
+ import torch
19
+ from skimage import measure
20
+
21
+
22
+ class Latent2MeshOutput:
23
+ def __init__(self, mesh_v=None, mesh_f=None):
24
+ self.mesh_v = mesh_v
25
+ self.mesh_f = mesh_f
26
+
27
+
28
+ def center_vertices(vertices):
29
+ """Translate the vertices so that bounding box is centered at zero."""
30
+ vert_min = vertices.min(dim=0)[0]
31
+ vert_max = vertices.max(dim=0)[0]
32
+ vert_center = 0.5 * (vert_min + vert_max)
33
+ return vertices - vert_center
34
+
35
+
36
+ class SurfaceExtractor:
37
+ def _compute_box_stat(self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int):
38
+ """
39
+ Compute grid size, bounding box minimum coordinates, and bounding box size based on input
40
+ bounds and resolution.
41
+
42
+ Args:
43
+ bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or a single
44
+ float representing half side length.
45
+ If float, bounds are assumed symmetric around zero in all axes.
46
+ Expected format if list/tuple: [xmin, ymin, zmin, xmax, ymax, zmax].
47
+ octree_resolution (int): Resolution of the octree grid.
48
+
49
+ Returns:
50
+ grid_size (List[int]): Grid size along each axis (x, y, z), each equal to octree_resolution + 1.
51
+ bbox_min (np.ndarray): Minimum coordinates of the bounding box (xmin, ymin, zmin).
52
+ bbox_size (np.ndarray): Size of the bounding box along each axis (xmax - xmin, etc.).
53
+ """
54
+ if isinstance(bounds, float):
55
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
56
+
57
+ bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
58
+ bbox_size = bbox_max - bbox_min
59
+ grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
60
+ return grid_size, bbox_min, bbox_size
61
+
62
+ def run(self, *args, **kwargs):
63
+ """
64
+ Abstract method to extract surface mesh from grid logits.
65
+
66
+ This method should be implemented by subclasses.
67
+
68
+ Raises:
69
+ NotImplementedError: Always, since this is an abstract method.
70
+ """
71
+ return NotImplementedError
72
+
73
+ def __call__(self, grid_logits, **kwargs):
74
+ """
75
+ Process a batch of grid logits to extract surface meshes.
76
+
77
+ Args:
78
+ grid_logits (torch.Tensor): Batch of grid logits with shape (batch_size, ...).
79
+ **kwargs: Additional keyword arguments passed to the `run` method.
80
+
81
+ Returns:
82
+ List[Optional[Latent2MeshOutput]]: List of mesh outputs for each grid in the batch.
83
+ If extraction fails for a grid, None is appended at that position.
84
+ """
85
+ outputs = []
86
+ for i in range(grid_logits.shape[0]):
87
+ try:
88
+ vertices, faces = self.run(grid_logits[i], **kwargs)
89
+ vertices = vertices.astype(np.float32)
90
+ faces = np.ascontiguousarray(faces)
91
+ outputs.append(Latent2MeshOutput(mesh_v=vertices, mesh_f=faces))
92
+
93
+ except Exception:
94
+ import traceback
95
+ traceback.print_exc()
96
+ outputs.append(None)
97
+
98
+ return outputs
99
+
100
+
101
+ class MCSurfaceExtractor(SurfaceExtractor):
102
+ def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
103
+ """
104
+ Extract surface mesh using the Marching Cubes algorithm.
105
+
106
+ Args:
107
+ grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
108
+ mc_level (float): The level (iso-value) at which to extract the surface.
109
+ bounds (Union[Tuple[float], List[float], float]): Bounding box coordinates or half side length.
110
+ octree_resolution (int): Resolution of the octree grid.
111
+ **kwargs: Additional keyword arguments (ignored).
112
+
113
+ Returns:
114
+ Tuple[np.ndarray, np.ndarray]: Tuple containing:
115
+ - vertices (np.ndarray): Extracted mesh vertices, scaled and translated to bounding
116
+ box coordinates.
117
+ - faces (np.ndarray): Extracted mesh faces (triangles).
118
+ """
119
+ vertices, faces, normals, _ = measure.marching_cubes(grid_logit.cpu().numpy(),
120
+ mc_level,
121
+ method="lewiner")
122
+ grid_size, bbox_min, bbox_size = self._compute_box_stat(bounds, octree_resolution)
123
+ vertices = vertices / grid_size * bbox_size + bbox_min
124
+ return vertices, faces
125
+
126
+
127
+ class DMCSurfaceExtractor(SurfaceExtractor):
128
+ def run(self, grid_logit, *, octree_resolution, **kwargs):
129
+ """
130
+ Extract surface mesh using Differentiable Marching Cubes (DMC) algorithm.
131
+
132
+ Args:
133
+ grid_logit (torch.Tensor): 3D grid logits tensor representing the scalar field.
134
+ octree_resolution (int): Resolution of the octree grid.
135
+ **kwargs: Additional keyword arguments (ignored).
136
+
137
+ Returns:
138
+ Tuple[np.ndarray, np.ndarray]: Tuple containing:
139
+ - vertices (np.ndarray): Extracted mesh vertices, centered and converted to numpy.
140
+ - faces (np.ndarray): Extracted mesh faces (triangles), with reversed vertex order.
141
+
142
+ Raises:
143
+ ImportError: If the 'diso' package is not installed.
144
+ """
145
+ device = grid_logit.device
146
+ if not hasattr(self, 'dmc'):
147
+ try:
148
+ from diso import DiffDMC
149
+ self.dmc = DiffDMC(dtype=torch.float32).to(device)
150
+ except:
151
+ raise ImportError("Please install diso via `pip install diso`, or set mc_algo to 'mc'")
152
+ sdf = -grid_logit / octree_resolution
153
+ sdf = sdf.to(torch.float32).contiguous()
154
+ verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
155
+ verts = center_vertices(verts)
156
+ vertices = verts.detach().cpu().numpy()
157
+ faces = faces.detach().cpu().numpy()[:, ::-1]
158
+ return vertices, faces
159
+
160
+
161
+ SurfaceExtractors = {
162
+ 'mc': MCSurfaceExtractor,
163
+ 'dmc': DMCSurfaceExtractor,
164
+ }
XPart/partgen/models/autoencoders/volume_decoders.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
+ # except for the third-party components listed below.
3
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
+ # in the repsective licenses of these third-party components.
5
+ # Users must comply with all terms and conditions of original licenses of these third-party
6
+ # components and must ensure that the usage of the third party components adheres to
7
+ # all relevant laws and regulations.
8
+
9
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
10
+ # their software and algorithms, including trained model weights, parameters (including
11
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
13
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
+
15
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
16
+ # their software and algorithms, including trained model weights, parameters (including
17
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
18
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
19
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
20
+
21
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
22
+ # their software and algorithms, including trained model weights, parameters (including
23
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
24
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
25
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
26
+
27
+
28
+ from typing import Union, Tuple, List, Callable
29
+
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from einops import repeat
35
+ from tqdm import tqdm
36
+
37
+ from .attention_blocks import CrossAttentionDecoder
38
+ from ...utils.misc import logger
39
+ from ...utils.mesh_utils import (
40
+ extract_near_surface_volume_fn,
41
+ generate_dense_grid_points,
42
+ )
43
+
44
+
45
+ class VanillaVolumeDecoder:
46
+ @torch.no_grad()
47
+ def __call__(
48
+ self,
49
+ latents: torch.FloatTensor,
50
+ geo_decoder: Callable,
51
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
52
+ num_chunks: int = 10000,
53
+ octree_resolution: int = None,
54
+ enable_pbar: bool = True,
55
+ **kwargs,
56
+ ):
57
+
58
+ """
59
+ Perform volume decoding with a vanilla decoder
60
+ Args:
61
+ latents (torch.FloatTensor): Latent vectors to decode.
62
+ geo_decoder (Callable): The geometry decoder function.
63
+ bounds (Union[Tuple[float], List[float], float]): Bounding box for the volume.
64
+ num_chunks (int): Number of chunks to process at a time.
65
+ octree_resolution (int): Resolution of the octree for sampling points.
66
+ enable_pbar (bool): Whether to enable progress bar.
67
+ Returns:
68
+ grid_logits (torch.FloatTensor): Decoded 3D volume logits.
69
+ """
70
+ device = latents.device
71
+ dtype = latents.dtype
72
+ batch_size = latents.shape[0]
73
+
74
+ # 1. generate query points
75
+ if isinstance(bounds, float):
76
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
77
+
78
+ bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
79
+ xyz_samples, grid_size, length = generate_dense_grid_points(
80
+ bbox_min=bbox_min,
81
+ bbox_max=bbox_max,
82
+ octree_resolution=octree_resolution,
83
+ indexing="ij",
84
+ )
85
+ xyz_samples = (
86
+ torch.from_numpy(xyz_samples)
87
+ .to(device, dtype=dtype)
88
+ .contiguous()
89
+ .reshape(-1, 3)
90
+ )
91
+
92
+ # 2. latents to 3d volume
93
+ batch_logits = []
94
+ for start in tqdm(
95
+ range(0, xyz_samples.shape[0], num_chunks),
96
+ desc=f"Volume Decoding",
97
+ disable=not enable_pbar,
98
+ ):
99
+ chunk_queries = xyz_samples[start : start + num_chunks, :]
100
+ chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
101
+ logits = geo_decoder(queries=chunk_queries, latents=latents)
102
+ batch_logits.append(logits)
103
+
104
+ grid_logits = torch.cat(batch_logits, dim=1)
105
+ grid_logits = grid_logits.view((batch_size, *grid_size)).float()
106
+
107
+ return grid_logits
XPart/partgen/models/conditioner/condioner_release.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
15
+ #
16
+ # Licensed under the Apache License, Version 2.0 (the "License");
17
+ # you may not use this file except in compliance with the License.
18
+ # You may obtain a copy of the License at
19
+ #
20
+ # http://www.apache.org/licenses/LICENSE-2.0
21
+ #
22
+ # Unless required by applicable law or agreed to in writing, software
23
+ # distributed under the License is distributed on an "AS IS" BASIS,
24
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25
+ # See the License for the specific language governing permissions and
26
+ # limitations under the License.
27
+
28
+ import torch
29
+ from .part_encoders import PartEncoder
30
+ from ..autoencoders import VolumeDecoderShapeVAE
31
+ from ...utils.misc import (
32
+ instantiate_from_config,
33
+ instantiate_non_trainable_model,
34
+ )
35
+ from .sonata_extractor import SonataFeatureExtractor
36
+ from .part_encoders import PartEncoder
37
+
38
+
39
+ def debug_sonata_feat(points, feats):
40
+ from sklearn.decomposition import PCA
41
+ import numpy as np
42
+ import trimesh
43
+ import os
44
+
45
+ point_num = points.shape[0]
46
+ feat_save = feats.float().detach().cpu().numpy()
47
+ data_scaled = feat_save / np.linalg.norm(feat_save, axis=-1, keepdims=True)
48
+ pca = PCA(n_components=3)
49
+ data_reduced = pca.fit_transform(data_scaled)
50
+ data_reduced = (data_reduced - data_reduced.min()) / (
51
+ data_reduced.max() - data_reduced.min()
52
+ )
53
+ colors_255 = (data_reduced * 255).astype(np.uint8)
54
+ colors_255 = np.concatenate(
55
+ [colors_255, np.ones((point_num, 1), dtype=np.uint8) * 255], axis=-1
56
+ )
57
+ pc_save = trimesh.points.PointCloud(points, colors=colors_255)
58
+ return pc_save
59
+ # pc_save.export(os.path.join("debug", "point_pca.glb"))
60
+
61
+
62
+ class Conditioner(torch.nn.Module):
63
+
64
+ def __init__(
65
+ self,
66
+ use_image=False,
67
+ use_geo=True,
68
+ use_obj=True,
69
+ use_seg_feat=False,
70
+ geo_cfg=None,
71
+ obj_encoder_cfg=None,
72
+ seg_feat_cfg=None,
73
+ **kwargs
74
+ ):
75
+ super().__init__()
76
+ self.use_image = use_image
77
+ self.use_obj = use_obj
78
+ self.use_geo = use_geo
79
+ self.use_seg_feat = use_seg_feat
80
+ self.geo_cfg = geo_cfg
81
+ self.obj_encoder_cfg = obj_encoder_cfg
82
+ self.seg_feat_cfg = seg_feat_cfg
83
+ if use_geo and geo_cfg is not None:
84
+ self.geo_encoder: PartEncoder = instantiate_from_config(geo_cfg)
85
+ if hasattr(geo_cfg, "output_dim"):
86
+ self.geo_out_proj = torch.nn.Linear(1024 + 512, geo_cfg.output_dim)
87
+
88
+ if use_obj and obj_encoder_cfg is not None:
89
+ self.obj_encoder: VolumeDecoderShapeVAE = instantiate_non_trainable_model(
90
+ obj_encoder_cfg
91
+ )
92
+ if hasattr(obj_encoder_cfg, "output_dim"):
93
+ self.obj_out_proj = torch.nn.Linear(
94
+ 1024 + 512, obj_encoder_cfg.output_dim
95
+ )
96
+ if use_seg_feat and seg_feat_cfg is not None:
97
+ self.seg_feat_encoder: SonataFeatureExtractor = (
98
+ instantiate_non_trainable_model(seg_feat_cfg)
99
+ )
100
+ if hasattr(seg_feat_cfg, "output_dim"):
101
+ self.seg_feat_outproj = torch.nn.Linear(512, seg_feat_cfg.output_dim)
102
+
103
+ def forward(self, part_surface_inbbox, object_surface):
104
+ bz = part_surface_inbbox.shape[0]
105
+ context = {}
106
+ # geo_cond
107
+ if self.use_geo:
108
+ context["geo_cond"], local_pc_infos = self.geo_encoder(
109
+ part_surface_inbbox,
110
+ object_surface,
111
+ return_local_pc_info=True,
112
+ )
113
+ # obj cond
114
+ if self.use_obj:
115
+ with torch.no_grad():
116
+ context["obj_cond"], global_pc_infos = self.obj_encoder.encode_shape(
117
+ object_surface, return_pc_info=True
118
+ )
119
+
120
+ # seg feat cond
121
+ if self.use_seg_feat:
122
+ # TODO: batchsize must be One
123
+ num_parts = part_surface_inbbox.shape[0]
124
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
125
+ # encode sonata feature
126
+ # with torch.cuda.amp.autocast(enabled=False):
127
+ with torch.no_grad():
128
+ point, normal = (
129
+ object_surface[:1, ..., :3].float(),
130
+ object_surface[:1, ..., 3:6].float(),
131
+ )
132
+ point_feat = self.seg_feat_encoder(point, normal)
133
+ # local feat
134
+ if self.use_obj:
135
+ nearest_global_matches = torch.argmin(
136
+ torch.cdist(global_pc_infos[0], object_surface[..., :3]), dim=-1
137
+ )
138
+ # global feat
139
+ global_point_feats = point_feat.expand(num_parts, -1, -1).gather(
140
+ 1,
141
+ nearest_global_matches.unsqueeze(-1).expand(
142
+ -1, -1, point_feat.size(-1)
143
+ ),
144
+ )
145
+ context["obj_cond"] = torch.concat(
146
+ [context["obj_cond"], global_point_feats], dim=-1
147
+ ).to(dtype=self.obj_out_proj.weight.dtype)
148
+ if hasattr(self, "obj_out_proj"):
149
+ context["obj_cond"] = self.obj_out_proj(
150
+ context["obj_cond"]
151
+ ) # .float()
152
+ if self.use_geo:
153
+ nearest_local_matches = torch.argmin(
154
+ torch.cdist(local_pc_infos[0], object_surface[..., :3]), dim=-1
155
+ )
156
+ local_point_feats = point_feat.expand(num_parts, -1, -1).gather(
157
+ 1,
158
+ nearest_local_matches.unsqueeze(-1).expand(
159
+ -1, -1, point_feat.size(-1)
160
+ ),
161
+ )
162
+ context["geo_cond"] = torch.concat(
163
+ [context["geo_cond"], local_point_feats],
164
+ dim=-1,
165
+ ).to(dtype=self.geo_out_proj.weight.dtype)
166
+ if hasattr(self, "geo_out_proj"):
167
+ context["geo_cond"] = self.geo_out_proj(
168
+ context["geo_cond"]
169
+ ) # .float()
170
+ return context
XPart/partgen/models/conditioner/part_encoders.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from ...utils.misc import (
3
+ instantiate_from_config,
4
+ instantiate_non_trainable_model,
5
+ )
6
+ from ..autoencoders.model import (
7
+ VolumeDecoderShapeVAE,
8
+ )
9
+
10
+
11
+ class PartEncoder(nn.Module):
12
+ def __init__(
13
+ self,
14
+ use_local=True,
15
+ local_global_feat_dim=None,
16
+ local_geo_cfg=None,
17
+ local_feat_type="latents",
18
+ num_tokens_cond=2048,
19
+ ):
20
+ super().__init__()
21
+ self.local_global_feat_dim = local_global_feat_dim
22
+ self.local_feat_type = local_feat_type
23
+ self.num_tokens_cond = num_tokens_cond
24
+ # local
25
+ self.use_local = use_local
26
+ if use_local:
27
+ if local_geo_cfg is None:
28
+ raise ValueError(
29
+ "local_geo_cfg must be provided when use_local is True"
30
+ )
31
+ assert (
32
+ "ShapeVAE" in local_geo_cfg.get("target").split(".")[-1]
33
+ ), "local_geo_cfg must be a ShapeVAE config"
34
+ self.local_encoder: VolumeDecoderShapeVAE = instantiate_from_config(
35
+ local_geo_cfg
36
+ )
37
+ if self.local_global_feat_dim is not None:
38
+ self.local_out_layer = nn.Linear(
39
+ (
40
+ local_geo_cfg.params.embed_dim
41
+ if self.local_feat_type == "latents"
42
+ else local_geo_cfg.params.width
43
+ ),
44
+ self.local_global_feat_dim,
45
+ bias=True,
46
+ )
47
+
48
+ def forward(self, part_surface_inbbox, object_surface, return_local_pc_info=False):
49
+ """
50
+ Args:
51
+ aabb: (B, 2, 3) tensor representing the axis-aligned bounding box
52
+ object_surface: (B, N, 3) tensor representing the surface points of the object
53
+ Returns:
54
+ local_features: (B, num_tokens_cond, C) tensor of local features
55
+ global_features: (B,num_tokens_cond, C) tensor of global features
56
+ """
57
+ # random selection if more than num_tokens_cond points
58
+ if self.use_local:
59
+ # with torch.autocast(
60
+ # device_type=part_surface_inbbox.device.type,
61
+ # dtype=torch.float16,
62
+ # ):
63
+ # with torch.no_grad():
64
+ if self.local_feat_type == "latents":
65
+ local_features, local_pc_infos = self.local_encoder.encode(
66
+ part_surface_inbbox, sample_posterior=True, return_pc_info=True
67
+ ) # (B, num_tokens_cond, C)
68
+ elif self.local_feat_type == "latents_shape":
69
+ local_features, local_pc_infos = self.local_encoder.encode_shape(
70
+ part_surface_inbbox, return_pc_info=True
71
+ ) # (B, num_tokens_cond, C)
72
+ elif self.local_feat_type == "miche-point-query-structural-vae":
73
+ local_features, local_pc_infos = self.local_encoder.encode(
74
+ part_surface_inbbox, sample_posterior=True, return_pc_info=True
75
+ )
76
+ local_features = self.local_encoder(local_features)
77
+ else:
78
+ raise ValueError(
79
+ f"local_feat_type {self.local_feat_type} not supported"
80
+ )
81
+ # ouput layer
82
+ geo_features = (
83
+ self.local_out_layer(local_features)
84
+ if hasattr(self, "local_out_layer")
85
+ else local_features
86
+ )
87
+ if return_local_pc_info:
88
+ return geo_features, local_pc_infos
89
+ return geo_features
XPart/partgen/models/conditioner/sonata_extractor.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .. import sonata
4
+
5
+ from typing import Dict, Union, Optional
6
+ from pathlib import Path
7
+
8
+
9
+ class SonataFeatureExtractor(nn.Module):
10
+ """
11
+ Feature extractor using Sonata backbone with MLP projection.
12
+ Supports batch processing and gradient computation.
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ ckpt_path: Optional[str] = "",
18
+ ):
19
+ super().__init__()
20
+
21
+ # Load Sonata model
22
+ self.sonata = sonata.load_by_config(
23
+ str(Path(__file__).parent.parent.parent / "config" / "sonata.json")
24
+ )
25
+
26
+ # Store original dtype for later reference
27
+ # self._original_dtype = next(self.parameters()).dtype
28
+
29
+ # Define MLP projection head (same as in train-sonata.py)
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(1232, 512),
32
+ nn.GELU(),
33
+ nn.Linear(512, 512),
34
+ nn.GELU(),
35
+ nn.Linear(512, 512),
36
+ )
37
+
38
+ # Define transform
39
+ self.transform = sonata.transform.default()
40
+
41
+ # Load checkpoint if provided
42
+ if ckpt_path:
43
+ self.load_checkpoint(ckpt_path)
44
+
45
+ def load_checkpoint(self, checkpoint_path: str):
46
+ """Load model weights from checkpoint."""
47
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
48
+
49
+ # Extract state dict from Lightning checkpoint
50
+ if "state_dict" in checkpoint:
51
+ state_dict = checkpoint["state_dict"]
52
+ # Remove 'model.' prefix if present from Lightning
53
+ state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
54
+ else:
55
+ state_dict = checkpoint
56
+
57
+ # Debug: Show all keys in checkpoint
58
+ print("\n=== Checkpoint Keys ===")
59
+ print(f"Total keys in checkpoint: {len(state_dict)}")
60
+ print("\nSample keys:")
61
+ for i, key in enumerate(list(state_dict.keys())[:10]):
62
+ print(f" {key}")
63
+ if len(state_dict) > 10:
64
+ print(f" ... and {len(state_dict) - 10} more keys")
65
+
66
+ # Load only the relevant weights
67
+ sonata_dict = {
68
+ k.replace("sonata.", ""): v
69
+ for k, v in state_dict.items()
70
+ if k.startswith("sonata.")
71
+ }
72
+ mlp_dict = {
73
+ k.replace("mlp.", ""): v
74
+ for k, v in state_dict.items()
75
+ if k.startswith("mlp.")
76
+ }
77
+
78
+ print(f"\nFound {len(sonata_dict)} Sonata keys")
79
+ print(f"Found {len(mlp_dict)} MLP keys")
80
+
81
+ # Load Sonata weights and show missing/unexpected keys
82
+ if sonata_dict:
83
+ print("\n=== Loading Sonata Weights ===")
84
+ result = self.sonata.load_state_dict(sonata_dict, strict=False)
85
+ if result.missing_keys:
86
+ print(f"\nMissing keys ({len(result.missing_keys)}):")
87
+ for key in result.missing_keys[:20]: # Show first 20
88
+ print(f" - {key}")
89
+ if len(result.missing_keys) > 20:
90
+ print(f" ... and {len(result.missing_keys) - 20} more")
91
+ else:
92
+ print("No missing keys!")
93
+
94
+ if result.unexpected_keys:
95
+ print(f"\nUnexpected keys ({len(result.unexpected_keys)}):")
96
+ for key in result.unexpected_keys[:20]: # Show first 20
97
+ print(f" - {key}")
98
+ if len(result.unexpected_keys) > 20:
99
+ print(f" ... and {len(result.unexpected_keys) - 20} more")
100
+ else:
101
+ print("No unexpected keys!")
102
+
103
+ # Load MLP weights
104
+ if mlp_dict:
105
+ print("\n=== Loading MLP Weights ===")
106
+ result = self.mlp.load_state_dict(mlp_dict, strict=False)
107
+ if result.missing_keys:
108
+ print(f"\nMissing keys: {result.missing_keys}")
109
+ if result.unexpected_keys:
110
+ print(f"Unexpected keys: {result.unexpected_keys}")
111
+ print("MLP weights loaded successfully!")
112
+
113
+ print(f"\n✓ Loaded checkpoint from {checkpoint_path}")
114
+
115
+ def prepare_batch_data(
116
+ self, points: torch.Tensor, normals: Optional[torch.Tensor] = None
117
+ ) -> Dict:
118
+ """
119
+ Prepare batch data for Sonata model.
120
+
121
+ Args:
122
+ points: [B, N, 3] or [N, 3] tensor of point coordinates
123
+ normals: [B, N, 3] or [N, 3] tensor of normals (optional)
124
+
125
+ Returns:
126
+ Dictionary formatted for Sonata input
127
+ """
128
+ # Handle single batch case
129
+ if points.dim() == 2:
130
+ points = points.unsqueeze(0)
131
+ if normals is not None:
132
+ normals = normals.unsqueeze(0)
133
+ # print('Sonata points shape: ', points.shape)
134
+ B, N, _ = points.shape
135
+
136
+ # Prepare batch indices
137
+ batch_idx = torch.arange(B).view(-1, 1).repeat(1, N).reshape(-1)
138
+
139
+ # Flatten points for Sonata format
140
+ coord = points.reshape(B * N, 3)
141
+
142
+ if normals is not None:
143
+ normal = normals.reshape(B * N, 3)
144
+ else:
145
+ # Generate dummy normals if not provided
146
+ normal = torch.ones_like(coord)
147
+
148
+ # Generate dummy colors
149
+ color = torch.ones_like(coord)
150
+
151
+ # Function to convert tensor to numpy array, handling BFloat16
152
+ def to_numpy(tensor):
153
+ # First convert to CPU if needed
154
+ if tensor.is_cuda:
155
+ tensor = tensor.cpu()
156
+ # Convert BFloat16 or other unsupported dtypes to float32
157
+ if tensor.dtype not in [
158
+ torch.float32,
159
+ torch.float64,
160
+ torch.int32,
161
+ torch.int64,
162
+ torch.uint8,
163
+ torch.int8,
164
+ torch.int16,
165
+ ]:
166
+ tensor = tensor.to(torch.float32)
167
+ # Then convert to numpy
168
+ return tensor.numpy()
169
+
170
+ # Create data dict
171
+ data_dict = {
172
+ "coord": to_numpy(coord),
173
+ "normal": to_numpy(normal),
174
+ "color": to_numpy(color),
175
+ "batch": to_numpy(batch_idx),
176
+ }
177
+
178
+ # Apply transform
179
+ data_dict = self.transform(data_dict)
180
+
181
+ return data_dict, B, N
182
+
183
+ def forward(
184
+ self, points: torch.Tensor, normals: Optional[torch.Tensor] = None
185
+ ) -> torch.Tensor:
186
+ """
187
+ Extract features from point clouds.
188
+
189
+ Args:
190
+ points: [B, N, 3] or [N, 3] tensor of point coordinates
191
+ normals: [B, N, 3] or [N, 3] tensor of normals (optional)
192
+
193
+ Returns:
194
+ features: [B, N, 512] or [N, 512] tensor of features
195
+ """
196
+ # Store original shape
197
+ original_shape = points.shape
198
+ single_batch = points.dim() == 2
199
+
200
+ # Prepare data for Sonata
201
+ data_dict, B, N = self.prepare_batch_data(points, normals)
202
+
203
+ # Move to GPU if needed and convert to appropriate dtype
204
+ device = points.device
205
+ dtype = points.dtype
206
+
207
+ # Make sure the entire model is in the correct dtype
208
+ # if dtype != self._original_dtype:
209
+ # self.to(dtype)
210
+ # self._original_dtype = dtype
211
+
212
+ for key in data_dict.keys():
213
+ if isinstance(data_dict[key], torch.Tensor):
214
+ # Convert tensors to the right device and dtype if they're floating point
215
+ if data_dict[key].is_floating_point():
216
+ data_dict[key] = data_dict[key].to(device=device, dtype=dtype)
217
+ else:
218
+ # For integer tensors, just move to device without changing dtype
219
+ data_dict[key] = data_dict[key].to(device)
220
+
221
+ # Extract Sonata features
222
+ point = self.sonata(data_dict)
223
+
224
+ # Handle pooling layers (same as in train-sonata.py)
225
+ while "pooling_parent" in point.keys():
226
+ assert "pooling_inverse" in point.keys()
227
+ parent = point.pop("pooling_parent")
228
+ inverse = point.pop("pooling_inverse")
229
+ parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
230
+ point = parent
231
+
232
+ # Get features and apply MLP
233
+ feat = point.feat # [M, 1232]
234
+ feat = self.mlp(feat) # [M, 512]
235
+
236
+ # Map back to original points
237
+ feat = feat[point.inverse] # [B*N, 512]
238
+
239
+ # Reshape to batch format
240
+ feat = feat.reshape(B, -1, feat.shape[-1]) # [B, N, 512]
241
+
242
+ # Return in original format
243
+ if single_batch:
244
+ feat = feat.squeeze(0) # [N, 512]
245
+
246
+ return feat
247
+
248
+ def extract_features_batch(
249
+ self,
250
+ points_list: list,
251
+ normals_list: Optional[list] = None,
252
+ batch_size: int = 8,
253
+ ) -> list:
254
+ """
255
+ Extract features for multiple point clouds in batches.
256
+
257
+ Args:
258
+ points_list: List of [N_i, 3] tensors
259
+ normals_list: List of [N_i, 3] tensors (optional)
260
+ batch_size: Batch size for processing
261
+
262
+ Returns:
263
+ List of [N_i, 512] feature tensors
264
+ """
265
+ features_list = []
266
+
267
+ # Process in batches
268
+ for i in range(0, len(points_list), batch_size):
269
+ batch_points = points_list[i : i + batch_size]
270
+ batch_normals = normals_list[i : i + batch_size] if normals_list else None
271
+
272
+ # Find max points in batch
273
+ max_n = max(p.shape[0] for p in batch_points)
274
+
275
+ # Pad to same size
276
+ padded_points = []
277
+ masks = []
278
+ for points in batch_points:
279
+ n = points.shape[0]
280
+ if n < max_n:
281
+ padding = torch.zeros(max_n - n, 3, device=points.device)
282
+ points = torch.cat([points, padding], dim=0)
283
+ padded_points.append(points)
284
+ mask = torch.zeros(max_n, dtype=torch.bool, device=points.device)
285
+ mask[:n] = True
286
+ masks.append(mask)
287
+
288
+ # Stack batch
289
+ batch_tensor = torch.stack(padded_points) # [B, max_n, 3]
290
+
291
+ # Handle normals similarly if provided
292
+ if batch_normals:
293
+ padded_normals = []
294
+ for j, normals in enumerate(batch_normals):
295
+ n = normals.shape[0]
296
+ if n < max_n:
297
+ padding = torch.ones(max_n - n, 3, device=normals.device)
298
+ normals = torch.cat([normals, padding], dim=0)
299
+ padded_normals.append(normals)
300
+ normals_tensor = torch.stack(padded_normals)
301
+ else:
302
+ normals_tensor = None
303
+
304
+ # Extract features
305
+ with torch.cuda.amp.autocast(enabled=True):
306
+ batch_features = self.forward(
307
+ batch_tensor, normals_tensor
308
+ ) # [B, max_n, 512]
309
+
310
+ # Unpad and add to results
311
+ for j, (feat, mask) in enumerate(zip(batch_features, masks)):
312
+ features_list.append(feat[mask])
313
+
314
+ return features_list
315
+
XPart/partgen/models/diffusion/schedulers.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
16
+ # except for the third-party components listed below.
17
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
18
+ # in the repsective licenses of these third-party components.
19
+ # Users must comply with all terms and conditions of original licenses of these third-party
20
+ # components and must ensure that the usage of the third party components adheres to
21
+ # all relevant laws and regulations.
22
+
23
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
24
+ # their software and algorithms, including trained model weights, parameters (including
25
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
26
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
27
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
28
+
29
+ import math
30
+ from dataclasses import dataclass
31
+ from typing import List, Optional, Tuple, Union
32
+
33
+ import numpy as np
34
+ import torch
35
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
36
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
37
+ from diffusers.utils import BaseOutput, logging
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ @dataclass
43
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
44
+ """
45
+ Output class for the scheduler's `step` function output.
46
+
47
+ Args:
48
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
49
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
50
+ denoising loop.
51
+ """
52
+
53
+ prev_sample: torch.FloatTensor
54
+
55
+
56
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
57
+ """
58
+ NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed
59
+
60
+ Euler scheduler.
61
+
62
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
63
+ methods the library implements for all schedulers such as loading and saving.
64
+
65
+ Args:
66
+ num_train_timesteps (`int`, defaults to 1000):
67
+ The number of diffusion steps to train the model.
68
+ timestep_spacing (`str`, defaults to `"linspace"`):
69
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
70
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
71
+ shift (`float`, defaults to 1.0):
72
+ The shift value for the timestep schedule.
73
+ """
74
+
75
+ _compatibles = []
76
+ order = 1
77
+
78
+ @register_to_config
79
+ def __init__(
80
+ self,
81
+ num_train_timesteps: int = 1000,
82
+ shift: float = 1.0,
83
+ use_dynamic_shifting=False,
84
+ ):
85
+ timesteps = np.linspace(
86
+ 1, num_train_timesteps, num_train_timesteps, dtype=np.float32
87
+ ).copy()
88
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
89
+
90
+ sigmas = timesteps / num_train_timesteps
91
+ if not use_dynamic_shifting:
92
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
93
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
94
+
95
+ self.timesteps = sigmas * num_train_timesteps
96
+
97
+ self._step_index = None
98
+ self._begin_index = None
99
+
100
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
101
+ self.sigma_min = self.sigmas[-1].item()
102
+ self.sigma_max = self.sigmas[0].item()
103
+
104
+ @property
105
+ def step_index(self):
106
+ """
107
+ The index counter for current timestep. It will increase 1 after each scheduler step.
108
+ """
109
+ return self._step_index
110
+
111
+ @property
112
+ def begin_index(self):
113
+ """
114
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
115
+ """
116
+ return self._begin_index
117
+
118
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
119
+ def set_begin_index(self, begin_index: int = 0):
120
+ """
121
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
122
+
123
+ Args:
124
+ begin_index (`int`):
125
+ The begin index for the scheduler.
126
+ """
127
+ self._begin_index = begin_index
128
+
129
+ def scale_noise(
130
+ self,
131
+ sample: torch.FloatTensor,
132
+ timestep: Union[float, torch.FloatTensor],
133
+ noise: Optional[torch.FloatTensor] = None,
134
+ ) -> torch.FloatTensor:
135
+ """
136
+ Forward process in flow-matching
137
+
138
+ Args:
139
+ sample (`torch.FloatTensor`):
140
+ The input sample.
141
+ timestep (`int`, *optional*):
142
+ The current timestep in the diffusion chain.
143
+
144
+ Returns:
145
+ `torch.FloatTensor`:
146
+ A scaled input sample.
147
+ """
148
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
149
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
150
+
151
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
152
+ # mps does not support float64
153
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
154
+ timestep = timestep.to(sample.device, dtype=torch.float32)
155
+ else:
156
+ schedule_timesteps = self.timesteps.to(sample.device)
157
+ timestep = timestep.to(sample.device)
158
+
159
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
160
+ if self.begin_index is None:
161
+ step_indices = [
162
+ self.index_for_timestep(t, schedule_timesteps) for t in timestep
163
+ ]
164
+ elif self.step_index is not None:
165
+ # add_noise is called after first denoising step (for inpainting)
166
+ step_indices = [self.step_index] * timestep.shape[0]
167
+ else:
168
+ # add noise is called before first denoising step to create initial latent(img2img)
169
+ step_indices = [self.begin_index] * timestep.shape[0]
170
+
171
+ sigma = sigmas[step_indices].flatten()
172
+ while len(sigma.shape) < len(sample.shape):
173
+ sigma = sigma.unsqueeze(-1)
174
+
175
+ sample = sigma * noise + (1.0 - sigma) * sample
176
+
177
+ return sample
178
+
179
+ def _sigma_to_t(self, sigma):
180
+ return sigma * self.config.num_train_timesteps
181
+
182
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
183
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
184
+
185
+ def set_timesteps(
186
+ self,
187
+ num_inference_steps: int = None,
188
+ device: Union[str, torch.device] = None,
189
+ sigmas: Optional[List[float]] = None,
190
+ mu: Optional[float] = None,
191
+ ):
192
+ """
193
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
194
+
195
+ Args:
196
+ num_inference_steps (`int`):
197
+ The number of diffusion steps used when generating samples with a pre-trained model.
198
+ device (`str` or `torch.device`, *optional*):
199
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
200
+ """
201
+
202
+ if self.config.use_dynamic_shifting and mu is None:
203
+ raise ValueError(
204
+ " you have a pass a value for `mu` when `use_dynamic_shifting` is set"
205
+ " to be `True`"
206
+ )
207
+
208
+ if sigmas is None:
209
+ self.num_inference_steps = num_inference_steps
210
+ timesteps = np.linspace(
211
+ self._sigma_to_t(self.sigma_max),
212
+ self._sigma_to_t(self.sigma_min),
213
+ num_inference_steps,
214
+ )
215
+
216
+ sigmas = timesteps / self.config.num_train_timesteps
217
+
218
+ if self.config.use_dynamic_shifting:
219
+ sigmas = self.time_shift(mu, 1.0, sigmas)
220
+ else:
221
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
222
+
223
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
224
+ timesteps = sigmas * self.config.num_train_timesteps
225
+
226
+ self.timesteps = timesteps.to(device=device)
227
+ self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
228
+
229
+ self._step_index = None
230
+ self._begin_index = None
231
+
232
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
233
+ if schedule_timesteps is None:
234
+ schedule_timesteps = self.timesteps
235
+
236
+ indices = (schedule_timesteps == timestep).nonzero()
237
+
238
+ # The sigma index that is taken for the **very** first `step`
239
+ # is always the second index (or the last index if there is only 1)
240
+ # This way we can ensure we don't accidentally skip a sigma in
241
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
242
+ pos = 1 if len(indices) > 1 else 0
243
+
244
+ return indices[pos].item()
245
+
246
+ def _init_step_index(self, timestep):
247
+ if self.begin_index is None:
248
+ if isinstance(timestep, torch.Tensor):
249
+ timestep = timestep.to(self.timesteps.device)
250
+ self._step_index = self.index_for_timestep(timestep)
251
+ else:
252
+ self._step_index = self._begin_index
253
+
254
+ def step(
255
+ self,
256
+ model_output: torch.FloatTensor,
257
+ timestep: Union[float, torch.FloatTensor],
258
+ sample: torch.FloatTensor,
259
+ s_churn: float = 0.0,
260
+ s_tmin: float = 0.0,
261
+ s_tmax: float = float("inf"),
262
+ s_noise: float = 1.0,
263
+ generator: Optional[torch.Generator] = None,
264
+ return_dict: bool = True,
265
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
266
+ """
267
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
268
+ process from the learned model outputs (most often the predicted noise).
269
+
270
+ Args:
271
+ model_output (`torch.FloatTensor`):
272
+ The direct output from learned diffusion model.
273
+ timestep (`float`):
274
+ The current discrete timestep in the diffusion chain.
275
+ sample (`torch.FloatTensor`):
276
+ A current instance of a sample created by the diffusion process.
277
+ s_churn (`float`):
278
+ s_tmin (`float`):
279
+ s_tmax (`float`):
280
+ s_noise (`float`, defaults to 1.0):
281
+ Scaling factor for noise added to the sample.
282
+ generator (`torch.Generator`, *optional*):
283
+ A random number generator.
284
+ return_dict (`bool`):
285
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
286
+ tuple.
287
+
288
+ Returns:
289
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
290
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
291
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
292
+ """
293
+
294
+ if (
295
+ isinstance(timestep, int)
296
+ or isinstance(timestep, torch.IntTensor)
297
+ or isinstance(timestep, torch.LongTensor)
298
+ ):
299
+ raise ValueError(
300
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as"
301
+ " timesteps to `EulerDiscreteScheduler.step()` is not supported. Make"
302
+ " sure to pass one of the `scheduler.timesteps` as a timestep.",
303
+ )
304
+
305
+ if self.step_index is None:
306
+ self._init_step_index(timestep)
307
+
308
+ # Upcast to avoid precision issues when computing prev_sample
309
+ sample = sample.to(torch.float32)
310
+
311
+ sigma = self.sigmas[self.step_index]
312
+ sigma_next = self.sigmas[self.step_index + 1]
313
+
314
+ prev_sample = sample + (sigma_next - sigma) * model_output
315
+
316
+ # Cast sample back to model compatible dtype
317
+ prev_sample = prev_sample.to(model_output.dtype)
318
+
319
+ # upon completion increase step index by one
320
+ self._step_index += 1
321
+
322
+ if not return_dict:
323
+ return (prev_sample,)
324
+
325
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
326
+
327
+ def __len__(self):
328
+ return self.config.num_train_timesteps
329
+
XPart/partgen/models/diffusion/transport/__init__.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
+ # which is licensed under the MIT License.
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ from .transport import Transport, ModelType, WeightType, PathType, Sampler
27
+
28
+
29
+ def create_transport(
30
+ path_type='Linear',
31
+ prediction="velocity",
32
+ loss_weight=None,
33
+ train_eps=None,
34
+ sample_eps=None,
35
+ train_sample_type="uniform",
36
+ mean = 0.0,
37
+ std = 1.0,
38
+ shift_scale = 1.0,
39
+ ):
40
+ """function for creating Transport object
41
+ **Note**: model prediction defaults to velocity
42
+ Args:
43
+ - path_type: type of path to use; default to linear
44
+ - learn_score: set model prediction to score
45
+ - learn_noise: set model prediction to noise
46
+ - velocity_weighted: weight loss by velocity weight
47
+ - likelihood_weighted: weight loss by likelihood weight
48
+ - train_eps: small epsilon for avoiding instability during training
49
+ - sample_eps: small epsilon for avoiding instability during sampling
50
+ """
51
+
52
+ if prediction == "noise":
53
+ model_type = ModelType.NOISE
54
+ elif prediction == "score":
55
+ model_type = ModelType.SCORE
56
+ else:
57
+ model_type = ModelType.VELOCITY
58
+
59
+ if loss_weight == "velocity":
60
+ loss_type = WeightType.VELOCITY
61
+ elif loss_weight == "likelihood":
62
+ loss_type = WeightType.LIKELIHOOD
63
+ else:
64
+ loss_type = WeightType.NONE
65
+
66
+ path_choice = {
67
+ "Linear": PathType.LINEAR,
68
+ "GVP": PathType.GVP,
69
+ "VP": PathType.VP,
70
+ }
71
+
72
+ path_type = path_choice[path_type]
73
+
74
+ if (path_type in [PathType.VP]):
75
+ train_eps = 1e-5 if train_eps is None else train_eps
76
+ sample_eps = 1e-3 if train_eps is None else sample_eps
77
+ elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
78
+ train_eps = 1e-3 if train_eps is None else train_eps
79
+ sample_eps = 1e-3 if train_eps is None else sample_eps
80
+ else: # velocity & [GVP, LINEAR] is stable everywhere
81
+ train_eps = 0
82
+ sample_eps = 0
83
+
84
+ # create flow state
85
+ state = Transport(
86
+ model_type=model_type,
87
+ path_type=path_type,
88
+ loss_type=loss_type,
89
+ train_eps=train_eps,
90
+ sample_eps=sample_eps,
91
+ train_sample_type=train_sample_type,
92
+ mean=mean,
93
+ std=std,
94
+ shift_scale =shift_scale,
95
+ )
96
+
97
+ return state
XPart/partgen/models/diffusion/transport/integrators.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
+ # which is licensed under the MIT License.
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ import numpy as np
27
+ import torch as th
28
+ import torch.nn as nn
29
+ from torchdiffeq import odeint
30
+ from functools import partial
31
+ from tqdm import tqdm
32
+
33
+ class sde:
34
+ """SDE solver class"""
35
+ def __init__(
36
+ self,
37
+ drift,
38
+ diffusion,
39
+ *,
40
+ t0,
41
+ t1,
42
+ num_steps,
43
+ sampler_type,
44
+ ):
45
+ assert t0 < t1, "SDE sampler has to be in forward time"
46
+
47
+ self.num_timesteps = num_steps
48
+ self.t = th.linspace(t0, t1, num_steps)
49
+ self.dt = self.t[1] - self.t[0]
50
+ self.drift = drift
51
+ self.diffusion = diffusion
52
+ self.sampler_type = sampler_type
53
+
54
+ def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
55
+ w_cur = th.randn(x.size()).to(x)
56
+ t = th.ones(x.size(0)).to(x) * t
57
+ dw = w_cur * th.sqrt(self.dt)
58
+ drift = self.drift(x, t, model, **model_kwargs)
59
+ diffusion = self.diffusion(x, t)
60
+ mean_x = x + drift * self.dt
61
+ x = mean_x + th.sqrt(2 * diffusion) * dw
62
+ return x, mean_x
63
+
64
+ def __Heun_step(self, x, _, t, model, **model_kwargs):
65
+ w_cur = th.randn(x.size()).to(x)
66
+ dw = w_cur * th.sqrt(self.dt)
67
+ t_cur = th.ones(x.size(0)).to(x) * t
68
+ diffusion = self.diffusion(x, t_cur)
69
+ xhat = x + th.sqrt(2 * diffusion) * dw
70
+ K1 = self.drift(xhat, t_cur, model, **model_kwargs)
71
+ xp = xhat + self.dt * K1
72
+ K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
73
+ return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
74
+
75
+ def __forward_fn(self):
76
+ """TODO: generalize here by adding all private functions ending with steps to it"""
77
+ sampler_dict = {
78
+ "Euler": self.__Euler_Maruyama_step,
79
+ "Heun": self.__Heun_step,
80
+ }
81
+
82
+ try:
83
+ sampler = sampler_dict[self.sampler_type]
84
+ except:
85
+ raise NotImplementedError("Smapler type not implemented.")
86
+
87
+ return sampler
88
+
89
+ def sample(self, init, model, **model_kwargs):
90
+ """forward loop of sde"""
91
+ x = init
92
+ mean_x = init
93
+ samples = []
94
+ sampler = self.__forward_fn()
95
+ for ti in self.t[:-1]:
96
+ with th.no_grad():
97
+ x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
98
+ samples.append(x)
99
+
100
+ return samples
101
+
102
+ class ode:
103
+ """ODE solver class"""
104
+ def __init__(
105
+ self,
106
+ drift,
107
+ *,
108
+ t0,
109
+ t1,
110
+ sampler_type,
111
+ num_steps,
112
+ atol,
113
+ rtol,
114
+ ):
115
+ assert t0 < t1, "ODE sampler has to be in forward time"
116
+
117
+ self.drift = drift
118
+ self.t = th.linspace(t0, t1, num_steps)
119
+ self.atol = atol
120
+ self.rtol = rtol
121
+ self.sampler_type = sampler_type
122
+
123
+ def sample(self, x, model, **model_kwargs):
124
+
125
+ device = x[0].device if isinstance(x, tuple) else x.device
126
+ def _fn(t, x):
127
+ t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
128
+ model_output = self.drift(x, t, model, **model_kwargs)
129
+ return model_output
130
+
131
+ t = self.t.to(device)
132
+ atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
133
+ rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
134
+ samples = odeint(
135
+ _fn,
136
+ x,
137
+ t,
138
+ method=self.sampler_type,
139
+ atol=atol,
140
+ rtol=rtol
141
+ )
142
+ return samples
XPart/partgen/models/diffusion/transport/path.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
+ # which is licensed under the MIT License.
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ import torch as th
27
+ import numpy as np
28
+ from functools import partial
29
+
30
+ def expand_t_like_x(t, x):
31
+ """Function to reshape time t to broadcastable dimension of x
32
+ Args:
33
+ t: [batch_dim,], time vector
34
+ x: [batch_dim,...], data point
35
+ """
36
+ dims = [1] * (len(x.size()) - 1)
37
+ t = t.view(t.size(0), *dims)
38
+ return t
39
+
40
+
41
+ #################### Coupling Plans ####################
42
+
43
+ class ICPlan:
44
+ """Linear Coupling Plan"""
45
+ def __init__(self, sigma=0.0):
46
+ self.sigma = sigma
47
+
48
+ def compute_alpha_t(self, t):
49
+ """Compute the data coefficient along the path"""
50
+ return t, 1
51
+
52
+ def compute_sigma_t(self, t):
53
+ """Compute the noise coefficient along the path"""
54
+ return 1 - t, -1
55
+
56
+ def compute_d_alpha_alpha_ratio_t(self, t):
57
+ """Compute the ratio between d_alpha and alpha"""
58
+ return 1 / t
59
+
60
+ def compute_drift(self, x, t):
61
+ """We always output sde according to score parametrization; """
62
+ t = expand_t_like_x(t, x)
63
+ alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
64
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
65
+ drift = alpha_ratio * x
66
+ diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
67
+
68
+ return -drift, diffusion
69
+
70
+ def compute_diffusion(self, x, t, form="constant", norm=1.0):
71
+ """Compute the diffusion term of the SDE
72
+ Args:
73
+ x: [batch_dim, ...], data point
74
+ t: [batch_dim,], time vector
75
+ form: str, form of the diffusion term
76
+ norm: float, norm of the diffusion term
77
+ """
78
+ t = expand_t_like_x(t, x)
79
+ choices = {
80
+ "constant": norm,
81
+ "SBDM": norm * self.compute_drift(x, t)[1],
82
+ "sigma": norm * self.compute_sigma_t(t)[0],
83
+ "linear": norm * (1 - t),
84
+ "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
85
+ "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
86
+ }
87
+
88
+ try:
89
+ diffusion = choices[form]
90
+ except KeyError:
91
+ raise NotImplementedError(f"Diffusion form {form} not implemented")
92
+
93
+ return diffusion
94
+
95
+ def get_score_from_velocity(self, velocity, x, t):
96
+ """Wrapper function: transfrom velocity prediction model to score
97
+ Args:
98
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
99
+ x: [batch_dim, ...] shaped tensor; x_t data point
100
+ t: [batch_dim,] time tensor
101
+ """
102
+ t = expand_t_like_x(t, x)
103
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
104
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
105
+ mean = x
106
+ reverse_alpha_ratio = alpha_t / d_alpha_t
107
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
108
+ score = (reverse_alpha_ratio * velocity - mean) / var
109
+ return score
110
+
111
+ def get_noise_from_velocity(self, velocity, x, t):
112
+ """Wrapper function: transfrom velocity prediction model to denoiser
113
+ Args:
114
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
115
+ x: [batch_dim, ...] shaped tensor; x_t data point
116
+ t: [batch_dim,] time tensor
117
+ """
118
+ t = expand_t_like_x(t, x)
119
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
120
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
121
+ mean = x
122
+ reverse_alpha_ratio = alpha_t / d_alpha_t
123
+ var = reverse_alpha_ratio * d_sigma_t - sigma_t
124
+ noise = (reverse_alpha_ratio * velocity - mean) / var
125
+ return noise
126
+
127
+ def get_velocity_from_score(self, score, x, t):
128
+ """Wrapper function: transfrom score prediction model to velocity
129
+ Args:
130
+ score: [batch_dim, ...] shaped tensor; score model output
131
+ x: [batch_dim, ...] shaped tensor; x_t data point
132
+ t: [batch_dim,] time tensor
133
+ """
134
+ t = expand_t_like_x(t, x)
135
+ drift, var = self.compute_drift(x, t)
136
+ velocity = var * score - drift
137
+ return velocity
138
+
139
+ def compute_mu_t(self, t, x0, x1):
140
+ """Compute the mean of time-dependent density p_t"""
141
+ t = expand_t_like_x(t, x1)
142
+ alpha_t, _ = self.compute_alpha_t(t)
143
+ sigma_t, _ = self.compute_sigma_t(t)
144
+ # t*x1 + (1-t)*x0 ; t=0 x0; t=1 x1
145
+ return alpha_t * x1 + sigma_t * x0
146
+
147
+ def compute_xt(self, t, x0, x1):
148
+ """Sample xt from time-dependent density p_t; rng is required"""
149
+ xt = self.compute_mu_t(t, x0, x1)
150
+ return xt
151
+
152
+ def compute_ut(self, t, x0, x1, xt):
153
+ """Compute the vector field corresponding to p_t"""
154
+ t = expand_t_like_x(t, x1)
155
+ _, d_alpha_t = self.compute_alpha_t(t)
156
+ _, d_sigma_t = self.compute_sigma_t(t)
157
+ return d_alpha_t * x1 + d_sigma_t * x0
158
+
159
+ def plan(self, t, x0, x1):
160
+ xt = self.compute_xt(t, x0, x1)
161
+ ut = self.compute_ut(t, x0, x1, xt)
162
+ return t, xt, ut
163
+
164
+
165
+ class VPCPlan(ICPlan):
166
+ """class for VP path flow matching"""
167
+
168
+ def __init__(self, sigma_min=0.1, sigma_max=20.0):
169
+ self.sigma_min = sigma_min
170
+ self.sigma_max = sigma_max
171
+ self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * \
172
+ (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
173
+ self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * \
174
+ (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
175
+
176
+
177
+ def compute_alpha_t(self, t):
178
+ """Compute coefficient of x1"""
179
+ alpha_t = self.log_mean_coeff(t)
180
+ alpha_t = th.exp(alpha_t)
181
+ d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
182
+ return alpha_t, d_alpha_t
183
+
184
+ def compute_sigma_t(self, t):
185
+ """Compute coefficient of x0"""
186
+ p_sigma_t = 2 * self.log_mean_coeff(t)
187
+ sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
188
+ d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
189
+ return sigma_t, d_sigma_t
190
+
191
+ def compute_d_alpha_alpha_ratio_t(self, t):
192
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
193
+ return self.d_log_mean_coeff(t)
194
+
195
+ def compute_drift(self, x, t):
196
+ """Compute the drift term of the SDE"""
197
+ t = expand_t_like_x(t, x)
198
+ beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
199
+ return -0.5 * beta_t * x, beta_t / 2
200
+
201
+
202
+ class GVPCPlan(ICPlan):
203
+ def __init__(self, sigma=0.0):
204
+ super().__init__(sigma)
205
+
206
+ def compute_alpha_t(self, t):
207
+ """Compute coefficient of x1"""
208
+ alpha_t = th.sin(t * np.pi / 2)
209
+ d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
210
+ return alpha_t, d_alpha_t
211
+
212
+ def compute_sigma_t(self, t):
213
+ """Compute coefficient of x0"""
214
+ sigma_t = th.cos(t * np.pi / 2)
215
+ d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
216
+ return sigma_t, d_sigma_t
217
+
218
+ def compute_d_alpha_alpha_ratio_t(self, t):
219
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
220
+ return np.pi / (2 * th.tan(t * np.pi / 2))
XPart/partgen/models/diffusion/transport/transport.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
+ # which is licensed under the MIT License.
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ import torch as th
27
+ import numpy as np
28
+ import logging
29
+
30
+ import enum
31
+
32
+ from . import path
33
+ from .utils import EasyDict, log_state, mean_flat
34
+ from .integrators import ode, sde
35
+
36
+
37
+ class ModelType(enum.Enum):
38
+ """
39
+ Which type of output the model predicts.
40
+ """
41
+
42
+ NOISE = enum.auto() # the model predicts epsilon
43
+ SCORE = enum.auto() # the model predicts \nabla \log p(x)
44
+ VELOCITY = enum.auto() # the model predicts v(x)
45
+
46
+
47
+ class PathType(enum.Enum):
48
+ """
49
+ Which type of path to use.
50
+ """
51
+
52
+ LINEAR = enum.auto()
53
+ GVP = enum.auto()
54
+ VP = enum.auto()
55
+
56
+
57
+ class WeightType(enum.Enum):
58
+ """
59
+ Which type of weighting to use.
60
+ """
61
+
62
+ NONE = enum.auto()
63
+ VELOCITY = enum.auto()
64
+ LIKELIHOOD = enum.auto()
65
+
66
+
67
+ class Transport:
68
+
69
+ def __init__(
70
+ self,
71
+ *,
72
+ model_type,
73
+ path_type,
74
+ loss_type,
75
+ train_eps,
76
+ sample_eps,
77
+ train_sample_type="uniform",
78
+ **kwargs,
79
+ ):
80
+ path_options = {
81
+ PathType.LINEAR: path.ICPlan,
82
+ PathType.GVP: path.GVPCPlan,
83
+ PathType.VP: path.VPCPlan,
84
+ }
85
+
86
+ self.loss_type = loss_type
87
+ self.model_type = model_type
88
+ self.path_sampler = path_options[path_type]()
89
+ self.train_eps = train_eps
90
+ self.sample_eps = sample_eps
91
+ self.train_sample_type = train_sample_type
92
+ if self.train_sample_type == "logit_normal":
93
+ self.mean = kwargs["mean"]
94
+ self.std = kwargs["std"]
95
+ self.shift_scale = kwargs["shift_scale"]
96
+ print(f"using logit normal sample, shift scale is {self.shift_scale}")
97
+
98
+ def prior_logp(self, z):
99
+ """
100
+ Standard multivariate normal prior
101
+ Assume z is batched
102
+ """
103
+ shape = th.tensor(z.size())
104
+ N = th.prod(shape[1:])
105
+ _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
106
+ return th.vmap(_fn)(z)
107
+
108
+ def check_interval(
109
+ self,
110
+ train_eps,
111
+ sample_eps,
112
+ *,
113
+ diffusion_form="SBDM",
114
+ sde=False,
115
+ reverse=False,
116
+ eval=False,
117
+ last_step_size=0.0,
118
+ ):
119
+ t0 = 0
120
+ t1 = 1
121
+ eps = train_eps if not eval else sample_eps
122
+ if type(self.path_sampler) in [path.VPCPlan]:
123
+
124
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
125
+
126
+ elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
127
+ self.model_type != ModelType.VELOCITY or sde
128
+ ): # avoid numerical issue by taking a first semi-implicit step
129
+
130
+ t0 = (
131
+ eps
132
+ if (diffusion_form == "SBDM" and sde)
133
+ or self.model_type != ModelType.VELOCITY
134
+ else 0
135
+ )
136
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
137
+
138
+ if reverse:
139
+ t0, t1 = 1 - t0, 1 - t1
140
+
141
+ return t0, t1
142
+
143
+ def sample(self, x1):
144
+ """Sampling x0 & t based on shape of x1 (if needed)
145
+ Args:
146
+ x1 - data point; [batch, *dim]
147
+ """
148
+
149
+ x0 = th.randn_like(x1)
150
+ if self.train_sample_type == "uniform":
151
+ t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
152
+ t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
153
+ t = t.to(x1)
154
+ elif self.train_sample_type == "logit_normal":
155
+ t = th.randn((x1.shape[0],)) * self.std + self.mean
156
+ t = t.to(x1)
157
+ t = 1 / (1 + th.exp(-t))
158
+
159
+ t = (
160
+ np.sqrt(self.shift_scale)
161
+ * t
162
+ / (1 + (np.sqrt(self.shift_scale) - 1) * t)
163
+ )
164
+
165
+ return t, x0, x1
166
+
167
+ def training_losses(self, model, x1, model_kwargs=None):
168
+ """Loss for training the score model
169
+ Args:
170
+ - model: backbone model; could be score, noise, or velocity
171
+ - x1: datapoint
172
+ - model_kwargs: additional arguments for the model
173
+ """
174
+ if model_kwargs == None:
175
+ model_kwargs = {}
176
+
177
+ t, x0, x1 = self.sample(x1)
178
+ t, xt, ut = self.path_sampler.plan(t, x0, x1)
179
+ model_output = model(xt, t, **model_kwargs)
180
+ B, *_, C = xt.shape
181
+ assert model_output.size() == (B, *xt.size()[1:-1], C)
182
+
183
+ terms = {}
184
+ terms["pred"] = model_output
185
+ if self.model_type == ModelType.VELOCITY:
186
+ terms["loss"] = mean_flat(((model_output - ut) ** 2))
187
+ else:
188
+ _, drift_var = self.path_sampler.compute_drift(xt, t)
189
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
190
+ if self.loss_type in [WeightType.VELOCITY]:
191
+ weight = (drift_var / sigma_t) ** 2
192
+ elif self.loss_type in [WeightType.LIKELIHOOD]:
193
+ weight = drift_var / (sigma_t**2)
194
+ elif self.loss_type in [WeightType.NONE]:
195
+ weight = 1
196
+ else:
197
+ raise NotImplementedError()
198
+
199
+ if self.model_type == ModelType.NOISE:
200
+ terms["loss"] = mean_flat(weight * ((model_output - x0) ** 2))
201
+ else:
202
+ terms["loss"] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
203
+
204
+ return terms
205
+
206
+ def get_drift(self):
207
+ """member function for obtaining the drift of the probability flow ODE"""
208
+
209
+ def score_ode(x, t, model, **model_kwargs):
210
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
211
+ model_output = model(x, t, **model_kwargs)
212
+ return -drift_mean + drift_var * model_output # by change of variable
213
+
214
+ def noise_ode(x, t, model, **model_kwargs):
215
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
216
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
217
+ model_output = model(x, t, **model_kwargs)
218
+ score = model_output / -sigma_t
219
+ return -drift_mean + drift_var * score
220
+
221
+ def velocity_ode(x, t, model, **model_kwargs):
222
+ model_output = model(x, t, **model_kwargs)
223
+ return model_output
224
+
225
+ if self.model_type == ModelType.NOISE:
226
+ drift_fn = noise_ode
227
+ elif self.model_type == ModelType.SCORE:
228
+ drift_fn = score_ode
229
+ else:
230
+ drift_fn = velocity_ode
231
+
232
+ def body_fn(x, t, model, **model_kwargs):
233
+ model_output = drift_fn(x, t, model, **model_kwargs)
234
+ assert (
235
+ model_output.shape == x.shape
236
+ ), "Output shape from ODE solver must match input shape"
237
+ return model_output
238
+
239
+ return body_fn
240
+
241
+ def get_score(
242
+ self,
243
+ ):
244
+ """member function for obtaining score of
245
+ x_t = alpha_t * x + sigma_t * eps"""
246
+ if self.model_type == ModelType.NOISE:
247
+ score_fn = (
248
+ lambda x, t, model, **kwargs: model(x, t, **kwargs)
249
+ / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
250
+ )
251
+ elif self.model_type == ModelType.SCORE:
252
+ score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
253
+ elif self.model_type == ModelType.VELOCITY:
254
+ score_fn = (
255
+ lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
256
+ model(x, t, **kwargs), x, t
257
+ )
258
+ )
259
+ else:
260
+ raise NotImplementedError()
261
+
262
+ return score_fn
263
+
264
+
265
+ class Sampler:
266
+ """Sampler class for the transport model"""
267
+
268
+ def __init__(
269
+ self,
270
+ transport,
271
+ ):
272
+ """Constructor for a general sampler; supporting different sampling methods
273
+ Args:
274
+ - transport: an tranport object specify model prediction & interpolant type
275
+ """
276
+
277
+ self.transport = transport
278
+ self.drift = self.transport.get_drift()
279
+ self.score = self.transport.get_score()
280
+
281
+ def __get_sde_diffusion_and_drift(
282
+ self,
283
+ *,
284
+ diffusion_form="SBDM",
285
+ diffusion_norm=1.0,
286
+ ):
287
+
288
+ def diffusion_fn(x, t):
289
+ diffusion = self.transport.path_sampler.compute_diffusion(
290
+ x, t, form=diffusion_form, norm=diffusion_norm
291
+ )
292
+ return diffusion
293
+
294
+ sde_drift = lambda x, t, model, **kwargs: self.drift(
295
+ x, t, model, **kwargs
296
+ ) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
297
+
298
+ sde_diffusion = diffusion_fn
299
+
300
+ return sde_drift, sde_diffusion
301
+
302
+ def __get_last_step(
303
+ self,
304
+ sde_drift,
305
+ *,
306
+ last_step,
307
+ last_step_size,
308
+ ):
309
+ """Get the last step function of the SDE solver"""
310
+
311
+ if last_step is None:
312
+ last_step_fn = lambda x, t, model, **model_kwargs: x
313
+ elif last_step == "Mean":
314
+ last_step_fn = (
315
+ lambda x, t, model, **model_kwargs: x
316
+ + sde_drift(x, t, model, **model_kwargs) * last_step_size
317
+ )
318
+ elif last_step == "Tweedie":
319
+ alpha = (
320
+ self.transport.path_sampler.compute_alpha_t
321
+ ) # simple aliasing; the original name was too long
322
+ sigma = self.transport.path_sampler.compute_sigma_t
323
+ last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (
324
+ sigma(t)[0][0] ** 2
325
+ ) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
326
+ elif last_step == "Euler":
327
+ last_step_fn = (
328
+ lambda x, t, model, **model_kwargs: x
329
+ + self.drift(x, t, model, **model_kwargs) * last_step_size
330
+ )
331
+ else:
332
+ raise NotImplementedError()
333
+
334
+ return last_step_fn
335
+
336
+ def sample_sde(
337
+ self,
338
+ *,
339
+ sampling_method="Euler",
340
+ diffusion_form="SBDM",
341
+ diffusion_norm=1.0,
342
+ last_step="Mean",
343
+ last_step_size=0.04,
344
+ num_steps=250,
345
+ ):
346
+ """returns a sampling function with given SDE settings
347
+ Args:
348
+ - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
349
+ - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
350
+ - diffusion_norm: function magnitude of diffusion coefficient; default to 1
351
+ - last_step: type of the last step; default to identity
352
+ - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
353
+ - num_steps: total integration step of SDE
354
+ """
355
+
356
+ if last_step is None:
357
+ last_step_size = 0.0
358
+
359
+ sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
360
+ diffusion_form=diffusion_form,
361
+ diffusion_norm=diffusion_norm,
362
+ )
363
+
364
+ t0, t1 = self.transport.check_interval(
365
+ self.transport.train_eps,
366
+ self.transport.sample_eps,
367
+ diffusion_form=diffusion_form,
368
+ sde=True,
369
+ eval=True,
370
+ reverse=False,
371
+ last_step_size=last_step_size,
372
+ )
373
+
374
+ _sde = sde(
375
+ sde_drift,
376
+ sde_diffusion,
377
+ t0=t0,
378
+ t1=t1,
379
+ num_steps=num_steps,
380
+ sampler_type=sampling_method,
381
+ )
382
+
383
+ last_step_fn = self.__get_last_step(
384
+ sde_drift, last_step=last_step, last_step_size=last_step_size
385
+ )
386
+
387
+ def _sample(init, model, **model_kwargs):
388
+ xs = _sde.sample(init, model, **model_kwargs)
389
+ ts = th.ones(init.size(0), device=init.device) * t1
390
+ x = last_step_fn(xs[-1], ts, model, **model_kwargs)
391
+ xs.append(x)
392
+
393
+ assert len(xs) == num_steps, "Samples does not match the number of steps"
394
+
395
+ return xs
396
+
397
+ return _sample
398
+
399
+ def sample_ode(
400
+ self,
401
+ *,
402
+ sampling_method="dopri5",
403
+ num_steps=50,
404
+ atol=1e-6,
405
+ rtol=1e-3,
406
+ reverse=False,
407
+ ):
408
+ """returns a sampling function with given ODE settings
409
+ Args:
410
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
411
+ - num_steps:
412
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
413
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
414
+ - atol: absolute error tolerance for the solver
415
+ - rtol: relative error tolerance for the solver
416
+ - reverse: whether solving the ODE in reverse (data to noise); default to False
417
+ """
418
+ if reverse:
419
+ drift = lambda x, t, model, **kwargs: self.drift(
420
+ x, th.ones_like(t) * (1 - t), model, **kwargs
421
+ )
422
+ else:
423
+ drift = self.drift
424
+
425
+ t0, t1 = self.transport.check_interval(
426
+ self.transport.train_eps,
427
+ self.transport.sample_eps,
428
+ sde=False,
429
+ eval=True,
430
+ reverse=reverse,
431
+ last_step_size=0.0,
432
+ )
433
+
434
+ _ode = ode(
435
+ drift=drift,
436
+ t0=t0,
437
+ t1=t1,
438
+ sampler_type=sampling_method,
439
+ num_steps=num_steps,
440
+ atol=atol,
441
+ rtol=rtol,
442
+ )
443
+
444
+ return _ode.sample
445
+
446
+
447
+ def sample_ode_likelihood(
448
+ self,
449
+ *,
450
+ sampling_method="dopri5",
451
+ num_steps=50,
452
+ atol=1e-6,
453
+ rtol=1e-3,
454
+ ):
455
+ """returns a sampling function for calculating likelihood with given ODE settings
456
+ Args:
457
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
458
+ - num_steps:
459
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
460
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
461
+ - atol: absolute error tolerance for the solver
462
+ - rtol: relative error tolerance for the solver
463
+ """
464
+
465
+ def _likelihood_drift(x, t, model, **model_kwargs):
466
+ x, _ = x
467
+ eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
468
+ t = th.ones_like(t) * (1 - t)
469
+ with th.enable_grad():
470
+ x.requires_grad = True
471
+ grad = th.autograd.grad(
472
+ th.sum(self.drift(x, t, model, **model_kwargs) * eps), x
473
+ )[0]
474
+ logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
475
+ drift = self.drift(x, t, model, **model_kwargs)
476
+ return (-drift, logp_grad)
477
+
478
+ t0, t1 = self.transport.check_interval(
479
+ self.transport.train_eps,
480
+ self.transport.sample_eps,
481
+ sde=False,
482
+ eval=True,
483
+ reverse=False,
484
+ last_step_size=0.0,
485
+ )
486
+
487
+ _ode = ode(
488
+ drift=_likelihood_drift,
489
+ t0=t0,
490
+ t1=t1,
491
+ sampler_type=sampling_method,
492
+ num_steps=num_steps,
493
+ atol=atol,
494
+ rtol=rtol,
495
+ )
496
+
497
+ def _sample_fn(x, model, **model_kwargs):
498
+ init_logp = th.zeros(x.size(0)).to(x)
499
+ input = (x, init_logp)
500
+ drift, delta_logp = _ode.sample(input, model, **model_kwargs)
501
+ drift, delta_logp = drift[-1], delta_logp[-1]
502
+ prior_logp = self.transport.prior_logp(drift)
503
+ logp = prior_logp - delta_logp
504
+ return logp, drift
505
+
506
+ return _sample_fn
XPart/partgen/models/diffusion/transport/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file includes code derived from the SiT project (https://github.com/willisma/SiT),
2
+ # which is licensed under the MIT License.
3
+ #
4
+ # MIT License
5
+ #
6
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ import torch as th
27
+
28
+ class EasyDict:
29
+
30
+ def __init__(self, sub_dict):
31
+ for k, v in sub_dict.items():
32
+ setattr(self, k, v)
33
+
34
+ def __getitem__(self, key):
35
+ return getattr(self, key)
36
+
37
+ def mean_flat(x):
38
+ """
39
+ Take the mean over all non-batch dimensions.
40
+ """
41
+ return th.mean(x, dim=list(range(1, len(x.size()))))
42
+
43
+ def log_state(state):
44
+ result = []
45
+
46
+ sorted_state = dict(sorted(state.items()))
47
+ for key, value in sorted_state.items():
48
+ # Check if the value is an instance of a class
49
+ if "<object" in str(value) or "object at" in str(value):
50
+ result.append(f"{key}: [{value.__class__.__name__}]")
51
+ else:
52
+ result.append(f"{key}: {value}")
53
+
54
+ return '\n'.join(result)
XPart/partgen/models/moe_layers.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
+ # except for the third-party components listed below.
3
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
+ # in the repsective licenses of these third-party components.
5
+ # Users must comply with all terms and conditions of original licenses of these third-party
6
+ # components and must ensure that the usage of the third party components adheres to
7
+ # all relevant laws and regulations.
8
+
9
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
10
+ # their software and algorithms, including trained model weights, parameters (including
11
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
13
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import math
18
+
19
+ import torch.nn.functional as F
20
+ from diffusers.models.attention import FeedForward
21
+
22
+
23
+ class AddAuxiliaryLoss(torch.autograd.Function):
24
+ """
25
+ The trick function of adding auxiliary (aux) loss,
26
+ which includes the gradient of the aux loss during backpropagation.
27
+ """
28
+
29
+ @staticmethod
30
+ def forward(ctx, x, loss):
31
+ assert loss.numel() == 1
32
+ ctx.dtype = loss.dtype
33
+ ctx.required_aux_loss = loss.requires_grad
34
+ return x
35
+
36
+ @staticmethod
37
+ def backward(ctx, grad_output):
38
+ grad_loss = None
39
+ if ctx.required_aux_loss:
40
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
41
+ return grad_output, grad_loss
42
+
43
+
44
+ class MoEGate(nn.Module):
45
+ def __init__(
46
+ self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01
47
+ ):
48
+ super().__init__()
49
+ self.top_k = num_experts_per_tok
50
+ self.n_routed_experts = num_experts
51
+
52
+ self.scoring_func = "softmax"
53
+ self.alpha = aux_loss_alpha
54
+ self.seq_aux = False
55
+
56
+ # topk selection algorithm
57
+ self.norm_topk_prob = False
58
+ self.gating_dim = embed_dim
59
+ self.weight = nn.Parameter(
60
+ torch.empty((self.n_routed_experts, self.gating_dim))
61
+ )
62
+ self.reset_parameters()
63
+
64
+ def reset_parameters(self) -> None:
65
+ import torch.nn.init as init
66
+
67
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
68
+
69
+ def forward(self, hidden_states):
70
+ bsz, seq_len, h = hidden_states.shape
71
+ # print(bsz, seq_len, h)
72
+ ### compute gating score
73
+ hidden_states = hidden_states.view(-1, h)
74
+ logits = F.linear(hidden_states, self.weight, None)
75
+ if self.scoring_func == "softmax":
76
+ scores = logits.softmax(dim=-1)
77
+ else:
78
+ raise NotImplementedError(
79
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
80
+ )
81
+
82
+ ### select top-k experts
83
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
84
+
85
+ ### norm gate to sum 1
86
+ if self.top_k > 1 and self.norm_topk_prob:
87
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
88
+ topk_weight = topk_weight / denominator
89
+
90
+ ### expert-level computation auxiliary loss
91
+ if self.training and self.alpha > 0.0:
92
+ scores_for_aux = scores
93
+ aux_topk = self.top_k
94
+ # always compute aux loss based on the naive greedy topk method
95
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
96
+ if self.seq_aux:
97
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
98
+ ce = torch.zeros(
99
+ bsz, self.n_routed_experts, device=hidden_states.device
100
+ )
101
+ ce.scatter_add_(
102
+ 1,
103
+ topk_idx_for_aux_loss,
104
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
105
+ ).div_(seq_len * aux_topk / self.n_routed_experts)
106
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean()
107
+ aux_loss = aux_loss * self.alpha
108
+ else:
109
+ mask_ce = F.one_hot(
110
+ topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
111
+ )
112
+ ce = mask_ce.float().mean(0)
113
+ Pi = scores_for_aux.mean(0)
114
+ fi = ce * self.n_routed_experts
115
+ aux_loss = (Pi * fi).sum() * self.alpha
116
+ else:
117
+ aux_loss = None
118
+ return topk_idx, topk_weight, aux_loss
119
+
120
+
121
+ class MoEBlock(nn.Module):
122
+ def __init__(
123
+ self,
124
+ dim,
125
+ num_experts=8,
126
+ moe_top_k=2,
127
+ activation_fn="gelu",
128
+ dropout=0.0,
129
+ final_dropout=False,
130
+ ff_inner_dim=None,
131
+ ff_bias=True,
132
+ ):
133
+ super().__init__()
134
+ self.moe_top_k = moe_top_k
135
+ self.experts = nn.ModuleList([
136
+ FeedForward(
137
+ dim,
138
+ dropout=dropout,
139
+ activation_fn=activation_fn,
140
+ final_dropout=final_dropout,
141
+ inner_dim=ff_inner_dim,
142
+ bias=ff_bias,
143
+ )
144
+ for i in range(num_experts)
145
+ ])
146
+ self.gate = MoEGate(
147
+ embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k
148
+ )
149
+
150
+ self.shared_experts = FeedForward(
151
+ dim,
152
+ dropout=dropout,
153
+ activation_fn=activation_fn,
154
+ final_dropout=final_dropout,
155
+ inner_dim=ff_inner_dim,
156
+ bias=ff_bias,
157
+ )
158
+
159
+ def initialize_weight(self):
160
+ pass
161
+
162
+ def forward(self, hidden_states):
163
+ identity = hidden_states
164
+ orig_shape = hidden_states.shape
165
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
166
+
167
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
168
+ flat_topk_idx = topk_idx.view(-1)
169
+ if self.training:
170
+ hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0)
171
+ y = torch.empty_like(hidden_states, dtype=hidden_states.dtype)
172
+ for i, expert in enumerate(self.experts):
173
+ tmp = expert(hidden_states[flat_topk_idx == i])
174
+ y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
175
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
176
+ y = y.view(*orig_shape)
177
+ y = AddAuxiliaryLoss.apply(y, aux_loss)
178
+ else:
179
+ y = self.moe_infer(
180
+ hidden_states, flat_topk_idx, topk_weight.view(-1, 1)
181
+ ).view(*orig_shape)
182
+ y = y + self.shared_experts(identity)
183
+ return y
184
+
185
+ @torch.no_grad()
186
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
187
+ expert_cache = torch.zeros_like(x)
188
+ idxs = flat_expert_indices.argsort()
189
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
190
+ token_idxs = idxs // self.moe_top_k
191
+ for i, end_idx in enumerate(tokens_per_expert):
192
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
193
+ if start_idx == end_idx:
194
+ continue
195
+ expert = self.experts[i]
196
+ exp_token_idx = token_idxs[start_idx:end_idx]
197
+ expert_tokens = x[exp_token_idx]
198
+ expert_out = expert(expert_tokens)
199
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
200
+
201
+ # for fp16 and other dtype
202
+ expert_cache = expert_cache.to(expert_out.dtype)
203
+ expert_cache.scatter_reduce_(
204
+ 0,
205
+ exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]),
206
+ expert_out,
207
+ reduce="sum",
208
+ )
209
+ return expert_cache
XPart/partgen/models/partformer_dit.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Newest version: add local&global context (cross-attn), and local&global attn (self-attn)
2
+ import math
3
+
4
+ import torch.nn.functional as F
5
+
6
+ import torch.nn as nn
7
+ import torch
8
+ from typing import Optional
9
+ from einops import rearrange
10
+ from .moe_layers import MoEBlock
11
+ import numpy as np
12
+
13
+
14
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
15
+ """
16
+ embed_dim: output dimension for each position
17
+ pos: a list of positions to be encoded: size (M,)
18
+ out: (M, D)
19
+ """
20
+ assert embed_dim % 2 == 0
21
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
22
+ omega /= embed_dim / 2.0
23
+ omega = 1.0 / 10000**omega # (D/2,)
24
+
25
+ pos = pos.reshape(-1) # (M,)
26
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
27
+
28
+ emb_sin = np.sin(out) # (M, D/2)
29
+ emb_cos = np.cos(out) # (M, D/2)
30
+
31
+ return np.concatenate([emb_sin, emb_cos], axis=1)
32
+
33
+
34
+ class Timesteps(nn.Module):
35
+ def __init__(
36
+ self,
37
+ num_channels: int,
38
+ downscale_freq_shift: float = 0.0,
39
+ scale: int = 1,
40
+ max_period: int = 10000,
41
+ ):
42
+ super().__init__()
43
+ self.num_channels = num_channels
44
+ self.downscale_freq_shift = downscale_freq_shift
45
+ self.scale = scale
46
+ self.max_period = max_period
47
+
48
+ def forward(self, timesteps):
49
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
50
+ embedding_dim = self.num_channels
51
+ half_dim = embedding_dim // 2
52
+ exponent = -math.log(self.max_period) * torch.arange(
53
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
54
+ )
55
+ exponent = exponent / (half_dim - self.downscale_freq_shift)
56
+ emb = torch.exp(exponent)
57
+ emb = timesteps[:, None].float() * emb[None, :]
58
+ emb = self.scale * emb
59
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
60
+ if embedding_dim % 2 == 1:
61
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
62
+ return emb
63
+
64
+
65
+ class TimestepEmbedder(nn.Module):
66
+ """
67
+ Embeds scalar timesteps into vector representations.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ hidden_size,
73
+ frequency_embedding_size=256,
74
+ cond_proj_dim=None,
75
+ out_size=None,
76
+ ):
77
+ super().__init__()
78
+ if out_size is None:
79
+ out_size = hidden_size
80
+ self.mlp = nn.Sequential(
81
+ nn.Linear(hidden_size, frequency_embedding_size, bias=True),
82
+ nn.GELU(),
83
+ nn.Linear(frequency_embedding_size, out_size, bias=True),
84
+ )
85
+ self.frequency_embedding_size = frequency_embedding_size
86
+
87
+ if cond_proj_dim is not None:
88
+ self.cond_proj = nn.Linear(
89
+ cond_proj_dim, frequency_embedding_size, bias=False
90
+ )
91
+
92
+ self.time_embed = Timesteps(hidden_size)
93
+
94
+ def forward(self, t, condition):
95
+
96
+ t_freq = self.time_embed(t).type(self.mlp[0].weight.dtype)
97
+
98
+ # t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
99
+ if condition is not None:
100
+ t_freq = t_freq + self.cond_proj(condition)
101
+
102
+ t = self.mlp(t_freq)
103
+ t = t.unsqueeze(dim=1)
104
+ return t
105
+
106
+
107
+ class MLP(nn.Module):
108
+ def __init__(self, *, width: int):
109
+ super().__init__()
110
+ self.width = width
111
+ self.fc1 = nn.Linear(width, width * 4)
112
+ self.fc2 = nn.Linear(width * 4, width)
113
+ self.gelu = nn.GELU()
114
+
115
+ def forward(self, x):
116
+ return self.fc2(self.gelu(self.fc1(x)))
117
+
118
+
119
+ class CrossAttention(nn.Module):
120
+ def __init__(
121
+ self,
122
+ qdim,
123
+ kdim,
124
+ num_heads,
125
+ qkv_bias=True,
126
+ qk_norm=False,
127
+ norm_layer=nn.LayerNorm,
128
+ with_decoupled_ca=False,
129
+ decoupled_ca_dim=16,
130
+ decoupled_ca_weight=1.0,
131
+ **kwargs,
132
+ ):
133
+ super().__init__()
134
+ self.qdim = qdim
135
+ self.kdim = kdim
136
+ self.num_heads = num_heads
137
+ assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
138
+ self.head_dim = self.qdim // num_heads
139
+ assert (
140
+ self.head_dim % 8 == 0 and self.head_dim <= 128
141
+ ), "Only support head_dim <= 128 and divisible by 8"
142
+ self.scale = self.head_dim**-0.5
143
+
144
+ self.to_q = nn.Linear(qdim, qdim, bias=qkv_bias)
145
+ self.to_k = nn.Linear(kdim, qdim, bias=qkv_bias)
146
+ self.to_v = nn.Linear(kdim, qdim, bias=qkv_bias)
147
+
148
+ # TODO: eps should be 1 / 65530 if using fp16
149
+ self.q_norm = (
150
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
151
+ if qk_norm
152
+ else nn.Identity()
153
+ )
154
+ self.k_norm = (
155
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
156
+ if qk_norm
157
+ else nn.Identity()
158
+ )
159
+ self.out_proj = nn.Linear(qdim, qdim, bias=True)
160
+
161
+ self.with_dca = with_decoupled_ca
162
+ if self.with_dca:
163
+ self.kv_proj_dca = nn.Linear(kdim, 2 * qdim, bias=qkv_bias)
164
+ self.k_norm_dca = (
165
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
166
+ if qk_norm
167
+ else nn.Identity()
168
+ )
169
+ self.dca_dim = decoupled_ca_dim
170
+ self.dca_weight = decoupled_ca_weight
171
+ # zero init
172
+ nn.init.zeros_(self.out_proj.weight)
173
+ nn.init.zeros_(self.out_proj.bias)
174
+
175
+ def forward(self, x, y):
176
+ """
177
+ Parameters
178
+ ----------
179
+ x: torch.Tensor
180
+ (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
181
+ y: torch.Tensor
182
+ (batch, seqlen2, hidden_dim2)
183
+ freqs_cis_img: torch.Tensor
184
+ (batch, hidden_dim // 2), RoPE for image
185
+ """
186
+ b, s1, c = x.shape # [b, s1, D]
187
+
188
+ if self.with_dca:
189
+ token_len = y.shape[1]
190
+ context_dca = y[:, -self.dca_dim :, :]
191
+ kv_dca = self.kv_proj_dca(context_dca).view(
192
+ b, self.dca_dim, 2, self.num_heads, self.head_dim
193
+ )
194
+ k_dca, v_dca = kv_dca.unbind(dim=2) # [b, s, h, d]
195
+ k_dca = self.k_norm_dca(k_dca)
196
+ y = y[:, : (token_len - self.dca_dim), :]
197
+
198
+ _, s2, c = y.shape # [b, s2, 1024]
199
+ q = self.to_q(x)
200
+ k = self.to_k(y)
201
+ v = self.to_v(y)
202
+
203
+ kv = torch.cat((k, v), dim=-1)
204
+ split_size = kv.shape[-1] // self.num_heads // 2
205
+ kv = kv.view(1, -1, self.num_heads, split_size * 2)
206
+ k, v = torch.split(kv, split_size, dim=-1)
207
+
208
+ q = q.view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
209
+ k = k.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
210
+ v = v.view(b, s2, self.num_heads, self.head_dim) # [b, s2, h, d]
211
+
212
+ q = self.q_norm(q)
213
+ k = self.k_norm(k)
214
+
215
+ with torch.backends.cuda.sdp_kernel(
216
+ enable_flash=True, enable_math=False, enable_mem_efficient=True
217
+ ):
218
+ q, k, v = map(
219
+ lambda t: rearrange(t, "b n h d -> b h n d", h=self.num_heads),
220
+ (q, k, v),
221
+ )
222
+ context = (
223
+ F.scaled_dot_product_attention(q, k, v)
224
+ .transpose(1, 2)
225
+ .reshape(b, s1, -1)
226
+ )
227
+
228
+ if self.with_dca:
229
+ with torch.backends.cuda.sdp_kernel(
230
+ enable_flash=True, enable_math=False, enable_mem_efficient=True
231
+ ):
232
+ k_dca, v_dca = map(
233
+ lambda t: rearrange(t, "b n h d -> b h n d", h=self.num_heads),
234
+ (k_dca, v_dca),
235
+ )
236
+ context_dca = (
237
+ F.scaled_dot_product_attention(q, k_dca, v_dca)
238
+ .transpose(1, 2)
239
+ .reshape(b, s1, -1)
240
+ )
241
+
242
+ context = context + self.dca_weight * context_dca
243
+
244
+ out = self.out_proj(context) # context.reshape - B, L1, -1
245
+
246
+ return out
247
+
248
+
249
+ class Attention(nn.Module):
250
+ """
251
+ We rename some layer names to align with flash attention
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ dim,
257
+ num_heads,
258
+ qkv_bias=True,
259
+ qk_norm=False,
260
+ norm_layer=nn.LayerNorm,
261
+ use_global_processor=False,
262
+ ):
263
+ super().__init__()
264
+ self.use_global_processor = use_global_processor
265
+ self.dim = dim
266
+ self.num_heads = num_heads
267
+ assert self.dim % num_heads == 0, "dim should be divisible by num_heads"
268
+ self.head_dim = self.dim // num_heads
269
+ # This assertion is aligned with flash attention
270
+ assert (
271
+ self.head_dim % 8 == 0 and self.head_dim <= 128
272
+ ), "Only support head_dim <= 128 and divisible by 8"
273
+ self.scale = self.head_dim**-0.5
274
+
275
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
276
+ self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
277
+ self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
278
+ # TODO: eps should be 1 / 65530 if using fp16
279
+ self.q_norm = (
280
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
281
+ if qk_norm
282
+ else nn.Identity()
283
+ )
284
+ self.k_norm = (
285
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
286
+ if qk_norm
287
+ else nn.Identity()
288
+ )
289
+ self.out_proj = nn.Linear(dim, dim)
290
+
291
+ # set processor
292
+ self.processor = LocalGlobalProcessor(use_global=use_global_processor)
293
+
294
+ def forward(self, x):
295
+ return self.processor(self, x)
296
+
297
+
298
+ class AttentionPool(nn.Module):
299
+ def __init__(
300
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
301
+ ):
302
+ super().__init__()
303
+ self.positional_embedding = nn.Parameter(
304
+ torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5
305
+ )
306
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
307
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
308
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
309
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
310
+ self.num_heads = num_heads
311
+
312
+ def forward(self, x, attention_mask=None):
313
+ x = x.permute(1, 0, 2) # NLC -> LNC
314
+ if attention_mask is not None:
315
+ attention_mask = attention_mask.unsqueeze(-1).permute(1, 0, 2)
316
+ global_emb = (x * attention_mask).sum(dim=0) / attention_mask.sum(dim=0)
317
+ x = torch.cat([global_emb[None,], x], dim=0)
318
+
319
+ else:
320
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
321
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
322
+ x, _ = F.multi_head_attention_forward(
323
+ query=x[:1],
324
+ key=x,
325
+ value=x,
326
+ embed_dim_to_check=x.shape[-1],
327
+ num_heads=self.num_heads,
328
+ q_proj_weight=self.q_proj.weight,
329
+ k_proj_weight=self.k_proj.weight,
330
+ v_proj_weight=self.v_proj.weight,
331
+ in_proj_weight=None,
332
+ in_proj_bias=torch.cat(
333
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
334
+ ),
335
+ bias_k=None,
336
+ bias_v=None,
337
+ add_zero_attn=False,
338
+ dropout_p=0,
339
+ out_proj_weight=self.c_proj.weight,
340
+ out_proj_bias=self.c_proj.bias,
341
+ use_separate_proj_weight=True,
342
+ training=self.training,
343
+ need_weights=False,
344
+ )
345
+ return x.squeeze(0)
346
+
347
+
348
+ class LocalGlobalProcessor:
349
+ def __init__(self, use_global=False):
350
+ self.use_global = use_global
351
+
352
+ def __call__(
353
+ self,
354
+ attn: Attention,
355
+ hidden_states: torch.Tensor,
356
+ ):
357
+ """
358
+ hidden_states: [B, L, C]
359
+ """
360
+ if self.use_global:
361
+ B_old, N_old, C_old = hidden_states.shape
362
+ hidden_states = hidden_states.reshape(1, -1, C_old)
363
+ B, N, C = hidden_states.shape
364
+
365
+ q = attn.to_q(hidden_states)
366
+ k = attn.to_k(hidden_states)
367
+ v = attn.to_v(hidden_states)
368
+
369
+ qkv = torch.cat((q, k, v), dim=-1)
370
+ split_size = qkv.shape[-1] // attn.num_heads // 3
371
+ qkv = qkv.view(1, -1, attn.num_heads, split_size * 3)
372
+ q, k, v = torch.split(qkv, split_size, dim=-1)
373
+
374
+ q = q.reshape(B, N, attn.num_heads, attn.head_dim).transpose(
375
+ 1, 2
376
+ ) # [b, h, s, d]
377
+ k = k.reshape(B, N, attn.num_heads, attn.head_dim).transpose(
378
+ 1, 2
379
+ ) # [b, h, s, d]
380
+ v = v.reshape(B, N, attn.num_heads, attn.head_dim).transpose(1, 2)
381
+
382
+ q = attn.q_norm(q) # [b, h, s, d]
383
+ k = attn.k_norm(k) # [b, h, s, d]
384
+
385
+ with torch.backends.cuda.sdp_kernel(
386
+ enable_flash=True, enable_math=False, enable_mem_efficient=True
387
+ ):
388
+ hidden_states = F.scaled_dot_product_attention(q, k, v)
389
+ hidden_states = hidden_states.transpose(1, 2).reshape(B, N, -1)
390
+
391
+ hidden_states = attn.out_proj(hidden_states)
392
+ if self.use_global:
393
+ hidden_states = hidden_states.reshape(B_old, N_old, -1)
394
+ return hidden_states
395
+
396
+
397
+ class PartFormerDitBlock(nn.Module):
398
+
399
+ def __init__(
400
+ self,
401
+ hidden_size,
402
+ num_heads,
403
+ use_self_attention: bool = True,
404
+ use_cross_attention: bool = False,
405
+ use_cross_attention_2: bool = False,
406
+ encoder_hidden_dim=1024, # cross-attn encoder_hidden_states dim
407
+ encoder_hidden2_dim=1024, # cross-attn 2 encoder_hidden_states dim
408
+ # cross_attn2_weight=0.0,
409
+ qkv_bias=True,
410
+ qk_norm=False,
411
+ norm_layer=nn.LayerNorm,
412
+ qk_norm_layer=nn.RMSNorm,
413
+ with_decoupled_ca=False,
414
+ decoupled_ca_dim=16,
415
+ decoupled_ca_weight=1.0,
416
+ skip_connection=False,
417
+ timested_modulate=False,
418
+ c_emb_size=0, # time embedding size
419
+ use_moe: bool = False,
420
+ num_experts: int = 8,
421
+ moe_top_k: int = 2,
422
+ ):
423
+ super().__init__()
424
+ # self.cross_attn2_weight = cross_attn2_weight
425
+ use_ele_affine = True
426
+ # ========================= Self-Attention =========================
427
+ self.use_self_attention = use_self_attention
428
+ if self.use_self_attention:
429
+ self.norm1 = norm_layer(
430
+ hidden_size, elementwise_affine=use_ele_affine, eps=1e-6
431
+ )
432
+ self.attn1 = Attention(
433
+ hidden_size,
434
+ num_heads=num_heads,
435
+ qkv_bias=qkv_bias,
436
+ qk_norm=qk_norm,
437
+ norm_layer=qk_norm_layer,
438
+ )
439
+
440
+ # ========================= Add =========================
441
+ # Simply use add like SDXL.
442
+ self.timested_modulate = timested_modulate
443
+ if self.timested_modulate:
444
+ self.default_modulation = nn.Sequential(
445
+ nn.SiLU(), nn.Linear(c_emb_size, hidden_size, bias=True)
446
+ )
447
+ # ========================= Cross-Attention =========================
448
+ self.use_cross_attention = use_cross_attention
449
+ if self.use_cross_attention:
450
+ self.norm2 = norm_layer(
451
+ hidden_size, elementwise_affine=use_ele_affine, eps=1e-6
452
+ )
453
+ self.attn2 = CrossAttention(
454
+ hidden_size,
455
+ encoder_hidden_dim,
456
+ num_heads=num_heads,
457
+ qkv_bias=qkv_bias,
458
+ qk_norm=qk_norm,
459
+ norm_layer=qk_norm_layer,
460
+ with_decoupled_ca=False,
461
+ )
462
+ self.use_cross_attention_2 = use_cross_attention_2
463
+ if self.use_cross_attention_2:
464
+ self.norm2_2 = norm_layer(
465
+ hidden_size, elementwise_affine=use_ele_affine, eps=1e-6
466
+ )
467
+ self.attn2_2 = CrossAttention(
468
+ hidden_size,
469
+ encoder_hidden2_dim,
470
+ num_heads=num_heads,
471
+ qkv_bias=qkv_bias,
472
+ qk_norm=qk_norm,
473
+ norm_layer=qk_norm_layer,
474
+ with_decoupled_ca=with_decoupled_ca,
475
+ decoupled_ca_dim=decoupled_ca_dim,
476
+ decoupled_ca_weight=decoupled_ca_weight,
477
+ )
478
+ # ========================= FFN =========================
479
+ self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
480
+ self.use_moe = use_moe
481
+ if self.use_moe:
482
+ print("using moe")
483
+ self.moe = MoEBlock(
484
+ hidden_size,
485
+ num_experts=num_experts,
486
+ moe_top_k=moe_top_k,
487
+ dropout=0.0,
488
+ activation_fn="gelu",
489
+ final_dropout=False,
490
+ ff_inner_dim=int(hidden_size * 4.0),
491
+ ff_bias=True,
492
+ )
493
+ else:
494
+ self.mlp = MLP(width=hidden_size)
495
+ # ========================= skip FFN =========================
496
+ if skip_connection:
497
+ self.skip_norm = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
498
+ self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
499
+ else:
500
+ self.skip_linear = None
501
+
502
+ def forward(
503
+ self,
504
+ hidden_states: torch.Tensor,
505
+ encoder_hidden_states: Optional[torch.Tensor] = None,
506
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
507
+ temb: Optional[torch.Tensor] = None,
508
+ skip_value: torch.Tensor = None,
509
+ ):
510
+ # skip connection
511
+ if self.skip_linear is not None:
512
+ cat = torch.cat([skip_value, hidden_states], dim=-1)
513
+ hidden_states = self.skip_linear(cat)
514
+ hidden_states = self.skip_norm(hidden_states)
515
+ # local global attn (self-attn)
516
+ if self.timested_modulate:
517
+ shift_msa = self.default_modulation(temb).unsqueeze(dim=1)
518
+ hidden_states = hidden_states + shift_msa
519
+ if self.use_self_attention:
520
+ attn_output = self.attn1(self.norm1(hidden_states))
521
+ hidden_states = hidden_states + attn_output
522
+ # image cross attn
523
+ if self.use_cross_attention:
524
+ original_cross_out = self.attn2(
525
+ self.norm2(hidden_states),
526
+ encoder_hidden_states,
527
+ )
528
+ # added local-global cross attn
529
+ # 2. Cross-Attention
530
+ if self.use_cross_attention_2:
531
+ cross_out_2 = self.attn2_2(
532
+ self.norm2_2(hidden_states),
533
+ encoder_hidden_states_2,
534
+ )
535
+ hidden_states = (
536
+ hidden_states
537
+ + (original_cross_out if self.use_cross_attention else 0)
538
+ + (cross_out_2 if self.use_cross_attention_2 else 0)
539
+ )
540
+
541
+ # FFN Layer
542
+ mlp_inputs = self.norm3(hidden_states)
543
+
544
+ if self.use_moe:
545
+ hidden_states = hidden_states + self.moe(mlp_inputs)
546
+ else:
547
+ hidden_states = hidden_states + self.mlp(mlp_inputs)
548
+
549
+ return hidden_states
550
+
551
+
552
+ class FinalLayer(nn.Module):
553
+ """
554
+ The final layer of HunYuanDiT.
555
+ """
556
+
557
+ def __init__(self, final_hidden_size, out_channels):
558
+ super().__init__()
559
+ self.final_hidden_size = final_hidden_size
560
+ self.norm_final = nn.LayerNorm(
561
+ final_hidden_size, elementwise_affine=True, eps=1e-6
562
+ )
563
+ self.linear = nn.Linear(final_hidden_size, out_channels, bias=True)
564
+
565
+ def forward(self, x):
566
+ x = self.norm_final(x)
567
+ x = x[:, 1:]
568
+ x = self.linear(x)
569
+ return x
570
+
571
+
572
+ class PartFormerDITPlain(nn.Module):
573
+
574
+ def __init__(
575
+ self,
576
+ input_size=1024,
577
+ in_channels=4,
578
+ hidden_size=1024,
579
+ use_self_attention=True,
580
+ use_cross_attention=True,
581
+ use_cross_attention_2=True,
582
+ encoder_hidden_dim=1024, # cross-attn encoder_hidden_states dim
583
+ encoder_hidden2_dim=1024, # cross-attn 2 encoder_hidden_states dim
584
+ depth=24,
585
+ num_heads=16,
586
+ qk_norm=False,
587
+ qkv_bias=True,
588
+ norm_type="layer",
589
+ qk_norm_type="rms",
590
+ with_decoupled_ca=False,
591
+ decoupled_ca_dim=16,
592
+ decoupled_ca_weight=1.0,
593
+ use_pos_emb=False,
594
+ # use_attention_pooling=True,
595
+ guidance_cond_proj_dim=None,
596
+ num_moe_layers: int = 6,
597
+ num_experts: int = 8,
598
+ moe_top_k: int = 2,
599
+ **kwargs,
600
+ ):
601
+ super().__init__()
602
+
603
+ self.input_size = input_size
604
+ self.depth = depth
605
+ self.in_channels = in_channels
606
+ self.out_channels = in_channels
607
+ self.num_heads = num_heads
608
+
609
+ self.hidden_size = hidden_size
610
+ self.norm = nn.LayerNorm if norm_type == "layer" else nn.RMSNorm
611
+ self.qk_norm = nn.RMSNorm if qk_norm_type == "rms" else nn.LayerNorm
612
+ # embedding
613
+ self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
614
+ self.t_embedder = TimestepEmbedder(
615
+ hidden_size, hidden_size * 4, cond_proj_dim=guidance_cond_proj_dim
616
+ )
617
+ # Will use fixed sin-cos embedding:
618
+ self.use_pos_emb = use_pos_emb
619
+ if self.use_pos_emb:
620
+ self.register_buffer("pos_embed", torch.zeros(1, input_size, hidden_size))
621
+ pos = np.arange(self.input_size, dtype=np.float32)
622
+ pos_embed = get_1d_sincos_pos_embed_from_grid(self.pos_embed.shape[-1], pos)
623
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
624
+
625
+ # self.use_attention_pooling = use_attention_pooling
626
+ # if use_attention_pooling:
627
+
628
+ # self.pooler = AttentionPool(
629
+ # self.text_len, encoder_hidden_dim, num_heads=8, output_dim=1024
630
+ # )
631
+ # self.extra_embedder = nn.Sequential(
632
+ # nn.Linear(1024, hidden_size * 4),
633
+ # nn.SiLU(),
634
+ # nn.Linear(hidden_size * 4, hidden_size, bias=True),
635
+ # )
636
+ # for part embedding
637
+ self.use_bbox_cond = kwargs.get("use_bbox_cond", False)
638
+ if self.use_bbox_cond:
639
+ self.bbox_conditioner = BboxEmbedder(
640
+ out_size=hidden_size,
641
+ num_freqs=kwargs.get("num_freqs", 8),
642
+ )
643
+ self.use_part_embed = kwargs.get("use_part_embed", False)
644
+ if self.use_part_embed:
645
+ self.valid_num = kwargs.get("valid_num", 50)
646
+ self.part_embed = nn.Parameter(torch.randn(self.valid_num, hidden_size))
647
+ # zero init part_embed
648
+ self.part_embed.data.zero_()
649
+ # transformer blocks
650
+ self.blocks = nn.ModuleList([
651
+ PartFormerDitBlock(
652
+ hidden_size,
653
+ num_heads,
654
+ use_self_attention=use_self_attention,
655
+ use_cross_attention=use_cross_attention,
656
+ use_cross_attention_2=use_cross_attention_2,
657
+ encoder_hidden_dim=encoder_hidden_dim, # cross-attn encoder_hidden_states dim
658
+ encoder_hidden2_dim=encoder_hidden2_dim, # cross-attn 2 encoder_hidden_states dim
659
+ # cross_attn2_weight=cross_attn2_weight,
660
+ qkv_bias=qkv_bias,
661
+ qk_norm=qk_norm,
662
+ norm_layer=self.norm,
663
+ qk_norm_layer=self.qk_norm,
664
+ with_decoupled_ca=with_decoupled_ca,
665
+ decoupled_ca_dim=decoupled_ca_dim,
666
+ decoupled_ca_weight=decoupled_ca_weight,
667
+ skip_connection=layer > depth // 2,
668
+ use_moe=True if depth - layer <= num_moe_layers else False,
669
+ num_experts=num_experts,
670
+ moe_top_k=moe_top_k,
671
+ )
672
+ for layer in range(depth)
673
+ ])
674
+ # set local-global processor
675
+ for layer, block in enumerate(self.blocks):
676
+ if hasattr(block, "attn1") and (layer + 1) % 2 == 0:
677
+ block.attn1.processor = LocalGlobalProcessor(use_global=True)
678
+
679
+ self.depth = depth
680
+
681
+ self.final_layer = FinalLayer(hidden_size, self.out_channels)
682
+
683
+ def forward(self, x, t, contexts: dict, **kwargs):
684
+ """
685
+
686
+ x: [B, N, C]
687
+ t: [B]
688
+ contexts: dict
689
+ image_context: [B, K*ni, C]
690
+ geo_context: [B, K*ng, C] or [B, K*ng, C*2]
691
+ aabb: [B, K, 2, 3]
692
+ num_tokens: [B, N]
693
+
694
+ N = K * num_tokens
695
+
696
+ For parts pretrain : K = 1
697
+ """
698
+ # prepare input
699
+ aabb: torch.Tensor = kwargs.get("aabb", None)
700
+ # image_context = contexts.get("image_un_cond", None)
701
+ object_context = contexts.get("obj_cond", None)
702
+ geo_context = contexts.get("geo_cond", None)
703
+ num_tokens: torch.Tensor = kwargs.get("num_tokens", None)
704
+ # timeembedding and input projection
705
+ t = self.t_embedder(t, condition=kwargs.get("guidance_cond"))
706
+ x = self.x_embedder(x)
707
+
708
+ if self.use_pos_emb:
709
+ pos_embed = self.pos_embed.to(x.dtype)
710
+ x = x + pos_embed
711
+
712
+ # c is time embedding (adding pooling context or not)
713
+ # if self.use_attention_pooling:
714
+ # # TODO: attention_pooling for all contexts
715
+ # extra_vec = self.pooler(image_context, None)
716
+ # c = t + self.extra_embedder(extra_vec) # [B, D]
717
+ # else:
718
+ # c = t
719
+ c = t
720
+ # bounding box
721
+ if self.use_bbox_cond:
722
+ center_extent = torch.cat(
723
+ [torch.mean(aabb, dim=-2), aabb[..., 1, :] - aabb[..., 0, :]], dim=-1
724
+ )
725
+ bbox_embeds = self.bbox_conditioner(center_extent)
726
+ # TODO: now only support batch_size=1
727
+ bbox_embeds = torch.repeat_interleave(
728
+ bbox_embeds, repeats=num_tokens[0], dim=1
729
+ )
730
+ x = x + bbox_embeds
731
+ # part id embedding
732
+ if self.use_part_embed:
733
+ num_parts = aabb.shape[1]
734
+ random_idx = torch.randperm(self.valid_num)[:num_parts]
735
+ part_embeds = self.part_embed[random_idx].unsqueeze(1)
736
+ # import pdb
737
+
738
+ # pdb.set_trace()
739
+ x = x + part_embeds
740
+ x = torch.cat([c, x], dim=1)
741
+ skip_value_list = []
742
+ for layer, block in enumerate(self.blocks):
743
+ skip_value = None if layer <= self.depth // 2 else skip_value_list.pop()
744
+ x = block(
745
+ hidden_states=x,
746
+ # encoder_hidden_states=image_context,
747
+ encoder_hidden_states=object_context,
748
+ encoder_hidden_states_2=geo_context,
749
+ temb=c,
750
+ skip_value=skip_value,
751
+ )
752
+ if layer < self.depth // 2:
753
+ skip_value_list.append(x)
754
+
755
+ x = self.final_layer(x)
756
+ return x
XPart/partgen/models/sonata/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .model import load, load_by_config
17
+
18
+ from . import model
19
+ from . import module
20
+ from . import structure
21
+ from . import data
22
+ from . import transform
23
+ from . import utils
24
+ from . import registry
25
+
26
+ __all__ = [
27
+ "load",
28
+ "load_by_config",
29
+ "model",
30
+ "module",
31
+ "structure",
32
+ "transform",
33
+ "registry",
34
+ "utils",
35
+ ]
XPart/partgen/models/sonata/data.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import numpy as np
18
+ import torch
19
+ from collections.abc import Mapping, Sequence
20
+ from huggingface_hub import hf_hub_download
21
+
22
+
23
+ DATAS = ["sample1", "sample1_high_res", "sample1_dino"]
24
+
25
+
26
+ def load(
27
+ name: str = "sonata",
28
+ download_root: str = None,
29
+ ):
30
+ if name in DATAS:
31
+ print(f"Loading data from HuggingFace: {name} ...")
32
+ data_path = hf_hub_download(
33
+ repo_id="pointcept/demo",
34
+ filename=f"{name}.npz",
35
+ repo_type="dataset",
36
+ revision="main",
37
+ local_dir=download_root or os.path.expanduser("~/.cache/sonata/data"),
38
+ )
39
+ elif os.path.isfile(name):
40
+ print(f"Loading data in local path: {name} ...")
41
+ data_path = name
42
+ else:
43
+ raise RuntimeError(f"Data {name} not found; available models = {DATAS}")
44
+ return dict(np.load(data_path))
45
+
46
+
47
+ from torch.utils.data.dataloader import default_collate
48
+
49
+
50
+ def collate_fn(batch):
51
+ """
52
+ collate function for point cloud which support dict and list,
53
+ 'coord' is necessary to determine 'offset'
54
+ """
55
+ if not isinstance(batch, Sequence):
56
+ raise TypeError(f"{batch.dtype} is not supported.")
57
+
58
+ if isinstance(batch[0], torch.Tensor):
59
+ return torch.cat(list(batch))
60
+ elif isinstance(batch[0], str):
61
+ # str is also a kind of Sequence, judgement should before Sequence
62
+ return list(batch)
63
+ elif isinstance(batch[0], Sequence):
64
+ for data in batch:
65
+ data.append(torch.tensor([data[0].shape[0]]))
66
+ batch = [collate_fn(samples) for samples in zip(*batch)]
67
+ batch[-1] = torch.cumsum(batch[-1], dim=0).int()
68
+ return batch
69
+ elif isinstance(batch[0], Mapping):
70
+ batch = {
71
+ key: (
72
+ collate_fn([d[key] for d in batch])
73
+ if "offset" not in key
74
+ # offset -> bincount -> concat bincount-> concat offset
75
+ else torch.cumsum(
76
+ collate_fn([d[key].diff(prepend=torch.tensor([0])) for d in batch]),
77
+ dim=0,
78
+ )
79
+ )
80
+ for key in batch[0]
81
+ }
82
+ return batch
83
+ else:
84
+ return default_collate(batch)
XPart/partgen/models/sonata/model.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Point Transformer - V3 Mode2 - Sonata
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
23
+ #
24
+ # Licensed under the Apache License, Version 2.0 (the "License");
25
+ # you may not use this file except in compliance with the License.
26
+ # You may obtain a copy of the License at
27
+ #
28
+ # http://www.apache.org/licenses/LICENSE-2.0
29
+ #
30
+ # Unless required by applicable law or agreed to in writing, software
31
+ # distributed under the License is distributed on an "AS IS" BASIS,
32
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33
+ # See the License for the specific language governing permissions and
34
+ # limitations under the License.
35
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
36
+ #
37
+ # Licensed under the Apache License, Version 2.0 (the "License");
38
+ # you may not use this file except in compliance with the License.
39
+ # You may obtain a copy of the License at
40
+ #
41
+ # http://www.apache.org/licenses/LICENSE-2.0
42
+ #
43
+ # Unless required by applicable law or agreed to in writing, software
44
+ # distributed under the License is distributed on an "AS IS" BASIS,
45
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
46
+ # See the License for the specific language governing permissions and
47
+ # limitations under the License.
48
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
49
+ #
50
+ # Licensed under the Apache License, Version 2.0 (the "License");
51
+ # you may not use this file except in compliance with the License.
52
+ # You may obtain a copy of the License at
53
+ #
54
+ # http://www.apache.org/licenses/LICENSE-2.0
55
+ #
56
+ # Unless required by applicable law or agreed to in writing, software
57
+ # distributed under the License is distributed on an "AS IS" BASIS,
58
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
59
+ # See the License for the specific language governing permissions and
60
+ # limitations under the License.
61
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
62
+ #
63
+ # Licensed under the Apache License, Version 2.0 (the "License");
64
+ # you may not use this file except in compliance with the License.
65
+ # You may obtain a copy of the License at
66
+ #
67
+ # http://www.apache.org/licenses/LICENSE-2.0
68
+ #
69
+ # Unless required by applicable law or agreed to in writing, software
70
+ # distributed under the License is distributed on an "AS IS" BASIS,
71
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
72
+ # See the License for the specific language governing permissions and
73
+ # limitations under the License.
74
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
75
+ #
76
+ # Licensed under the Apache License, Version 2.0 (the "License");
77
+ # you may not use this file except in compliance with the License.
78
+ # You may obtain a copy of the License at
79
+ #
80
+ # http://www.apache.org/licenses/LICENSE-2.0
81
+ #
82
+ # Unless required by applicable law or agreed to in writing, software
83
+ # distributed under the License is distributed on an "AS IS" BASIS,
84
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
85
+ # See the License for the specific language governing permissions and
86
+ # limitations under the License.
87
+
88
+
89
+ import os
90
+ from packaging import version
91
+ from huggingface_hub import hf_hub_download, PyTorchModelHubMixin
92
+ from addict import Dict
93
+ import torch
94
+ import torch.nn as nn
95
+ from torch.nn.init import trunc_normal_
96
+ import spconv.pytorch as spconv
97
+ import torch_scatter
98
+ from timm.layers import DropPath
99
+ import json
100
+
101
+ try:
102
+ import flash_attn
103
+ except ImportError:
104
+ flash_attn = None
105
+
106
+ from .structure import Point
107
+ from .module import PointSequential, PointModule
108
+ from .utils import offset2bincount
109
+
110
+ MODELS = [
111
+ "sonata",
112
+ "sonata_small",
113
+ "sonata_linear_prob_head_sc",
114
+ ]
115
+
116
+
117
+ class LayerScale(nn.Module):
118
+ def __init__(
119
+ self,
120
+ dim: int,
121
+ init_values: float = 1e-5,
122
+ inplace: bool = False,
123
+ ) -> None:
124
+ super().__init__()
125
+ self.inplace = inplace
126
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
127
+
128
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
129
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
130
+
131
+
132
+ class RPE(torch.nn.Module):
133
+ def __init__(self, patch_size, num_heads):
134
+ super().__init__()
135
+ self.patch_size = patch_size
136
+ self.num_heads = num_heads
137
+ self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
138
+ self.rpe_num = 2 * self.pos_bnd + 1
139
+ self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
140
+ torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
141
+
142
+ def forward(self, coord):
143
+ idx = (
144
+ coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
145
+ + self.pos_bnd # relative position to positive index
146
+ + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
147
+ )
148
+ out = self.rpe_table.index_select(0, idx.reshape(-1))
149
+ out = out.view(idx.shape + (-1,)).sum(3)
150
+ out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
151
+ return out
152
+
153
+
154
+ class SerializedAttention(PointModule):
155
+ def __init__(
156
+ self,
157
+ channels,
158
+ num_heads,
159
+ patch_size,
160
+ qkv_bias=True,
161
+ qk_scale=None,
162
+ attn_drop=0.0,
163
+ proj_drop=0.0,
164
+ order_index=0,
165
+ enable_rpe=False,
166
+ enable_flash=True,
167
+ upcast_attention=True,
168
+ upcast_softmax=True,
169
+ ):
170
+ super().__init__()
171
+ assert channels % num_heads == 0
172
+ self.channels = channels
173
+ self.num_heads = num_heads
174
+ self.scale = qk_scale or (channels // num_heads) ** -0.5
175
+ self.order_index = order_index
176
+ self.upcast_attention = upcast_attention
177
+ self.upcast_softmax = upcast_softmax
178
+ self.enable_rpe = enable_rpe
179
+ self.enable_flash = enable_flash
180
+ if enable_flash:
181
+ assert (
182
+ enable_rpe is False
183
+ ), "Set enable_rpe to False when enable Flash Attention"
184
+ assert (
185
+ upcast_attention is False
186
+ ), "Set upcast_attention to False when enable Flash Attention"
187
+ assert (
188
+ upcast_softmax is False
189
+ ), "Set upcast_softmax to False when enable Flash Attention"
190
+ assert flash_attn is not None, "Make sure flash_attn is installed."
191
+ self.patch_size = patch_size
192
+ self.attn_drop = attn_drop
193
+ else:
194
+ # when disable flash attention, we still don't want to use mask
195
+ # consequently, patch size will auto set to the
196
+ # min number of patch_size_max and number of points
197
+ self.patch_size_max = patch_size
198
+ self.patch_size = 0
199
+ self.attn_drop = torch.nn.Dropout(attn_drop)
200
+
201
+ self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
202
+ self.proj = torch.nn.Linear(channels, channels)
203
+ self.proj_drop = torch.nn.Dropout(proj_drop)
204
+ self.softmax = torch.nn.Softmax(dim=-1)
205
+ self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
206
+
207
+ @torch.no_grad()
208
+ def get_rel_pos(self, point, order):
209
+ K = self.patch_size
210
+ rel_pos_key = f"rel_pos_{self.order_index}"
211
+ if rel_pos_key not in point.keys():
212
+ grid_coord = point.grid_coord[order]
213
+ grid_coord = grid_coord.reshape(-1, K, 3)
214
+ point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
215
+ return point[rel_pos_key]
216
+
217
+ @torch.no_grad()
218
+ def get_padding_and_inverse(self, point):
219
+ pad_key = "pad"
220
+ unpad_key = "unpad"
221
+ cu_seqlens_key = "cu_seqlens_key"
222
+ if (
223
+ pad_key not in point.keys()
224
+ or unpad_key not in point.keys()
225
+ or cu_seqlens_key not in point.keys()
226
+ ):
227
+ offset = point.offset
228
+ bincount = offset2bincount(offset)
229
+ bincount_pad = (
230
+ torch.div(
231
+ bincount + self.patch_size - 1,
232
+ self.patch_size,
233
+ rounding_mode="trunc",
234
+ )
235
+ * self.patch_size
236
+ )
237
+ # only pad point when num of points larger than patch_size
238
+ mask_pad = bincount > self.patch_size
239
+ bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
240
+ _offset = nn.functional.pad(offset, (1, 0))
241
+ _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))
242
+ pad = torch.arange(_offset_pad[-1], device=offset.device)
243
+ unpad = torch.arange(_offset[-1], device=offset.device)
244
+ cu_seqlens = []
245
+ for i in range(len(offset)):
246
+ unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
247
+ if bincount[i] != bincount_pad[i]:
248
+ pad[
249
+ _offset_pad[i + 1]
250
+ - self.patch_size
251
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
252
+ ] = pad[
253
+ _offset_pad[i + 1]
254
+ - 2 * self.patch_size
255
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
256
+ - self.patch_size
257
+ ]
258
+ pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
259
+ cu_seqlens.append(
260
+ torch.arange(
261
+ _offset_pad[i],
262
+ _offset_pad[i + 1],
263
+ step=self.patch_size,
264
+ dtype=torch.int32,
265
+ device=offset.device,
266
+ )
267
+ )
268
+ point[pad_key] = pad
269
+ point[unpad_key] = unpad
270
+ point[cu_seqlens_key] = nn.functional.pad(
271
+ torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
272
+ )
273
+ return point[pad_key], point[unpad_key], point[cu_seqlens_key]
274
+
275
+ def forward(self, point):
276
+ if not self.enable_flash:
277
+ self.patch_size = min(
278
+ offset2bincount(point.offset).min().tolist(), self.patch_size_max
279
+ )
280
+
281
+ H = self.num_heads
282
+ K = self.patch_size
283
+ C = self.channels
284
+
285
+ pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
286
+
287
+ order = point.serialized_order[self.order_index][pad]
288
+ inverse = unpad[point.serialized_inverse[self.order_index]]
289
+
290
+ # padding and reshape feat and batch for serialized point patch
291
+ qkv = self.qkv(point.feat)[order]
292
+
293
+ if not self.enable_flash:
294
+ # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
295
+ q, k, v = (
296
+ qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
297
+ )
298
+ # attn
299
+ if self.upcast_attention:
300
+ q = q.float()
301
+ k = k.float()
302
+ attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
303
+ if self.enable_rpe:
304
+ attn = attn + self.rpe(self.get_rel_pos(point, order))
305
+ if self.upcast_softmax:
306
+ attn = attn.float()
307
+ attn = self.softmax(attn)
308
+ attn = self.attn_drop(attn).to(qkv.dtype)
309
+ feat = (attn @ v).transpose(1, 2).reshape(-1, C)
310
+ else:
311
+ feat = flash_attn.flash_attn_varlen_qkvpacked_func(
312
+ qkv.half().reshape(-1, 3, H, C // H),
313
+ cu_seqlens,
314
+ max_seqlen=self.patch_size,
315
+ dropout_p=self.attn_drop if self.training else 0,
316
+ softmax_scale=self.scale,
317
+ ).reshape(-1, C)
318
+ feat = feat.to(qkv.dtype)
319
+ feat = feat[inverse]
320
+
321
+ # ffn
322
+ feat = self.proj(feat)
323
+ feat = self.proj_drop(feat)
324
+ point.feat = feat
325
+ return point
326
+
327
+
328
+ class MLP(nn.Module):
329
+ def __init__(
330
+ self,
331
+ in_channels,
332
+ hidden_channels=None,
333
+ out_channels=None,
334
+ act_layer=nn.GELU,
335
+ drop=0.0,
336
+ ):
337
+ super().__init__()
338
+ out_channels = out_channels or in_channels
339
+ hidden_channels = hidden_channels or in_channels
340
+ self.fc1 = nn.Linear(in_channels, hidden_channels)
341
+ self.act = act_layer()
342
+ self.fc2 = nn.Linear(hidden_channels, out_channels)
343
+ self.drop = nn.Dropout(drop)
344
+
345
+ def forward(self, x):
346
+ x = self.fc1(x)
347
+ x = self.act(x)
348
+ x = self.drop(x)
349
+ x = self.fc2(x)
350
+ x = self.drop(x)
351
+ return x
352
+
353
+
354
+ class Block(PointModule):
355
+ def __init__(
356
+ self,
357
+ channels,
358
+ num_heads,
359
+ patch_size=48,
360
+ mlp_ratio=4.0,
361
+ qkv_bias=True,
362
+ qk_scale=None,
363
+ attn_drop=0.0,
364
+ proj_drop=0.0,
365
+ drop_path=0.0,
366
+ layer_scale=None,
367
+ norm_layer=nn.LayerNorm,
368
+ act_layer=nn.GELU,
369
+ pre_norm=True,
370
+ order_index=0,
371
+ cpe_indice_key=None,
372
+ enable_rpe=False,
373
+ enable_flash=True,
374
+ upcast_attention=True,
375
+ upcast_softmax=True,
376
+ ):
377
+ super().__init__()
378
+ self.channels = channels
379
+ self.pre_norm = pre_norm
380
+
381
+ self.cpe = PointSequential(
382
+ spconv.SubMConv3d(
383
+ channels,
384
+ channels,
385
+ kernel_size=3,
386
+ bias=True,
387
+ indice_key=cpe_indice_key,
388
+ ),
389
+ nn.Linear(channels, channels),
390
+ norm_layer(channels),
391
+ )
392
+
393
+ self.norm1 = PointSequential(norm_layer(channels))
394
+ self.ls1 = PointSequential(
395
+ LayerScale(channels, init_values=layer_scale)
396
+ if layer_scale is not None
397
+ else nn.Identity()
398
+ )
399
+ self.attn = SerializedAttention(
400
+ channels=channels,
401
+ patch_size=patch_size,
402
+ num_heads=num_heads,
403
+ qkv_bias=qkv_bias,
404
+ qk_scale=qk_scale,
405
+ attn_drop=attn_drop,
406
+ proj_drop=proj_drop,
407
+ order_index=order_index,
408
+ enable_rpe=enable_rpe,
409
+ enable_flash=enable_flash,
410
+ upcast_attention=upcast_attention,
411
+ upcast_softmax=upcast_softmax,
412
+ )
413
+ self.norm2 = PointSequential(norm_layer(channels))
414
+ self.ls2 = PointSequential(
415
+ LayerScale(channels, init_values=layer_scale)
416
+ if layer_scale is not None
417
+ else nn.Identity()
418
+ )
419
+ self.mlp = PointSequential(
420
+ MLP(
421
+ in_channels=channels,
422
+ hidden_channels=int(channels * mlp_ratio),
423
+ out_channels=channels,
424
+ act_layer=act_layer,
425
+ drop=proj_drop,
426
+ )
427
+ )
428
+ self.drop_path = PointSequential(
429
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
430
+ )
431
+
432
+ def forward(self, point: Point):
433
+ shortcut = point.feat
434
+ point = self.cpe(point)
435
+ point.feat = shortcut + point.feat
436
+ shortcut = point.feat
437
+ if self.pre_norm:
438
+ point = self.norm1(point)
439
+ point = self.drop_path(self.ls1(self.attn(point)))
440
+ point.feat = shortcut + point.feat
441
+ if not self.pre_norm:
442
+ point = self.norm1(point)
443
+
444
+ shortcut = point.feat
445
+ if self.pre_norm:
446
+ point = self.norm2(point)
447
+ point = self.drop_path(self.ls2(self.mlp(point)))
448
+ point.feat = shortcut + point.feat
449
+ if not self.pre_norm:
450
+ point = self.norm2(point)
451
+ point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
452
+ return point
453
+
454
+
455
+ class GridPooling(PointModule):
456
+ def __init__(
457
+ self,
458
+ in_channels,
459
+ out_channels,
460
+ stride=2,
461
+ norm_layer=None,
462
+ act_layer=None,
463
+ reduce="max",
464
+ shuffle_orders=True,
465
+ traceable=True, # record parent and cluster
466
+ ):
467
+ super().__init__()
468
+ self.in_channels = in_channels
469
+ self.out_channels = out_channels
470
+
471
+ self.stride = stride
472
+ assert reduce in ["sum", "mean", "min", "max"]
473
+ self.reduce = reduce
474
+ self.shuffle_orders = shuffle_orders
475
+ self.traceable = traceable
476
+
477
+ self.proj = nn.Linear(in_channels, out_channels)
478
+ if norm_layer is not None:
479
+ self.norm = PointSequential(norm_layer(out_channels))
480
+ if act_layer is not None:
481
+ self.act = PointSequential(act_layer())
482
+
483
+ def forward(self, point: Point):
484
+ if "grid_coord" in point.keys():
485
+ grid_coord = point.grid_coord
486
+ elif {"coord", "grid_size"}.issubset(point.keys()):
487
+ grid_coord = torch.div(
488
+ point.coord - point.coord.min(0)[0],
489
+ point.grid_size,
490
+ rounding_mode="trunc",
491
+ ).int()
492
+ else:
493
+ raise AssertionError(
494
+ "[gird_coord] or [coord, grid_size] should be include in the Point"
495
+ )
496
+ grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc")
497
+ grid_coord = grid_coord | point.batch.view(-1, 1) << 48
498
+ grid_coord, cluster, counts = torch.unique(
499
+ grid_coord,
500
+ sorted=True,
501
+ return_inverse=True,
502
+ return_counts=True,
503
+ dim=0,
504
+ )
505
+ grid_coord = grid_coord & ((1 << 48) - 1)
506
+ # indices of point sorted by cluster, for torch_scatter.segment_csr
507
+ _, indices = torch.sort(cluster)
508
+ # index pointer for sorted point, for torch_scatter.segment_csr
509
+ idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
510
+ # head_indices of each cluster, for reduce attr e.g. code, batch
511
+ head_indices = indices[idx_ptr[:-1]]
512
+ point_dict = Dict(
513
+ feat=torch_scatter.segment_csr(
514
+ self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
515
+ ),
516
+ coord=torch_scatter.segment_csr(
517
+ point.coord[indices], idx_ptr, reduce="mean"
518
+ ),
519
+ grid_coord=grid_coord,
520
+ batch=point.batch[head_indices],
521
+ )
522
+ if "origin_coord" in point.keys():
523
+ point_dict["origin_coord"] = torch_scatter.segment_csr(
524
+ point.origin_coord[indices], idx_ptr, reduce="mean"
525
+ )
526
+ if "condition" in point.keys():
527
+ point_dict["condition"] = point.condition
528
+ if "context" in point.keys():
529
+ point_dict["context"] = point.context
530
+ if "name" in point.keys():
531
+ point_dict["name"] = point.name
532
+ if "split" in point.keys():
533
+ point_dict["split"] = point.split
534
+ if "color" in point.keys():
535
+ point_dict["color"] = torch_scatter.segment_csr(
536
+ point.color[indices], idx_ptr, reduce="mean"
537
+ )
538
+ if "grid_size" in point.keys():
539
+ point_dict["grid_size"] = point.grid_size * self.stride
540
+
541
+ if self.traceable:
542
+ point_dict["pooling_inverse"] = cluster
543
+ point_dict["pooling_parent"] = point
544
+ order = point.order
545
+ point = Point(point_dict)
546
+ if self.norm is not None:
547
+ point = self.norm(point)
548
+ if self.act is not None:
549
+ point = self.act(point)
550
+ point.serialization(order=order, shuffle_orders=self.shuffle_orders)
551
+ point.sparsify()
552
+ return point
553
+
554
+
555
+ class GridUnpooling(PointModule):
556
+ def __init__(
557
+ self,
558
+ in_channels,
559
+ skip_channels,
560
+ out_channels,
561
+ norm_layer=None,
562
+ act_layer=None,
563
+ traceable=False, # record parent and cluster
564
+ ):
565
+ super().__init__()
566
+ self.proj = PointSequential(nn.Linear(in_channels, out_channels))
567
+ self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))
568
+
569
+ if norm_layer is not None:
570
+ self.proj.add(norm_layer(out_channels))
571
+ self.proj_skip.add(norm_layer(out_channels))
572
+
573
+ if act_layer is not None:
574
+ self.proj.add(act_layer())
575
+ self.proj_skip.add(act_layer())
576
+
577
+ self.traceable = traceable
578
+
579
+ def forward(self, point):
580
+ assert "pooling_parent" in point.keys()
581
+ assert "pooling_inverse" in point.keys()
582
+ parent = point.pop("pooling_parent")
583
+ inverse = point.pooling_inverse
584
+ feat = point.feat
585
+
586
+ parent = self.proj_skip(parent)
587
+ parent.feat = parent.feat + self.proj(point).feat[inverse]
588
+ parent.sparse_conv_feat = parent.sparse_conv_feat.replace_feature(parent.feat)
589
+
590
+ if self.traceable:
591
+ point.feat = feat
592
+ parent["unpooling_parent"] = point
593
+ return parent
594
+
595
+
596
+ class Embedding(PointModule):
597
+ def __init__(
598
+ self,
599
+ in_channels,
600
+ embed_channels,
601
+ norm_layer=None,
602
+ act_layer=None,
603
+ mask_token=False,
604
+ ):
605
+ super().__init__()
606
+ self.in_channels = in_channels
607
+ self.embed_channels = embed_channels
608
+
609
+ self.stem = PointSequential(linear=nn.Linear(in_channels, embed_channels))
610
+ if norm_layer is not None:
611
+ self.stem.add(norm_layer(embed_channels), name="norm")
612
+ if act_layer is not None:
613
+ self.stem.add(act_layer(), name="act")
614
+
615
+ if mask_token:
616
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_channels))
617
+ else:
618
+ self.mask_token = None
619
+
620
+ def forward(self, point: Point):
621
+ point = self.stem(point)
622
+ if "mask" in point.keys():
623
+ point.feat = torch.where(
624
+ point.mask.unsqueeze(-1),
625
+ self.mask_token.to(point.feat.dtype),
626
+ point.feat,
627
+ )
628
+ return point
629
+
630
+
631
+ class PointTransformerV3(PointModule, PyTorchModelHubMixin):
632
+ def __init__(
633
+ self,
634
+ in_channels=6,
635
+ order=("z", "z-trans"),
636
+ stride=(2, 2, 2, 2),
637
+ enc_depths=(3, 3, 3, 12, 3),
638
+ enc_channels=(48, 96, 192, 384, 512),
639
+ enc_num_head=(3, 6, 12, 24, 32),
640
+ enc_patch_size=(1024, 1024, 1024, 1024, 1024),
641
+ dec_depths=(3, 3, 3, 3),
642
+ dec_channels=(96, 96, 192, 384),
643
+ dec_num_head=(6, 6, 12, 32),
644
+ dec_patch_size=(1024, 1024, 1024, 1024),
645
+ mlp_ratio=4,
646
+ qkv_bias=True,
647
+ qk_scale=None,
648
+ attn_drop=0.0,
649
+ proj_drop=0.0,
650
+ drop_path=0.3,
651
+ layer_scale=None,
652
+ pre_norm=True,
653
+ shuffle_orders=True,
654
+ enable_rpe=False,
655
+ enable_flash=True,
656
+ upcast_attention=False,
657
+ upcast_softmax=False,
658
+ traceable=False,
659
+ mask_token=False,
660
+ enc_mode=False,
661
+ freeze_encoder=False,
662
+ ):
663
+ super().__init__()
664
+ self.num_stages = len(enc_depths)
665
+ self.order = [order] if isinstance(order, str) else order
666
+ self.enc_mode = enc_mode
667
+ self.shuffle_orders = shuffle_orders
668
+ self.freeze_encoder = freeze_encoder
669
+
670
+ assert self.num_stages == len(stride) + 1
671
+ assert self.num_stages == len(enc_depths)
672
+ assert self.num_stages == len(enc_channels)
673
+ assert self.num_stages == len(enc_num_head)
674
+ assert self.num_stages == len(enc_patch_size)
675
+ assert self.enc_mode or self.num_stages == len(dec_depths) + 1
676
+ assert self.enc_mode or self.num_stages == len(dec_channels) + 1
677
+ assert self.enc_mode or self.num_stages == len(dec_num_head) + 1
678
+ assert self.enc_mode or self.num_stages == len(dec_patch_size) + 1
679
+
680
+ print(f"flash attention: {enable_flash}")
681
+
682
+ # normalization layer
683
+ ln_layer = nn.LayerNorm
684
+ # activation layers
685
+ act_layer = nn.GELU
686
+
687
+ self.embedding = Embedding(
688
+ in_channels=in_channels,
689
+ embed_channels=enc_channels[0],
690
+ norm_layer=ln_layer,
691
+ act_layer=act_layer,
692
+ mask_token=mask_token,
693
+ )
694
+
695
+ # encoder
696
+ enc_drop_path = [
697
+ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
698
+ ]
699
+ self.enc = PointSequential()
700
+ for s in range(self.num_stages):
701
+ enc_drop_path_ = enc_drop_path[
702
+ sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
703
+ ]
704
+ enc = PointSequential()
705
+ if s > 0:
706
+ enc.add(
707
+ GridPooling(
708
+ in_channels=enc_channels[s - 1],
709
+ out_channels=enc_channels[s],
710
+ stride=stride[s - 1],
711
+ norm_layer=ln_layer,
712
+ act_layer=act_layer,
713
+ ),
714
+ name="down",
715
+ )
716
+ for i in range(enc_depths[s]):
717
+ enc.add(
718
+ Block(
719
+ channels=enc_channels[s],
720
+ num_heads=enc_num_head[s],
721
+ patch_size=enc_patch_size[s],
722
+ mlp_ratio=mlp_ratio,
723
+ qkv_bias=qkv_bias,
724
+ qk_scale=qk_scale,
725
+ attn_drop=attn_drop,
726
+ proj_drop=proj_drop,
727
+ drop_path=enc_drop_path_[i],
728
+ layer_scale=layer_scale,
729
+ norm_layer=ln_layer,
730
+ act_layer=act_layer,
731
+ pre_norm=pre_norm,
732
+ order_index=i % len(self.order),
733
+ cpe_indice_key=f"stage{s}",
734
+ enable_rpe=enable_rpe,
735
+ enable_flash=enable_flash,
736
+ upcast_attention=upcast_attention,
737
+ upcast_softmax=upcast_softmax,
738
+ ),
739
+ name=f"block{i}",
740
+ )
741
+ if len(enc) != 0:
742
+ self.enc.add(module=enc, name=f"enc{s}")
743
+
744
+ # decoder
745
+ if not self.enc_mode:
746
+ dec_drop_path = [
747
+ x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
748
+ ]
749
+ self.dec = PointSequential()
750
+ dec_channels = list(dec_channels) + [enc_channels[-1]]
751
+ for s in reversed(range(self.num_stages - 1)):
752
+ dec_drop_path_ = dec_drop_path[
753
+ sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
754
+ ]
755
+ dec_drop_path_.reverse()
756
+ dec = PointSequential()
757
+ dec.add(
758
+ GridUnpooling(
759
+ in_channels=dec_channels[s + 1],
760
+ skip_channels=enc_channels[s],
761
+ out_channels=dec_channels[s],
762
+ norm_layer=ln_layer,
763
+ act_layer=act_layer,
764
+ traceable=traceable,
765
+ ),
766
+ name="up",
767
+ )
768
+ for i in range(dec_depths[s]):
769
+ dec.add(
770
+ Block(
771
+ channels=dec_channels[s],
772
+ num_heads=dec_num_head[s],
773
+ patch_size=dec_patch_size[s],
774
+ mlp_ratio=mlp_ratio,
775
+ qkv_bias=qkv_bias,
776
+ qk_scale=qk_scale,
777
+ attn_drop=attn_drop,
778
+ proj_drop=proj_drop,
779
+ drop_path=dec_drop_path_[i],
780
+ layer_scale=layer_scale,
781
+ norm_layer=ln_layer,
782
+ act_layer=act_layer,
783
+ pre_norm=pre_norm,
784
+ order_index=i % len(self.order),
785
+ cpe_indice_key=f"stage{s}",
786
+ enable_rpe=enable_rpe,
787
+ enable_flash=enable_flash,
788
+ upcast_attention=upcast_attention,
789
+ upcast_softmax=upcast_softmax,
790
+ ),
791
+ name=f"block{i}",
792
+ )
793
+ self.dec.add(module=dec, name=f"dec{s}")
794
+ if self.freeze_encoder:
795
+ for p in self.embedding.parameters():
796
+ p.requires_grad = False
797
+ for p in self.enc.parameters():
798
+ p.requires_grad = False
799
+ self.apply(self._init_weights)
800
+
801
+ @staticmethod
802
+ def _init_weights(module):
803
+ if isinstance(module, nn.Linear):
804
+ trunc_normal_(module.weight, std=0.02)
805
+ if module.bias is not None:
806
+ nn.init.zeros_(module.bias)
807
+ elif isinstance(module, spconv.SubMConv3d):
808
+ trunc_normal_(module.weight, std=0.02)
809
+ if module.bias is not None:
810
+ nn.init.zeros_(module.bias)
811
+
812
+ def forward(self, data_dict):
813
+ point = Point(data_dict)
814
+ point = self.embedding(point)
815
+
816
+ point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
817
+ point.sparsify()
818
+
819
+ point = self.enc(point)
820
+ if not self.enc_mode:
821
+ point = self.dec(point)
822
+ return point
823
+
824
+
825
+ def load(
826
+ name: str = "sonata",
827
+ repo_id="facebook/sonata",
828
+ download_root: str = None,
829
+ custom_config: dict = None,
830
+ ckpt_only: bool = False,
831
+ ):
832
+ if name in MODELS:
833
+ print(f"Loading checkpoint from HuggingFace: {name} ...")
834
+ ckpt_path = hf_hub_download(
835
+ repo_id=repo_id,
836
+ filename=f"{name}.pth",
837
+ repo_type="model",
838
+ revision="main",
839
+ local_dir=download_root or os.path.expanduser("~/.cache/sonata/ckpt"),
840
+ )
841
+ elif os.path.isfile(name):
842
+ print(f"Loading checkpoint in local path: {name} ...")
843
+ ckpt_path = name
844
+ else:
845
+ raise RuntimeError(f"Model {name} not found; available models = {MODELS}")
846
+
847
+ if version.parse(torch.__version__) >= version.parse("2.4"):
848
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
849
+ else:
850
+ ckpt = torch.load(ckpt_path, map_location="cpu")
851
+ if custom_config is not None:
852
+ for key, value in custom_config.items():
853
+ ckpt["config"][key] = value
854
+
855
+ if ckpt_only:
856
+ return ckpt
857
+
858
+ # 关闭flash attention
859
+ # ckpt["config"]['enable_flash'] = False
860
+
861
+ model = PointTransformerV3(**ckpt["config"])
862
+ model.load_state_dict(ckpt["state_dict"])
863
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
864
+ print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}")
865
+ return model
866
+
867
+
868
+ def load_by_config(config_path: str):
869
+ with open(config_path, "r") as f:
870
+ config = json.load(f)
871
+ model = PointTransformerV3(**config)
872
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
873
+ print(f"Model params: {n_parameters / 1e6:.2f}M {n_parameters}")
874
+ return model
XPart/partgen/models/sonata/module.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Point Modules
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import sys
25
+ import torch.nn as nn
26
+ import spconv.pytorch as spconv
27
+ from collections import OrderedDict
28
+
29
+ from .structure import Point
30
+
31
+
32
+ class PointModule(nn.Module):
33
+ r"""PointModule
34
+ placeholder, all module subclass from this will take Point in PointSequential.
35
+ """
36
+
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+
40
+
41
+ class PointSequential(PointModule):
42
+ r"""A sequential container.
43
+ Modules will be added to it in the order they are passed in the constructor.
44
+ Alternatively, an ordered dict of modules can also be passed in.
45
+ """
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__()
49
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
50
+ for key, module in args[0].items():
51
+ self.add_module(key, module)
52
+ else:
53
+ for idx, module in enumerate(args):
54
+ self.add_module(str(idx), module)
55
+ for name, module in kwargs.items():
56
+ if sys.version_info < (3, 6):
57
+ raise ValueError("kwargs only supported in py36+")
58
+ if name in self._modules:
59
+ raise ValueError("name exists.")
60
+ self.add_module(name, module)
61
+
62
+ def __getitem__(self, idx):
63
+ if not (-len(self) <= idx < len(self)):
64
+ raise IndexError("index {} is out of range".format(idx))
65
+ if idx < 0:
66
+ idx += len(self)
67
+ it = iter(self._modules.values())
68
+ for i in range(idx):
69
+ next(it)
70
+ return next(it)
71
+
72
+ def __len__(self):
73
+ return len(self._modules)
74
+
75
+ def add(self, module, name=None):
76
+ if name is None:
77
+ name = str(len(self._modules))
78
+ if name in self._modules:
79
+ raise KeyError("name exists")
80
+ self.add_module(name, module)
81
+
82
+ def forward(self, input):
83
+ for k, module in self._modules.items():
84
+ # Point module
85
+ if isinstance(module, PointModule):
86
+ input = module(input)
87
+ # Spconv module
88
+ elif spconv.modules.is_spconv_module(module):
89
+ if isinstance(input, Point):
90
+ input.sparse_conv_feat = module(input.sparse_conv_feat)
91
+ input.feat = input.sparse_conv_feat.features
92
+ else:
93
+ input = module(input)
94
+ # PyTorch module
95
+ else:
96
+ if isinstance(input, Point):
97
+ input.feat = module(input.feat)
98
+ if "sparse_conv_feat" in input.keys():
99
+ input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(
100
+ input.feat
101
+ )
102
+ elif isinstance(input, spconv.SparseConvTensor):
103
+ if input.indices.shape[0] != 0:
104
+ input = input.replace_feature(module(input.features))
105
+ else:
106
+ input = module(input)
107
+ return input
XPart/partgen/models/sonata/registry.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @lint-ignore-every LICENSELINT
2
+ # Copyright (c) OpenMMLab. All rights reserved.
3
+ import inspect
4
+ import warnings
5
+ from functools import partial
6
+ from collections import abc
7
+
8
+
9
+ def is_seq_of(seq, expected_type, seq_type=None):
10
+ """Check whether it is a sequence of some type.
11
+
12
+ Args:
13
+ seq (Sequence): The sequence to be checked.
14
+ expected_type (type): Expected type of sequence items.
15
+ seq_type (type, optional): Expected sequence type.
16
+
17
+ Returns:
18
+ bool: Whether the sequence is valid.
19
+ """
20
+ if seq_type is None:
21
+ exp_seq_type = abc.Sequence
22
+ else:
23
+ assert isinstance(seq_type, type)
24
+ exp_seq_type = seq_type
25
+ if not isinstance(seq, exp_seq_type):
26
+ return False
27
+ for item in seq:
28
+ if not isinstance(item, expected_type):
29
+ return False
30
+ return True
31
+
32
+
33
+ def build_from_cfg(cfg, registry, default_args=None):
34
+ """Build a module from configs dict.
35
+
36
+ Args:
37
+ cfg (dict): Config dict. It should at least contain the key "type".
38
+ registry (:obj:`Registry`): The registry to search the type from.
39
+ default_args (dict, optional): Default initialization arguments.
40
+
41
+ Returns:
42
+ object: The constructed object.
43
+ """
44
+ if not isinstance(cfg, dict):
45
+ raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
46
+ if "type" not in cfg:
47
+ if default_args is None or "type" not in default_args:
48
+ raise KeyError(
49
+ '`cfg` or `default_args` must contain the key "type", '
50
+ f"but got {cfg}\n{default_args}"
51
+ )
52
+ if not isinstance(registry, Registry):
53
+ raise TypeError(
54
+ "registry must be an mmcv.Registry object, " f"but got {type(registry)}"
55
+ )
56
+ if not (isinstance(default_args, dict) or default_args is None):
57
+ raise TypeError(
58
+ "default_args must be a dict or None, " f"but got {type(default_args)}"
59
+ )
60
+
61
+ args = cfg.copy()
62
+
63
+ if default_args is not None:
64
+ for name, value in default_args.items():
65
+ args.setdefault(name, value)
66
+
67
+ obj_type = args.pop("type")
68
+ if isinstance(obj_type, str):
69
+ obj_cls = registry.get(obj_type)
70
+ if obj_cls is None:
71
+ raise KeyError(f"{obj_type} is not in the {registry.name} registry")
72
+ elif inspect.isclass(obj_type):
73
+ obj_cls = obj_type
74
+ else:
75
+ raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
76
+ try:
77
+ return obj_cls(**args)
78
+ except Exception as e:
79
+ # Normal TypeError does not print class name.
80
+ raise type(e)(f"{obj_cls.__name__}: {e}")
81
+
82
+
83
+ class Registry:
84
+ """A registry to map strings to classes.
85
+
86
+ Registered object could be built from registry.
87
+ Example:
88
+ >>> MODELS = Registry('models')
89
+ >>> @MODELS.register_module()
90
+ >>> class ResNet:
91
+ >>> pass
92
+ >>> resnet = MODELS.build(dict(type='ResNet'))
93
+
94
+ Please refer to
95
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
96
+ advanced usage.
97
+
98
+ Args:
99
+ name (str): Registry name.
100
+ build_func(func, optional): Build function to construct instance from
101
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
102
+ ``build_func`` is specified. If ``parent`` is specified and
103
+ ``build_func`` is not given, ``build_func`` will be inherited
104
+ from ``parent``. Default: None.
105
+ parent (Registry, optional): Parent registry. The class registered in
106
+ children registry could be built from parent. Default: None.
107
+ scope (str, optional): The scope of registry. It is the key to search
108
+ for children registry. If not specified, scope will be the name of
109
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
110
+ Default: None.
111
+ """
112
+
113
+ def __init__(self, name, build_func=None, parent=None, scope=None):
114
+ self._name = name
115
+ self._module_dict = dict()
116
+ self._children = dict()
117
+ self._scope = self.infer_scope() if scope is None else scope
118
+
119
+ # self.build_func will be set with the following priority:
120
+ # 1. build_func
121
+ # 2. parent.build_func
122
+ # 3. build_from_cfg
123
+ if build_func is None:
124
+ if parent is not None:
125
+ self.build_func = parent.build_func
126
+ else:
127
+ self.build_func = build_from_cfg
128
+ else:
129
+ self.build_func = build_func
130
+ if parent is not None:
131
+ assert isinstance(parent, Registry)
132
+ parent._add_children(self)
133
+ self.parent = parent
134
+ else:
135
+ self.parent = None
136
+
137
+ def __len__(self):
138
+ return len(self._module_dict)
139
+
140
+ def __contains__(self, key):
141
+ return self.get(key) is not None
142
+
143
+ def __repr__(self):
144
+ format_str = (
145
+ self.__class__.__name__ + f"(name={self._name}, "
146
+ f"items={self._module_dict})"
147
+ )
148
+ return format_str
149
+
150
+ @staticmethod
151
+ def infer_scope():
152
+ """Infer the scope of registry.
153
+
154
+ The name of the package where registry is defined will be returned.
155
+
156
+ Example:
157
+ # in mmdet/models/backbone/resnet.py
158
+ >>> MODELS = Registry('models')
159
+ >>> @MODELS.register_module()
160
+ >>> class ResNet:
161
+ >>> pass
162
+ The scope of ``ResNet`` will be ``mmdet``.
163
+
164
+
165
+ Returns:
166
+ scope (str): The inferred scope name.
167
+ """
168
+ # inspect.stack() trace where this function is called, the index-2
169
+ # indicates the frame where `infer_scope()` is called
170
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
171
+ split_filename = filename.split(".")
172
+ return split_filename[0]
173
+
174
+ @staticmethod
175
+ def split_scope_key(key):
176
+ """Split scope and key.
177
+
178
+ The first scope will be split from key.
179
+
180
+ Examples:
181
+ >>> Registry.split_scope_key('mmdet.ResNet')
182
+ 'mmdet', 'ResNet'
183
+ >>> Registry.split_scope_key('ResNet')
184
+ None, 'ResNet'
185
+
186
+ Return:
187
+ scope (str, None): The first scope.
188
+ key (str): The remaining key.
189
+ """
190
+ split_index = key.find(".")
191
+ if split_index != -1:
192
+ return key[:split_index], key[split_index + 1 :]
193
+ else:
194
+ return None, key
195
+
196
+ @property
197
+ def name(self):
198
+ return self._name
199
+
200
+ @property
201
+ def scope(self):
202
+ return self._scope
203
+
204
+ @property
205
+ def module_dict(self):
206
+ return self._module_dict
207
+
208
+ @property
209
+ def children(self):
210
+ return self._children
211
+
212
+ def get(self, key):
213
+ """Get the registry record.
214
+
215
+ Args:
216
+ key (str): The class name in string format.
217
+
218
+ Returns:
219
+ class: The corresponding class.
220
+ """
221
+ scope, real_key = self.split_scope_key(key)
222
+ if scope is None or scope == self._scope:
223
+ # get from self
224
+ if real_key in self._module_dict:
225
+ return self._module_dict[real_key]
226
+ else:
227
+ # get from self._children
228
+ if scope in self._children:
229
+ return self._children[scope].get(real_key)
230
+ else:
231
+ # goto root
232
+ parent = self.parent
233
+ while parent.parent is not None:
234
+ parent = parent.parent
235
+ return parent.get(key)
236
+
237
+ def build(self, *args, **kwargs):
238
+ return self.build_func(*args, **kwargs, registry=self)
239
+
240
+ def _add_children(self, registry):
241
+ """Add children for a registry.
242
+
243
+ The ``registry`` will be added as children based on its scope.
244
+ The parent registry could build objects from children registry.
245
+
246
+ Example:
247
+ >>> models = Registry('models')
248
+ >>> mmdet_models = Registry('models', parent=models)
249
+ >>> @mmdet_models.register_module()
250
+ >>> class ResNet:
251
+ >>> pass
252
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
253
+ """
254
+
255
+ assert isinstance(registry, Registry)
256
+ assert registry.scope is not None
257
+ assert (
258
+ registry.scope not in self.children
259
+ ), f"scope {registry.scope} exists in {self.name} registry"
260
+ self.children[registry.scope] = registry
261
+
262
+ def _register_module(self, module_class, module_name=None, force=False):
263
+ if not inspect.isclass(module_class):
264
+ raise TypeError("module must be a class, " f"but got {type(module_class)}")
265
+
266
+ if module_name is None:
267
+ module_name = module_class.__name__
268
+ if isinstance(module_name, str):
269
+ module_name = [module_name]
270
+ for name in module_name:
271
+ if not force and name in self._module_dict:
272
+ raise KeyError(f"{name} is already registered " f"in {self.name}")
273
+ self._module_dict[name] = module_class
274
+
275
+ def deprecated_register_module(self, cls=None, force=False):
276
+ warnings.warn(
277
+ "The old API of register_module(module, force=False) "
278
+ "is deprecated and will be removed, please use the new API "
279
+ "register_module(name=None, force=False, module=None) instead."
280
+ )
281
+ if cls is None:
282
+ return partial(self.deprecated_register_module, force=force)
283
+ self._register_module(cls, force=force)
284
+ return cls
285
+
286
+ def register_module(self, name=None, force=False, module=None):
287
+ """Register a module.
288
+
289
+ A record will be added to `self._module_dict`, whose key is the class
290
+ name or the specified name, and value is the class itself.
291
+ It can be used as a decorator or a normal function.
292
+
293
+ Example:
294
+ >>> backbones = Registry('backbone')
295
+ >>> @backbones.register_module()
296
+ >>> class ResNet:
297
+ >>> pass
298
+
299
+ >>> backbones = Registry('backbone')
300
+ >>> @backbones.register_module(name='mnet')
301
+ >>> class MobileNet:
302
+ >>> pass
303
+
304
+ >>> backbones = Registry('backbone')
305
+ >>> class ResNet:
306
+ >>> pass
307
+ >>> backbones.register_module(ResNet)
308
+
309
+ Args:
310
+ name (str | None): The module name to be registered. If not
311
+ specified, the class name will be used.
312
+ force (bool, optional): Whether to override an existing class with
313
+ the same name. Default: False.
314
+ module (type): Module class to be registered.
315
+ """
316
+ if not isinstance(force, bool):
317
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
318
+ # NOTE: This is a walkaround to be compatible with the old api,
319
+ # while it may introduce unexpected bugs.
320
+ if isinstance(name, type):
321
+ return self.deprecated_register_module(name, force=force)
322
+
323
+ # raise the error ahead of time
324
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
325
+ raise TypeError(
326
+ "name must be either of None, an instance of str or a sequence"
327
+ f" of str, but got {type(name)}"
328
+ )
329
+
330
+ # use it as a normal method: x.register_module(module=SomeClass)
331
+ if module is not None:
332
+ self._register_module(module_class=module, module_name=name, force=force)
333
+ return module
334
+
335
+ # use it as a decorator: @x.register_module()
336
+ def _register(cls):
337
+ self._register_module(module_class=cls, module_name=name, force=force)
338
+ return cls
339
+
340
+ return _register
XPart/partgen/models/sonata/serialization/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #init.py
2
+ from .default import (
3
+ encode,
4
+ decode,
5
+ z_order_encode,
6
+ z_order_decode,
7
+ hilbert_encode,
8
+ hilbert_decode,
9
+ )
XPart/partgen/models/sonata/serialization/default.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Serialization Encoding
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import torch
25
+ from .z_order import xyz2key as z_order_encode_
26
+ from .z_order import key2xyz as z_order_decode_
27
+ from .hilbert import encode as hilbert_encode_
28
+ from .hilbert import decode as hilbert_decode_
29
+
30
+
31
+ @torch.inference_mode()
32
+ def encode(grid_coord, batch=None, depth=16, order="z"):
33
+ assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
34
+ if order == "z":
35
+ code = z_order_encode(grid_coord, depth=depth)
36
+ elif order == "z-trans":
37
+ code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
38
+ elif order == "hilbert":
39
+ code = hilbert_encode(grid_coord, depth=depth)
40
+ elif order == "hilbert-trans":
41
+ code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
42
+ else:
43
+ raise NotImplementedError
44
+ if batch is not None:
45
+ batch = batch.long()
46
+ code = batch << depth * 3 | code
47
+ return code
48
+
49
+
50
+ @torch.inference_mode()
51
+ def decode(code, depth=16, order="z"):
52
+ assert order in {"z", "hilbert"}
53
+ batch = code >> depth * 3
54
+ code = code & ((1 << depth * 3) - 1)
55
+ if order == "z":
56
+ grid_coord = z_order_decode(code, depth=depth)
57
+ elif order == "hilbert":
58
+ grid_coord = hilbert_decode(code, depth=depth)
59
+ else:
60
+ raise NotImplementedError
61
+ return grid_coord, batch
62
+
63
+
64
+ def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
65
+ x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
66
+ # we block the support to batch, maintain batched code in Point class
67
+ code = z_order_encode_(x, y, z, b=None, depth=depth)
68
+ return code
69
+
70
+
71
+ def z_order_decode(code: torch.Tensor, depth):
72
+ x, y, z = z_order_decode_(code, depth=depth)
73
+ grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
74
+ return grid_coord
75
+
76
+
77
+ def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
78
+ return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
79
+
80
+
81
+ def hilbert_decode(code: torch.Tensor, depth: int = 16):
82
+ return hilbert_decode_(code, num_dims=3, num_bits=depth)
XPart/partgen/models/sonata/serialization/hilbert.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hilbert Order
3
+ Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import torch
25
+
26
+
27
+ def right_shift(binary, k=1, axis=-1):
28
+ """Right shift an array of binary values.
29
+
30
+ Parameters:
31
+ -----------
32
+ binary: An ndarray of binary values.
33
+
34
+ k: The number of bits to shift. Default 1.
35
+
36
+ axis: The axis along which to shift. Default -1.
37
+
38
+ Returns:
39
+ --------
40
+ Returns an ndarray with zero prepended and the ends truncated, along
41
+ whatever axis was specified."""
42
+
43
+ # If we're shifting the whole thing, just return zeros.
44
+ if binary.shape[axis] <= k:
45
+ return torch.zeros_like(binary)
46
+
47
+ # Determine the padding pattern.
48
+ # padding = [(0,0)] * len(binary.shape)
49
+ # padding[axis] = (k,0)
50
+
51
+ # Determine the slicing pattern to eliminate just the last one.
52
+ slicing = [slice(None)] * len(binary.shape)
53
+ slicing[axis] = slice(None, -k)
54
+ shifted = torch.nn.functional.pad(
55
+ binary[tuple(slicing)], (k, 0), mode="constant", value=0
56
+ )
57
+
58
+ return shifted
59
+
60
+
61
+ def binary2gray(binary, axis=-1):
62
+ """Convert an array of binary values into Gray codes.
63
+
64
+ This uses the classic X ^ (X >> 1) trick to compute the Gray code.
65
+
66
+ Parameters:
67
+ -----------
68
+ binary: An ndarray of binary values.
69
+
70
+ axis: The axis along which to compute the gray code. Default=-1.
71
+
72
+ Returns:
73
+ --------
74
+ Returns an ndarray of Gray codes.
75
+ """
76
+ shifted = right_shift(binary, axis=axis)
77
+
78
+ # Do the X ^ (X >> 1) trick.
79
+ gray = torch.logical_xor(binary, shifted)
80
+
81
+ return gray
82
+
83
+
84
+ def gray2binary(gray, axis=-1):
85
+ """Convert an array of Gray codes back into binary values.
86
+
87
+ Parameters:
88
+ -----------
89
+ gray: An ndarray of gray codes.
90
+
91
+ axis: The axis along which to perform Gray decoding. Default=-1.
92
+
93
+ Returns:
94
+ --------
95
+ Returns an ndarray of binary values.
96
+ """
97
+
98
+ # Loop the log2(bits) number of times necessary, with shift and xor.
99
+ shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
100
+ while shift > 0:
101
+ gray = torch.logical_xor(gray, right_shift(gray, shift))
102
+ shift = torch.div(shift, 2, rounding_mode="floor")
103
+ return gray
104
+
105
+
106
+ def encode(locs, num_dims, num_bits):
107
+ """Decode an array of locations in a hypercube into a Hilbert integer.
108
+
109
+ This is a vectorized-ish version of the Hilbert curve implementation by John
110
+ Skilling as described in:
111
+
112
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
113
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
114
+
115
+ Params:
116
+ -------
117
+ locs - An ndarray of locations in a hypercube of num_dims dimensions, in
118
+ which each dimension runs from 0 to 2**num_bits-1. The shape can
119
+ be arbitrary, as long as the last dimension of the same has size
120
+ num_dims.
121
+
122
+ num_dims - The dimensionality of the hypercube. Integer.
123
+
124
+ num_bits - The number of bits for each dimension. Integer.
125
+
126
+ Returns:
127
+ --------
128
+ The output is an ndarray of uint64 integers with the same shape as the
129
+ input, excluding the last dimension, which needs to be num_dims.
130
+ """
131
+
132
+ # Keep around the original shape for later.
133
+ orig_shape = locs.shape
134
+ bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
135
+ bitpack_mask_rev = bitpack_mask.flip(-1)
136
+
137
+ if orig_shape[-1] != num_dims:
138
+ raise ValueError(
139
+ """
140
+ The shape of locs was surprising in that the last dimension was of size
141
+ %d, but num_dims=%d. These need to be equal.
142
+ """
143
+ % (orig_shape[-1], num_dims)
144
+ )
145
+
146
+ if num_dims * num_bits > 63:
147
+ raise ValueError(
148
+ """
149
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
150
+ into a int64. Are you sure you need that many points on your Hilbert
151
+ curve?
152
+ """
153
+ % (num_dims, num_bits, num_dims * num_bits)
154
+ )
155
+
156
+ # Treat the location integers as 64-bit unsigned and then split them up into
157
+ # a sequence of uint8s. Preserve the association by dimension.
158
+ locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
159
+
160
+ # Now turn these into bits and truncate to num_bits.
161
+ gray = (
162
+ locs_uint8.unsqueeze(-1)
163
+ .bitwise_and(bitpack_mask_rev)
164
+ .ne(0)
165
+ .byte()
166
+ .flatten(-2, -1)[..., -num_bits:]
167
+ )
168
+
169
+ # Run the decoding process the other way.
170
+ # Iterate forwards through the bits.
171
+ for bit in range(0, num_bits):
172
+ # Iterate forwards through the dimensions.
173
+ for dim in range(0, num_dims):
174
+ # Identify which ones have this bit active.
175
+ mask = gray[:, dim, bit]
176
+
177
+ # Where this bit is on, invert the 0 dimension for lower bits.
178
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
179
+ gray[:, 0, bit + 1 :], mask[:, None]
180
+ )
181
+
182
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
183
+ to_flip = torch.logical_and(
184
+ torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
185
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
186
+ )
187
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
188
+ gray[:, dim, bit + 1 :], to_flip
189
+ )
190
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
191
+
192
+ # Now flatten out.
193
+ gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
194
+
195
+ # Convert Gray back to binary.
196
+ hh_bin = gray2binary(gray)
197
+
198
+ # Pad back out to 64 bits.
199
+ extra_dims = 64 - num_bits * num_dims
200
+ padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
201
+
202
+ # Convert binary values into uint8s.
203
+ hh_uint8 = (
204
+ (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
205
+ .sum(2)
206
+ .squeeze()
207
+ .type(torch.uint8)
208
+ )
209
+
210
+ # Convert uint8s into uint64s.
211
+ hh_uint64 = hh_uint8.view(torch.int64).squeeze()
212
+
213
+ return hh_uint64
214
+
215
+
216
+ def decode(hilberts, num_dims, num_bits):
217
+ """Decode an array of Hilbert integers into locations in a hypercube.
218
+
219
+ This is a vectorized-ish version of the Hilbert curve implementation by John
220
+ Skilling as described in:
221
+
222
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
223
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
224
+
225
+ Params:
226
+ -------
227
+ hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
228
+ cannot have fewer bits than num_dims * num_bits.
229
+
230
+ num_dims - The dimensionality of the hypercube. Integer.
231
+
232
+ num_bits - The number of bits for each dimension. Integer.
233
+
234
+ Returns:
235
+ --------
236
+ The output is an ndarray of unsigned integers with the same shape as hilberts
237
+ but with an additional dimension of size num_dims.
238
+ """
239
+
240
+ if num_dims * num_bits > 64:
241
+ raise ValueError(
242
+ """
243
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
244
+ into a uint64. Are you sure you need that many points on your Hilbert
245
+ curve?
246
+ """
247
+ % (num_dims, num_bits)
248
+ )
249
+
250
+ # Handle the case where we got handed a naked integer.
251
+ hilberts = torch.atleast_1d(hilberts)
252
+
253
+ # Keep around the shape for later.
254
+ orig_shape = hilberts.shape
255
+ bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
256
+ bitpack_mask_rev = bitpack_mask.flip(-1)
257
+
258
+ # Treat each of the hilberts as a s equence of eight uint8.
259
+ # This treats all of the inputs as uint64 and makes things uniform.
260
+ hh_uint8 = (
261
+ hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
262
+ )
263
+
264
+ # Turn these lists of uints into lists of bits and then truncate to the size
265
+ # we actually need for using Skilling's procedure.
266
+ hh_bits = (
267
+ hh_uint8.unsqueeze(-1)
268
+ .bitwise_and(bitpack_mask_rev)
269
+ .ne(0)
270
+ .byte()
271
+ .flatten(-2, -1)[:, -num_dims * num_bits :]
272
+ )
273
+
274
+ # Take the sequence of bits and Gray-code it.
275
+ gray = binary2gray(hh_bits)
276
+
277
+ # There has got to be a better way to do this.
278
+ # I could index them differently, but the eventual packbits likes it this way.
279
+ gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
280
+
281
+ # Iterate backwards through the bits.
282
+ for bit in range(num_bits - 1, -1, -1):
283
+ # Iterate backwards through the dimensions.
284
+ for dim in range(num_dims - 1, -1, -1):
285
+ # Identify which ones have this bit active.
286
+ mask = gray[:, dim, bit]
287
+
288
+ # Where this bit is on, invert the 0 dimension for lower bits.
289
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
290
+ gray[:, 0, bit + 1 :], mask[:, None]
291
+ )
292
+
293
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
294
+ to_flip = torch.logical_and(
295
+ torch.logical_not(mask[:, None]),
296
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
297
+ )
298
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
299
+ gray[:, dim, bit + 1 :], to_flip
300
+ )
301
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
302
+
303
+ # Pad back out to 64 bits.
304
+ extra_dims = 64 - num_bits
305
+ padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
306
+
307
+ # Now chop these up into blocks of 8.
308
+ locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
309
+
310
+ # Take those blocks and turn them unto uint8s.
311
+ # from IPython import embed; embed()
312
+ locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
313
+
314
+ # Finally, treat these as uint64s.
315
+ flat_locs = locs_uint8.view(torch.int64)
316
+
317
+ # Return them in the expected shape.
318
+ return flat_locs.reshape((*orig_shape, num_dims))
XPart/partgen/models/sonata/serialization/z_order.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @lint-ignore-every LICENSELINT
2
+ # --------------------------------------------------------
3
+ # Octree-based Sparse Convolutional Neural Networks
4
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Written by Peng-Shuai Wang
7
+ # --------------------------------------------------------
8
+ # --------------------------------------------------------
9
+ # Octree-based Sparse Convolutional Neural Networks
10
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
11
+ # Licensed under The MIT License [see LICENSE for details]
12
+ # Written by Peng-Shuai Wang
13
+ # --------------------------------------------------------
14
+ # --------------------------------------------------------
15
+ # Octree-based Sparse Convolutional Neural Networks
16
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
17
+ # Licensed under The MIT License [see LICENSE for details]
18
+ # Written by Peng-Shuai Wang
19
+ # --------------------------------------------------------
20
+ # --------------------------------------------------------
21
+ # Octree-based Sparse Convolutional Neural Networks
22
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
23
+ # Licensed under The MIT License [see LICENSE for details]
24
+ # Written by Peng-Shuai Wang
25
+ # --------------------------------------------------------
26
+
27
+ import torch
28
+ from typing import Optional, Union
29
+
30
+
31
+ class KeyLUT:
32
+ def __init__(self):
33
+ r256 = torch.arange(256, dtype=torch.int64)
34
+ r512 = torch.arange(512, dtype=torch.int64)
35
+ zero = torch.zeros(256, dtype=torch.int64)
36
+ device = torch.device("cpu")
37
+
38
+ self._encode = {
39
+ device: (
40
+ self.xyz2key(r256, zero, zero, 8),
41
+ self.xyz2key(zero, r256, zero, 8),
42
+ self.xyz2key(zero, zero, r256, 8),
43
+ )
44
+ }
45
+ self._decode = {device: self.key2xyz(r512, 9)}
46
+
47
+ def encode_lut(self, device=torch.device("cpu")):
48
+ if device not in self._encode:
49
+ cpu = torch.device("cpu")
50
+ self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
51
+ return self._encode[device]
52
+
53
+ def decode_lut(self, device=torch.device("cpu")):
54
+ if device not in self._decode:
55
+ cpu = torch.device("cpu")
56
+ self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
57
+ return self._decode[device]
58
+
59
+ def xyz2key(self, x, y, z, depth):
60
+ key = torch.zeros_like(x)
61
+ for i in range(depth):
62
+ mask = 1 << i
63
+ key = (
64
+ key
65
+ | ((x & mask) << (2 * i + 2))
66
+ | ((y & mask) << (2 * i + 1))
67
+ | ((z & mask) << (2 * i + 0))
68
+ )
69
+ return key
70
+
71
+ def key2xyz(self, key, depth):
72
+ x = torch.zeros_like(key)
73
+ y = torch.zeros_like(key)
74
+ z = torch.zeros_like(key)
75
+ for i in range(depth):
76
+ x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
77
+ y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
78
+ z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
79
+ return x, y, z
80
+
81
+
82
+ _key_lut = KeyLUT()
83
+
84
+
85
+ def xyz2key(
86
+ x: torch.Tensor,
87
+ y: torch.Tensor,
88
+ z: torch.Tensor,
89
+ b: Optional[Union[torch.Tensor, int]] = None,
90
+ depth: int = 16,
91
+ ):
92
+ """Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
93
+ based on pre-computed look up tables. The speed of this function is much
94
+ faster than the method based on for-loop.
95
+
96
+ Args:
97
+ x (torch.Tensor): The x coordinate.
98
+ y (torch.Tensor): The y coordinate.
99
+ z (torch.Tensor): The z coordinate.
100
+ b (torch.Tensor or int): The batch index of the coordinates, and should be
101
+ smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
102
+ :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
103
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
104
+ """
105
+
106
+ EX, EY, EZ = _key_lut.encode_lut(x.device)
107
+ x, y, z = x.long(), y.long(), z.long()
108
+
109
+ mask = 255 if depth > 8 else (1 << depth) - 1
110
+ key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
111
+ if depth > 8:
112
+ mask = (1 << (depth - 8)) - 1
113
+ key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
114
+ key = key16 << 24 | key
115
+
116
+ if b is not None:
117
+ b = b.long()
118
+ key = b << 48 | key
119
+
120
+ return key
121
+
122
+
123
+ def key2xyz(key: torch.Tensor, depth: int = 16):
124
+ r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
125
+ and the batch index based on pre-computed look up tables.
126
+
127
+ Args:
128
+ key (torch.Tensor): The shuffled key.
129
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
130
+ """
131
+
132
+ DX, DY, DZ = _key_lut.decode_lut(key.device)
133
+ x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
134
+
135
+ b = key >> 48
136
+ key = key & ((1 << 48) - 1)
137
+
138
+ n = (depth + 2) // 3
139
+ for i in range(n):
140
+ k = key >> (i * 9) & 511
141
+ x = x | (DX[k] << (i * 3))
142
+ y = y | (DY[k] << (i * 3))
143
+ z = z | (DZ[k] << (i * 3))
144
+
145
+ return x, y, z, b
XPart/partgen/models/sonata/structure.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data structure for 3D point cloud
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import torch
24
+ import spconv.pytorch as spconv
25
+ from addict import Dict
26
+
27
+ from .serialization import encode
28
+ from .utils import offset2batch, batch2offset
29
+
30
+
31
+ class Point(Dict):
32
+ """
33
+ Point Structure of Pointcept
34
+
35
+ A Point (point cloud) in Pointcept is a dictionary that contains various properties of
36
+ a batched point cloud. The property with the following names have a specific definition
37
+ as follows:
38
+
39
+ - "coord": original coordinate of point cloud;
40
+ - "grid_coord": grid coordinate for specific grid size (related to GridSampling);
41
+ Point also support the following optional attributes:
42
+ - "offset": if not exist, initialized as batch size is 1;
43
+ - "batch": if not exist, initialized as batch size is 1;
44
+ - "feat": feature of point cloud, default input of model;
45
+ - "grid_size": Grid size of point cloud (related to GridSampling);
46
+ (related to Serialization)
47
+ - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
48
+ - "serialized_code": a list of serialization codes;
49
+ - "serialized_order": a list of serialization order determined by code;
50
+ - "serialized_inverse": a list of inverse mapping determined by code;
51
+ (related to Sparsify: SpConv)
52
+ - "sparse_shape": Sparse shape for Sparse Conv Tensor;
53
+ - "sparse_conv_feat": SparseConvTensor init with information provide by Point;
54
+ """
55
+
56
+ def __init__(self, *args, **kwargs):
57
+ super().__init__(*args, **kwargs)
58
+ # If one of "offset" or "batch" do not exist, generate by the existing one
59
+ if "batch" not in self.keys() and "offset" in self.keys():
60
+ self["batch"] = offset2batch(self.offset)
61
+ elif "offset" not in self.keys() and "batch" in self.keys():
62
+ self["offset"] = batch2offset(self.batch)
63
+
64
+ def serialization(self, order="z", depth=None, shuffle_orders=False):
65
+ """
66
+ Point Cloud Serialization
67
+
68
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
69
+ """
70
+ self["order"] = order
71
+ assert "batch" in self.keys()
72
+ if "grid_coord" not in self.keys():
73
+ # if you don't want to operate GridSampling in data augmentation,
74
+ # please add the following augmentation into your pipeline:
75
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
76
+ # (adjust `grid_size` to what your want)
77
+ assert {"grid_size", "coord"}.issubset(self.keys())
78
+
79
+ self["grid_coord"] = torch.div(
80
+ self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
81
+ ).int()
82
+
83
+ if depth is None:
84
+ # Adaptive measure the depth of serialization cube (length = 2 ^ depth)
85
+ depth = int(self.grid_coord.max() + 1).bit_length()
86
+ self["serialized_depth"] = depth
87
+ # Maximum bit length for serialization code is 63 (int64)
88
+ assert depth * 3 + len(self.offset).bit_length() <= 63
89
+ # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
90
+ # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
91
+ # cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
92
+ # We can unlock the limitation by optimizing the z-order encoding function if necessary.
93
+ assert depth <= 16
94
+
95
+ # The serialization codes are arranged as following structures:
96
+ # [Order1 ([n]),
97
+ # Order2 ([n]),
98
+ # ...
99
+ # OrderN ([n])] (k, n)
100
+ code = [
101
+ encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order
102
+ ]
103
+ code = torch.stack(code)
104
+ order = torch.argsort(code)
105
+ inverse = torch.zeros_like(order).scatter_(
106
+ dim=1,
107
+ index=order,
108
+ src=torch.arange(0, code.shape[1], device=order.device).repeat(
109
+ code.shape[0], 1
110
+ ),
111
+ )
112
+
113
+ if shuffle_orders:
114
+ perm = torch.randperm(code.shape[0])
115
+ code = code[perm]
116
+ order = order[perm]
117
+ inverse = inverse[perm]
118
+
119
+ self["serialized_code"] = code
120
+ self["serialized_order"] = order
121
+ self["serialized_inverse"] = inverse
122
+
123
+ def sparsify(self, pad=96):
124
+ """
125
+ Point Cloud Serialization
126
+
127
+ Point cloud is sparse, here we use "sparsify" to specifically refer to
128
+ preparing "spconv.SparseConvTensor" for SpConv.
129
+
130
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
131
+
132
+ pad: padding sparse for sparse shape.
133
+ """
134
+ assert {"feat", "batch"}.issubset(self.keys())
135
+ if "grid_coord" not in self.keys():
136
+ # if you don't want to operate GridSampling in data augmentation,
137
+ # please add the following augmentation into your pipeline:
138
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
139
+ # (adjust `grid_size` to what your want)
140
+ assert {"grid_size", "coord"}.issubset(self.keys())
141
+ self["grid_coord"] = torch.div(
142
+ self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
143
+ ).int()
144
+ if "sparse_shape" in self.keys():
145
+ sparse_shape = self.sparse_shape
146
+ else:
147
+ sparse_shape = torch.add(
148
+ torch.max(self.grid_coord, dim=0).values, pad
149
+ ).tolist()
150
+ sparse_conv_feat = spconv.SparseConvTensor(
151
+ features=self.feat,
152
+ indices=torch.cat(
153
+ [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
154
+ ).contiguous(),
155
+ spatial_shape=sparse_shape,
156
+ batch_size=self.batch[-1].tolist() + 1,
157
+ )
158
+ self["sparse_shape"] = sparse_shape
159
+ self["sparse_conv_feat"] = sparse_conv_feat
XPart/partgen/models/sonata/transform.py ADDED
@@ -0,0 +1,1330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D point cloud augmentation
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ # Unless required by applicable law or agreed to in writing, software
22
+ # distributed under the License is distributed on an "AS IS" BASIS,
23
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24
+ # See the License for the specific language governing permissions and
25
+ # limitations under the License.
26
+ # Unless required by applicable law or agreed to in writing, software
27
+ # distributed under the License is distributed on an "AS IS" BASIS,
28
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29
+ # See the License for the specific language governing permissions and
30
+ # limitations under the License.
31
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
32
+ #
33
+ # Licensed under the Apache License, Version 2.0 (the "License");
34
+ # you may not use this file except in compliance with the License.
35
+ # You may obtain a copy of the License at
36
+ #
37
+ # http://www.apache.org/licenses/LICENSE-2.0
38
+ #
39
+ # Unless required by applicable law or agreed to in writing, software
40
+ # distributed under the License is distributed on an "AS IS" BASIS,
41
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42
+ # See the License for the specific language governing permissions and
43
+ # limitations under the License.
44
+ # Unless required by applicable law or agreed to in writing, software
45
+ # distributed under the License is distributed on an "AS IS" BASIS,
46
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
47
+ # See the License for the specific language governing permissions and
48
+ # limitations under the License.
49
+ # Unless required by applicable law or agreed to in writing, software
50
+ # distributed under the License is distributed on an "AS IS" BASIS,
51
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52
+ # See the License for the specific language governing permissions and
53
+ # limitations under the License.
54
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
55
+ #
56
+ # Licensed under the Apache License, Version 2.0 (the "License");
57
+ # you may not use this file except in compliance with the License.
58
+ # You may obtain a copy of the License at
59
+ #
60
+ # http://www.apache.org/licenses/LICENSE-2.0
61
+ #
62
+ # Unless required by applicable law or agreed to in writing, software
63
+ # distributed under the License is distributed on an "AS IS" BASIS,
64
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
65
+ # See the License for the specific language governing permissions and
66
+ # limitations under the License.
67
+ # Unless required by applicable law or agreed to in writing, software
68
+ # distributed under the License is distributed on an "AS IS" BASIS,
69
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70
+ # See the License for the specific language governing permissions and
71
+ # limitations under the License.
72
+ # Unless required by applicable law or agreed to in writing, software
73
+ # distributed under the License is distributed on an "AS IS" BASIS,
74
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
75
+ # See the License for the specific language governing permissions and
76
+ # limitations under the License.
77
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
78
+ #
79
+ # Licensed under the Apache License, Version 2.0 (the "License");
80
+ # you may not use this file except in compliance with the License.
81
+ # You may obtain a copy of the License at
82
+ #
83
+ # http://www.apache.org/licenses/LICENSE-2.0
84
+ #
85
+ # Unless required by applicable law or agreed to in writing, software
86
+ # distributed under the License is distributed on an "AS IS" BASIS,
87
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
88
+ # See the License for the specific language governing permissions and
89
+ # limitations under the License.
90
+ # Unless required by applicable law or agreed to in writing, software
91
+ # distributed under the License is distributed on an "AS IS" BASIS,
92
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
93
+ # See the License for the specific language governing permissions and
94
+ # limitations under the License.
95
+ # Unless required by applicable law or agreed to in writing, software
96
+ # distributed under the License is distributed on an "AS IS" BASIS,
97
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
98
+ # See the License for the specific language governing permissions and
99
+ # limitations under the License.
100
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
101
+ #
102
+ # Licensed under the Apache License, Version 2.0 (the "License");
103
+ # you may not use this file except in compliance with the License.
104
+ # You may obtain a copy of the License at
105
+ #
106
+ # http://www.apache.org/licenses/LICENSE-2.0
107
+ #
108
+ # Unless required by applicable law or agreed to in writing, software
109
+ # distributed under the License is distributed on an "AS IS" BASIS,
110
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
111
+ # See the License for the specific language governing permissions and
112
+ # limitations under the License.
113
+ # Unless required by applicable law or agreed to in writing, software
114
+ # distributed under the License is distributed on an "AS IS" BASIS,
115
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
116
+ # See the License for the specific language governing permissions and
117
+ # limitations under the License.
118
+ # Unless required by applicable law or agreed to in writing, software
119
+ # distributed under the License is distributed on an "AS IS" BASIS,
120
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121
+ # See the License for the specific language governing permissions and
122
+ # limitations under the License.
123
+
124
+
125
+
126
+ import random
127
+ import numbers
128
+ import scipy
129
+ import scipy.ndimage
130
+ import scipy.interpolate
131
+ import scipy.stats
132
+ import numpy as np
133
+ import torch
134
+ import copy
135
+ from collections.abc import Sequence, Mapping
136
+
137
+ from .registry import Registry
138
+
139
+ TRANSFORMS = Registry("transforms")
140
+
141
+
142
+ def index_operator(data_dict, index, duplicate=False):
143
+ # index selection operator for keys in "index_valid_keys"
144
+ # custom these keys by "Update" transform in config
145
+ if "index_valid_keys" not in data_dict:
146
+ data_dict["index_valid_keys"] = [
147
+ "coord",
148
+ "color",
149
+ "normal",
150
+ "strength",
151
+ "segment",
152
+ "instance",
153
+ ]
154
+ if not duplicate:
155
+ for key in data_dict["index_valid_keys"]:
156
+ if key in data_dict:
157
+ data_dict[key] = data_dict[key][index]
158
+ return data_dict
159
+ else:
160
+ data_dict_ = dict()
161
+ for key in data_dict.keys():
162
+ if key in data_dict["index_valid_keys"]:
163
+ data_dict_[key] = data_dict[key][index]
164
+ else:
165
+ data_dict_[key] = data_dict[key]
166
+ return data_dict_
167
+
168
+
169
+ @TRANSFORMS.register_module()
170
+ class Collect(object):
171
+ def __init__(self, keys, offset_keys_dict=None, **kwargs):
172
+ """
173
+ e.g. Collect(keys=[coord], feat_keys=[coord, color])
174
+ """
175
+ if offset_keys_dict is None:
176
+ offset_keys_dict = dict(offset="coord")
177
+ self.keys = keys
178
+ self.offset_keys = offset_keys_dict
179
+ self.kwargs = kwargs
180
+
181
+ def __call__(self, data_dict):
182
+ data = dict()
183
+ if isinstance(self.keys, str):
184
+ self.keys = [self.keys]
185
+ for key in self.keys:
186
+ data[key] = data_dict[key]
187
+ for key, value in self.offset_keys.items():
188
+ data[key] = torch.tensor([data_dict[value].shape[0]])
189
+ for name, keys in self.kwargs.items():
190
+ name = name.replace("_keys", "")
191
+ assert isinstance(keys, Sequence)
192
+ data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1)
193
+ return data
194
+
195
+
196
+ @TRANSFORMS.register_module()
197
+ class Copy(object):
198
+ def __init__(self, keys_dict=None):
199
+ if keys_dict is None:
200
+ keys_dict = dict(coord="origin_coord", segment="origin_segment")
201
+ self.keys_dict = keys_dict
202
+
203
+ def __call__(self, data_dict):
204
+ for key, value in self.keys_dict.items():
205
+ if isinstance(data_dict[key], np.ndarray):
206
+ data_dict[value] = data_dict[key].copy()
207
+ elif isinstance(data_dict[key], torch.Tensor):
208
+ data_dict[value] = data_dict[key].clone().detach()
209
+ else:
210
+ data_dict[value] = copy.deepcopy(data_dict[key])
211
+ return data_dict
212
+
213
+
214
+ @TRANSFORMS.register_module()
215
+ class Update(object):
216
+ def __init__(self, keys_dict=None):
217
+ if keys_dict is None:
218
+ keys_dict = dict()
219
+ self.keys_dict = keys_dict
220
+
221
+ def __call__(self, data_dict):
222
+ for key, value in self.keys_dict.items():
223
+ data_dict[key] = value
224
+ return data_dict
225
+
226
+
227
+ @TRANSFORMS.register_module()
228
+ class ToTensor(object):
229
+ def __call__(self, data):
230
+ if isinstance(data, torch.Tensor):
231
+ return data
232
+ elif isinstance(data, str):
233
+ # note that str is also a kind of sequence, judgement should before sequence
234
+ return data
235
+ elif isinstance(data, int):
236
+ return torch.LongTensor([data])
237
+ elif isinstance(data, float):
238
+ return torch.FloatTensor([data])
239
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool):
240
+ return torch.from_numpy(data)
241
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer):
242
+ return torch.from_numpy(data).long()
243
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating):
244
+ return torch.from_numpy(data).float()
245
+ elif isinstance(data, Mapping):
246
+ result = {sub_key: self(item) for sub_key, item in data.items()}
247
+ return result
248
+ elif isinstance(data, Sequence):
249
+ result = [self(item) for item in data]
250
+ return result
251
+ else:
252
+ raise TypeError(f"type {type(data)} cannot be converted to tensor.")
253
+
254
+
255
+ @TRANSFORMS.register_module()
256
+ class NormalizeColor(object):
257
+ def __call__(self, data_dict):
258
+ if "color" in data_dict.keys():
259
+ data_dict["color"] = data_dict["color"] / 255
260
+ return data_dict
261
+
262
+
263
+ @TRANSFORMS.register_module()
264
+ class NormalizeCoord(object):
265
+ def __call__(self, data_dict):
266
+ if "coord" in data_dict.keys():
267
+ # modified from pointnet2
268
+ centroid = np.mean(data_dict["coord"], axis=0)
269
+ data_dict["coord"] -= centroid
270
+ m = np.max(np.sqrt(np.sum(data_dict["coord"] ** 2, axis=1)))
271
+ data_dict["coord"] = data_dict["coord"] / m
272
+ return data_dict
273
+
274
+
275
+ @TRANSFORMS.register_module()
276
+ class PositiveShift(object):
277
+ def __call__(self, data_dict):
278
+ if "coord" in data_dict.keys():
279
+ coord_min = np.min(data_dict["coord"], 0)
280
+ data_dict["coord"] -= coord_min
281
+ return data_dict
282
+
283
+
284
+ @TRANSFORMS.register_module()
285
+ class CenterShift(object):
286
+ def __init__(self, apply_z=True):
287
+ self.apply_z = apply_z
288
+
289
+ def __call__(self, data_dict):
290
+ if "coord" in data_dict.keys():
291
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
292
+ x_max, y_max, _ = data_dict["coord"].max(axis=0)
293
+ if self.apply_z:
294
+ shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, z_min]
295
+ else:
296
+ shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, 0]
297
+ data_dict["coord"] -= shift
298
+ return data_dict
299
+
300
+
301
+ @TRANSFORMS.register_module()
302
+ class RandomShift(object):
303
+ def __init__(self, shift=((-0.2, 0.2), (-0.2, 0.2), (0, 0))):
304
+ self.shift = shift
305
+
306
+ def __call__(self, data_dict):
307
+ if "coord" in data_dict.keys():
308
+ shift_x = np.random.uniform(self.shift[0][0], self.shift[0][1])
309
+ shift_y = np.random.uniform(self.shift[1][0], self.shift[1][1])
310
+ shift_z = np.random.uniform(self.shift[2][0], self.shift[2][1])
311
+ data_dict["coord"] += [shift_x, shift_y, shift_z]
312
+ return data_dict
313
+
314
+
315
+ @TRANSFORMS.register_module()
316
+ class PointClip(object):
317
+ def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)):
318
+ self.point_cloud_range = point_cloud_range
319
+
320
+ def __call__(self, data_dict):
321
+ if "coord" in data_dict.keys():
322
+ data_dict["coord"] = np.clip(
323
+ data_dict["coord"],
324
+ a_min=self.point_cloud_range[:3],
325
+ a_max=self.point_cloud_range[3:],
326
+ )
327
+ return data_dict
328
+
329
+
330
+ @TRANSFORMS.register_module()
331
+ class RandomDropout(object):
332
+ def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5):
333
+ """
334
+ upright_axis: axis index among x,y,z, i.e. 2 for z
335
+ """
336
+ self.dropout_ratio = dropout_ratio
337
+ self.dropout_application_ratio = dropout_application_ratio
338
+
339
+ def __call__(self, data_dict):
340
+ if random.random() < self.dropout_application_ratio:
341
+ n = len(data_dict["coord"])
342
+ idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False)
343
+ if "sampled_index" in data_dict:
344
+ # for ScanNet data efficient, we need to make sure labeled point is sampled.
345
+ idx = np.unique(np.append(idx, data_dict["sampled_index"]))
346
+ mask = np.zeros_like(data_dict["segment"]).astype(bool)
347
+ mask[data_dict["sampled_index"]] = True
348
+ data_dict["sampled_index"] = np.where(mask[idx])[0]
349
+ data_dict = index_operator(data_dict, idx)
350
+ return data_dict
351
+
352
+
353
+ @TRANSFORMS.register_module()
354
+ class RandomRotate(object):
355
+ def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5):
356
+ self.angle = [-1, 1] if angle is None else angle
357
+ self.axis = axis
358
+ self.always_apply = always_apply
359
+ self.p = p if not self.always_apply else 1
360
+ self.center = center
361
+
362
+ def __call__(self, data_dict):
363
+ if random.random() > self.p:
364
+ return data_dict
365
+ angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi
366
+ rot_cos, rot_sin = np.cos(angle), np.sin(angle)
367
+ if self.axis == "x":
368
+ rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
369
+ elif self.axis == "y":
370
+ rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
371
+ elif self.axis == "z":
372
+ rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
373
+ else:
374
+ raise NotImplementedError
375
+ if "coord" in data_dict.keys():
376
+ if self.center is None:
377
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
378
+ x_max, y_max, z_max = data_dict["coord"].max(axis=0)
379
+ center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
380
+ else:
381
+ center = self.center
382
+ data_dict["coord"] -= center
383
+ data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
384
+ data_dict["coord"] += center
385
+ if "normal" in data_dict.keys():
386
+ data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
387
+ return data_dict
388
+
389
+
390
+ @TRANSFORMS.register_module()
391
+ class RandomRotateTargetAngle(object):
392
+ def __init__(
393
+ self, angle=(1 / 2, 1, 3 / 2), center=None, axis="z", always_apply=False, p=0.75
394
+ ):
395
+ self.angle = angle
396
+ self.axis = axis
397
+ self.always_apply = always_apply
398
+ self.p = p if not self.always_apply else 1
399
+ self.center = center
400
+
401
+ def __call__(self, data_dict):
402
+ if random.random() > self.p:
403
+ return data_dict
404
+ angle = np.random.choice(self.angle) * np.pi
405
+ rot_cos, rot_sin = np.cos(angle), np.sin(angle)
406
+ if self.axis == "x":
407
+ rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
408
+ elif self.axis == "y":
409
+ rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
410
+ elif self.axis == "z":
411
+ rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
412
+ else:
413
+ raise NotImplementedError
414
+ if "coord" in data_dict.keys():
415
+ if self.center is None:
416
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
417
+ x_max, y_max, z_max = data_dict["coord"].max(axis=0)
418
+ center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
419
+ else:
420
+ center = self.center
421
+ data_dict["coord"] -= center
422
+ data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
423
+ data_dict["coord"] += center
424
+ if "normal" in data_dict.keys():
425
+ data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
426
+ return data_dict
427
+
428
+
429
+ @TRANSFORMS.register_module()
430
+ class RandomScale(object):
431
+ def __init__(self, scale=None, anisotropic=False):
432
+ self.scale = scale if scale is not None else [0.95, 1.05]
433
+ self.anisotropic = anisotropic
434
+
435
+ def __call__(self, data_dict):
436
+ if "coord" in data_dict.keys():
437
+ scale = np.random.uniform(
438
+ self.scale[0], self.scale[1], 3 if self.anisotropic else 1
439
+ )
440
+ data_dict["coord"] *= scale
441
+ return data_dict
442
+
443
+
444
+ @TRANSFORMS.register_module()
445
+ class RandomFlip(object):
446
+ def __init__(self, p=0.5):
447
+ self.p = p
448
+
449
+ def __call__(self, data_dict):
450
+ if np.random.rand() < self.p:
451
+ if "coord" in data_dict.keys():
452
+ data_dict["coord"][:, 0] = -data_dict["coord"][:, 0]
453
+ if "normal" in data_dict.keys():
454
+ data_dict["normal"][:, 0] = -data_dict["normal"][:, 0]
455
+ if np.random.rand() < self.p:
456
+ if "coord" in data_dict.keys():
457
+ data_dict["coord"][:, 1] = -data_dict["coord"][:, 1]
458
+ if "normal" in data_dict.keys():
459
+ data_dict["normal"][:, 1] = -data_dict["normal"][:, 1]
460
+ return data_dict
461
+
462
+
463
+ @TRANSFORMS.register_module()
464
+ class RandomJitter(object):
465
+ def __init__(self, sigma=0.01, clip=0.05):
466
+ assert clip > 0
467
+ self.sigma = sigma
468
+ self.clip = clip
469
+
470
+ def __call__(self, data_dict):
471
+ if "coord" in data_dict.keys():
472
+ jitter = np.clip(
473
+ self.sigma * np.random.randn(data_dict["coord"].shape[0], 3),
474
+ -self.clip,
475
+ self.clip,
476
+ )
477
+ data_dict["coord"] += jitter
478
+ return data_dict
479
+
480
+
481
+ @TRANSFORMS.register_module()
482
+ class ClipGaussianJitter(object):
483
+ def __init__(self, scalar=0.02, store_jitter=False):
484
+ self.scalar = scalar
485
+ self.mean = np.mean(3)
486
+ self.cov = np.identity(3)
487
+ self.quantile = 1.96
488
+ self.store_jitter = store_jitter
489
+
490
+ def __call__(self, data_dict):
491
+ if "coord" in data_dict.keys():
492
+ jitter = np.random.multivariate_normal(
493
+ self.mean, self.cov, data_dict["coord"].shape[0]
494
+ )
495
+ jitter = self.scalar * np.clip(jitter / 1.96, -1, 1)
496
+ data_dict["coord"] += jitter
497
+ if self.store_jitter:
498
+ data_dict["jitter"] = jitter
499
+ return data_dict
500
+
501
+
502
+ @TRANSFORMS.register_module()
503
+ class ChromaticAutoContrast(object):
504
+ def __init__(self, p=0.2, blend_factor=None):
505
+ self.p = p
506
+ self.blend_factor = blend_factor
507
+
508
+ def __call__(self, data_dict):
509
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
510
+ lo = np.min(data_dict["color"], 0, keepdims=True)
511
+ hi = np.max(data_dict["color"], 0, keepdims=True)
512
+ scale = 255 / (hi - lo)
513
+ contrast_feat = (data_dict["color"][:, :3] - lo) * scale
514
+ blend_factor = (
515
+ np.random.rand() if self.blend_factor is None else self.blend_factor
516
+ )
517
+ data_dict["color"][:, :3] = (1 - blend_factor) * data_dict["color"][
518
+ :, :3
519
+ ] + blend_factor * contrast_feat
520
+ return data_dict
521
+
522
+
523
+ @TRANSFORMS.register_module()
524
+ class ChromaticTranslation(object):
525
+ def __init__(self, p=0.95, ratio=0.05):
526
+ self.p = p
527
+ self.ratio = ratio
528
+
529
+ def __call__(self, data_dict):
530
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
531
+ tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio
532
+ data_dict["color"][:, :3] = np.clip(tr + data_dict["color"][:, :3], 0, 255)
533
+ return data_dict
534
+
535
+
536
+ @TRANSFORMS.register_module()
537
+ class ChromaticJitter(object):
538
+ def __init__(self, p=0.95, std=0.005):
539
+ self.p = p
540
+ self.std = std
541
+
542
+ def __call__(self, data_dict):
543
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
544
+ noise = np.random.randn(data_dict["color"].shape[0], 3)
545
+ noise *= self.std * 255
546
+ data_dict["color"][:, :3] = np.clip(
547
+ noise + data_dict["color"][:, :3], 0, 255
548
+ )
549
+ return data_dict
550
+
551
+
552
+ @TRANSFORMS.register_module()
553
+ class RandomColorGrayScale(object):
554
+ def __init__(self, p):
555
+ self.p = p
556
+
557
+ @staticmethod
558
+ def rgb_to_grayscale(color, num_output_channels=1):
559
+ if color.shape[-1] < 3:
560
+ raise TypeError(
561
+ "Input color should have at least 3 dimensions, but found {}".format(
562
+ color.shape[-1]
563
+ )
564
+ )
565
+
566
+ if num_output_channels not in (1, 3):
567
+ raise ValueError("num_output_channels should be either 1 or 3")
568
+
569
+ r, g, b = color[..., 0], color[..., 1], color[..., 2]
570
+ gray = (0.2989 * r + 0.587 * g + 0.114 * b).astype(color.dtype)
571
+ gray = np.expand_dims(gray, axis=-1)
572
+
573
+ if num_output_channels == 3:
574
+ gray = np.broadcast_to(gray, color.shape)
575
+
576
+ return gray
577
+
578
+ def __call__(self, data_dict):
579
+ if np.random.rand() < self.p:
580
+ data_dict["color"] = self.rgb_to_grayscale(data_dict["color"], 3)
581
+ return data_dict
582
+
583
+
584
+ @TRANSFORMS.register_module()
585
+ class RandomColorJitter(object):
586
+ """
587
+ Random Color Jitter for 3D point cloud (refer torchvision)
588
+ """
589
+
590
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.95):
591
+ self.brightness = self._check_input(brightness, "brightness")
592
+ self.contrast = self._check_input(contrast, "contrast")
593
+ self.saturation = self._check_input(saturation, "saturation")
594
+ self.hue = self._check_input(
595
+ hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False
596
+ )
597
+ self.p = p
598
+
599
+ @staticmethod
600
+ def _check_input(
601
+ value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
602
+ ):
603
+ if isinstance(value, numbers.Number):
604
+ if value < 0:
605
+ raise ValueError(
606
+ "If {} is a single number, it must be non negative.".format(name)
607
+ )
608
+ value = [center - float(value), center + float(value)]
609
+ if clip_first_on_zero:
610
+ value[0] = max(value[0], 0.0)
611
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
612
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
613
+ raise ValueError("{} values should be between {}".format(name, bound))
614
+ else:
615
+ raise TypeError(
616
+ "{} should be a single number or a list/tuple with length 2.".format(
617
+ name
618
+ )
619
+ )
620
+
621
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
622
+ # or (0., 0.) for hue, do nothing
623
+ if value[0] == value[1] == center:
624
+ value = None
625
+ return value
626
+
627
+ @staticmethod
628
+ def blend(color1, color2, ratio):
629
+ ratio = float(ratio)
630
+ bound = 255.0
631
+ return (
632
+ (ratio * color1 + (1.0 - ratio) * color2)
633
+ .clip(0, bound)
634
+ .astype(color1.dtype)
635
+ )
636
+
637
+ @staticmethod
638
+ def rgb2hsv(rgb):
639
+ r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
640
+ maxc = np.max(rgb, axis=-1)
641
+ minc = np.min(rgb, axis=-1)
642
+ eqc = maxc == minc
643
+ cr = maxc - minc
644
+ s = cr / (np.ones_like(maxc) * eqc + maxc * (1 - eqc))
645
+ cr_divisor = np.ones_like(maxc) * eqc + cr * (1 - eqc)
646
+ rc = (maxc - r) / cr_divisor
647
+ gc = (maxc - g) / cr_divisor
648
+ bc = (maxc - b) / cr_divisor
649
+
650
+ hr = (maxc == r) * (bc - gc)
651
+ hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
652
+ hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
653
+ h = hr + hg + hb
654
+ h = (h / 6.0 + 1.0) % 1.0
655
+ return np.stack((h, s, maxc), axis=-1)
656
+
657
+ @staticmethod
658
+ def hsv2rgb(hsv):
659
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
660
+ i = np.floor(h * 6.0)
661
+ f = (h * 6.0) - i
662
+ i = i.astype(np.int32)
663
+
664
+ p = np.clip((v * (1.0 - s)), 0.0, 1.0)
665
+ q = np.clip((v * (1.0 - s * f)), 0.0, 1.0)
666
+ t = np.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
667
+ i = i % 6
668
+ mask = np.expand_dims(i, axis=-1) == np.arange(6)
669
+
670
+ a1 = np.stack((v, q, p, p, t, v), axis=-1)
671
+ a2 = np.stack((t, v, v, q, p, p), axis=-1)
672
+ a3 = np.stack((p, p, t, v, v, q), axis=-1)
673
+ a4 = np.stack((a1, a2, a3), axis=-1)
674
+
675
+ return np.einsum("...na, ...nab -> ...nb", mask.astype(hsv.dtype), a4)
676
+
677
+ def adjust_brightness(self, color, brightness_factor):
678
+ if brightness_factor < 0:
679
+ raise ValueError(
680
+ "brightness_factor ({}) is not non-negative.".format(brightness_factor)
681
+ )
682
+
683
+ return self.blend(color, np.zeros_like(color), brightness_factor)
684
+
685
+ def adjust_contrast(self, color, contrast_factor):
686
+ if contrast_factor < 0:
687
+ raise ValueError(
688
+ "contrast_factor ({}) is not non-negative.".format(contrast_factor)
689
+ )
690
+ mean = np.mean(RandomColorGrayScale.rgb_to_grayscale(color))
691
+ return self.blend(color, mean, contrast_factor)
692
+
693
+ def adjust_saturation(self, color, saturation_factor):
694
+ if saturation_factor < 0:
695
+ raise ValueError(
696
+ "saturation_factor ({}) is not non-negative.".format(saturation_factor)
697
+ )
698
+ gray = RandomColorGrayScale.rgb_to_grayscale(color)
699
+ return self.blend(color, gray, saturation_factor)
700
+
701
+ def adjust_hue(self, color, hue_factor):
702
+ if not (-0.5 <= hue_factor <= 0.5):
703
+ raise ValueError(
704
+ "hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)
705
+ )
706
+ orig_dtype = color.dtype
707
+ hsv = self.rgb2hsv(color / 255.0)
708
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
709
+ h = (h + hue_factor) % 1.0
710
+ hsv = np.stack((h, s, v), axis=-1)
711
+ color_hue_adj = (self.hsv2rgb(hsv) * 255.0).astype(orig_dtype)
712
+ return color_hue_adj
713
+
714
+ @staticmethod
715
+ def get_params(brightness, contrast, saturation, hue):
716
+ fn_idx = torch.randperm(4)
717
+ b = (
718
+ None
719
+ if brightness is None
720
+ else np.random.uniform(brightness[0], brightness[1])
721
+ )
722
+ c = None if contrast is None else np.random.uniform(contrast[0], contrast[1])
723
+ s = (
724
+ None
725
+ if saturation is None
726
+ else np.random.uniform(saturation[0], saturation[1])
727
+ )
728
+ h = None if hue is None else np.random.uniform(hue[0], hue[1])
729
+ return fn_idx, b, c, s, h
730
+
731
+ def __call__(self, data_dict):
732
+ (
733
+ fn_idx,
734
+ brightness_factor,
735
+ contrast_factor,
736
+ saturation_factor,
737
+ hue_factor,
738
+ ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
739
+
740
+ for fn_id in fn_idx:
741
+ if (
742
+ fn_id == 0
743
+ and brightness_factor is not None
744
+ and np.random.rand() < self.p
745
+ ):
746
+ data_dict["color"] = self.adjust_brightness(
747
+ data_dict["color"], brightness_factor
748
+ )
749
+ elif (
750
+ fn_id == 1 and contrast_factor is not None and np.random.rand() < self.p
751
+ ):
752
+ data_dict["color"] = self.adjust_contrast(
753
+ data_dict["color"], contrast_factor
754
+ )
755
+ elif (
756
+ fn_id == 2
757
+ and saturation_factor is not None
758
+ and np.random.rand() < self.p
759
+ ):
760
+ data_dict["color"] = self.adjust_saturation(
761
+ data_dict["color"], saturation_factor
762
+ )
763
+ elif fn_id == 3 and hue_factor is not None and np.random.rand() < self.p:
764
+ data_dict["color"] = self.adjust_hue(data_dict["color"], hue_factor)
765
+ return data_dict
766
+
767
+
768
+ @TRANSFORMS.register_module()
769
+ class HueSaturationTranslation(object):
770
+ @staticmethod
771
+ def rgb_to_hsv(rgb):
772
+ # Translated from source of colorsys.rgb_to_hsv
773
+ # r,g,b should be a numpy arrays with values between 0 and 255
774
+ # rgb_to_hsv returns an array of floats between 0.0 and 1.0.
775
+ rgb = rgb.astype("float")
776
+ hsv = np.zeros_like(rgb)
777
+ # in case an RGBA array was passed, just copy the A channel
778
+ hsv[..., 3:] = rgb[..., 3:]
779
+ r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
780
+ maxc = np.max(rgb[..., :3], axis=-1)
781
+ minc = np.min(rgb[..., :3], axis=-1)
782
+ hsv[..., 2] = maxc
783
+ mask = maxc != minc
784
+ hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
785
+ rc = np.zeros_like(r)
786
+ gc = np.zeros_like(g)
787
+ bc = np.zeros_like(b)
788
+ rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
789
+ gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
790
+ bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
791
+ hsv[..., 0] = np.select(
792
+ [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc
793
+ )
794
+ hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
795
+ return hsv
796
+
797
+ @staticmethod
798
+ def hsv_to_rgb(hsv):
799
+ # Translated from source of colorsys.hsv_to_rgb
800
+ # h,s should be a numpy arrays with values between 0.0 and 1.0
801
+ # v should be a numpy array with values between 0.0 and 255.0
802
+ # hsv_to_rgb returns an array of uints between 0 and 255.
803
+ rgb = np.empty_like(hsv)
804
+ rgb[..., 3:] = hsv[..., 3:]
805
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
806
+ i = (h * 6.0).astype("uint8")
807
+ f = (h * 6.0) - i
808
+ p = v * (1.0 - s)
809
+ q = v * (1.0 - s * f)
810
+ t = v * (1.0 - s * (1.0 - f))
811
+ i = i % 6
812
+ conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
813
+ rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
814
+ rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
815
+ rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
816
+ return rgb.astype("uint8")
817
+
818
+ def __init__(self, hue_max=0.5, saturation_max=0.2):
819
+ self.hue_max = hue_max
820
+ self.saturation_max = saturation_max
821
+
822
+ def __call__(self, data_dict):
823
+ if "color" in data_dict.keys():
824
+ # Assume color[:, :3] is rgb
825
+ hsv = HueSaturationTranslation.rgb_to_hsv(data_dict["color"][:, :3])
826
+ hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max
827
+ sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max
828
+ hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1)
829
+ hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1)
830
+ data_dict["color"][:, :3] = np.clip(
831
+ HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255
832
+ )
833
+ return data_dict
834
+
835
+
836
+ @TRANSFORMS.register_module()
837
+ class RandomColorDrop(object):
838
+ def __init__(self, p=0.2, color_augment=0.0):
839
+ self.p = p
840
+ self.color_augment = color_augment
841
+
842
+ def __call__(self, data_dict):
843
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
844
+ data_dict["color"] *= self.color_augment
845
+ return data_dict
846
+
847
+ def __repr__(self):
848
+ return "RandomColorDrop(color_augment: {}, p: {})".format(
849
+ self.color_augment, self.p
850
+ )
851
+
852
+
853
+ @TRANSFORMS.register_module()
854
+ class ElasticDistortion(object):
855
+ def __init__(self, distortion_params=None):
856
+ self.distortion_params = (
857
+ [[0.2, 0.4], [0.8, 1.6]] if distortion_params is None else distortion_params
858
+ )
859
+
860
+ @staticmethod
861
+ def elastic_distortion(coords, granularity, magnitude):
862
+ """
863
+ Apply elastic distortion on sparse coordinate space.
864
+ pointcloud: numpy array of (number of points, at least 3 spatial dims)
865
+ granularity: size of the noise grid (in same scale[m/cm] as the voxel grid)
866
+ magnitude: noise multiplier
867
+ """
868
+ blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3
869
+ blury = np.ones((1, 3, 1, 1)).astype("float32") / 3
870
+ blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3
871
+ coords_min = coords.min(0)
872
+
873
+ # Create Gaussian noise tensor of the size given by granularity.
874
+ noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
875
+ noise = np.random.randn(*noise_dim, 3).astype(np.float32)
876
+
877
+ # Smoothing.
878
+ for _ in range(2):
879
+ noise = scipy.ndimage.filters.convolve(
880
+ noise, blurx, mode="constant", cval=0
881
+ )
882
+ noise = scipy.ndimage.filters.convolve(
883
+ noise, blury, mode="constant", cval=0
884
+ )
885
+ noise = scipy.ndimage.filters.convolve(
886
+ noise, blurz, mode="constant", cval=0
887
+ )
888
+
889
+ # Trilinear interpolate noise filters for each spatial dimensions.
890
+ ax = [
891
+ np.linspace(d_min, d_max, d)
892
+ for d_min, d_max, d in zip(
893
+ coords_min - granularity,
894
+ coords_min + granularity * (noise_dim - 2),
895
+ noise_dim,
896
+ )
897
+ ]
898
+ interp = scipy.interpolate.RegularGridInterpolator(
899
+ ax, noise, bounds_error=False, fill_value=0
900
+ )
901
+ coords += interp(coords) * magnitude
902
+ return coords
903
+
904
+ def __call__(self, data_dict):
905
+ if "coord" in data_dict.keys() and self.distortion_params is not None:
906
+ if random.random() < 0.95:
907
+ for granularity, magnitude in self.distortion_params:
908
+ data_dict["coord"] = self.elastic_distortion(
909
+ data_dict["coord"], granularity, magnitude
910
+ )
911
+ return data_dict
912
+
913
+
914
+ @TRANSFORMS.register_module()
915
+ class GridSample(object):
916
+ def __init__(
917
+ self,
918
+ grid_size=0.05,
919
+ hash_type="fnv",
920
+ mode="train",
921
+ return_inverse=False,
922
+ return_grid_coord=False,
923
+ return_min_coord=False,
924
+ return_displacement=False,
925
+ project_displacement=False,
926
+ ):
927
+ self.grid_size = grid_size
928
+ self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec
929
+ assert mode in ["train", "test"]
930
+ self.mode = mode
931
+ self.return_inverse = return_inverse
932
+ self.return_grid_coord = return_grid_coord
933
+ self.return_min_coord = return_min_coord
934
+ self.return_displacement = return_displacement
935
+ self.project_displacement = project_displacement
936
+
937
+ def __call__(self, data_dict):
938
+ assert "coord" in data_dict.keys()
939
+ scaled_coord = data_dict["coord"] / np.array(self.grid_size)
940
+ grid_coord = np.floor(scaled_coord).astype(int)
941
+ min_coord = grid_coord.min(0)
942
+ grid_coord -= min_coord
943
+ scaled_coord -= min_coord
944
+ min_coord = min_coord * np.array(self.grid_size)
945
+ key = self.hash(grid_coord)
946
+ idx_sort = np.argsort(key)
947
+ key_sort = key[idx_sort]
948
+ _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
949
+ if self.mode == "train": # train mode
950
+ idx_select = (
951
+ np.cumsum(np.insert(count, 0, 0)[0:-1])
952
+ + np.random.randint(0, count.max(), count.size) % count
953
+ )
954
+ idx_unique = idx_sort[idx_select]
955
+ if "sampled_index" in data_dict:
956
+ # for ScanNet data efficient, we need to make sure labeled point is sampled.
957
+ idx_unique = np.unique(
958
+ np.append(idx_unique, data_dict["sampled_index"])
959
+ )
960
+ mask = np.zeros_like(data_dict["segment"]).astype(bool)
961
+ mask[data_dict["sampled_index"]] = True
962
+ data_dict["sampled_index"] = np.where(mask[idx_unique])[0]
963
+ data_dict = index_operator(data_dict, idx_unique)
964
+ if self.return_inverse:
965
+ data_dict["inverse"] = np.zeros_like(inverse)
966
+ data_dict["inverse"][idx_sort] = inverse
967
+ if self.return_grid_coord:
968
+ data_dict["grid_coord"] = grid_coord[idx_unique]
969
+ data_dict["index_valid_keys"].append("grid_coord")
970
+ if self.return_min_coord:
971
+ data_dict["min_coord"] = min_coord.reshape([1, 3])
972
+ if self.return_displacement:
973
+ displacement = (
974
+ scaled_coord - grid_coord - 0.5
975
+ ) # [0, 1] -> [-0.5, 0.5] displacement to center
976
+ if self.project_displacement:
977
+ displacement = np.sum(
978
+ displacement * data_dict["normal"], axis=-1, keepdims=True
979
+ )
980
+ data_dict["displacement"] = displacement[idx_unique]
981
+ data_dict["index_valid_keys"].append("displacement")
982
+ return data_dict
983
+
984
+ elif self.mode == "test": # test mode
985
+ data_part_list = []
986
+ for i in range(count.max()):
987
+ idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count
988
+ idx_part = idx_sort[idx_select]
989
+ data_part = index_operator(data_dict, idx_part, duplicate=True)
990
+ data_part["index"] = idx_part
991
+ if self.return_inverse:
992
+ data_part["inverse"] = np.zeros_like(inverse)
993
+ data_part["inverse"][idx_sort] = inverse
994
+ if self.return_grid_coord:
995
+ data_part["grid_coord"] = grid_coord[idx_part]
996
+ data_dict["index_valid_keys"].append("grid_coord")
997
+ if self.return_min_coord:
998
+ data_part["min_coord"] = min_coord.reshape([1, 3])
999
+ if self.return_displacement:
1000
+ displacement = (
1001
+ scaled_coord - grid_coord - 0.5
1002
+ ) # [0, 1] -> [-0.5, 0.5] displacement to center
1003
+ if self.project_displacement:
1004
+ displacement = np.sum(
1005
+ displacement * data_dict["normal"], axis=-1, keepdims=True
1006
+ )
1007
+ data_dict["displacement"] = displacement[idx_part]
1008
+ data_dict["index_valid_keys"].append("displacement")
1009
+ data_part_list.append(data_part)
1010
+ return data_part_list
1011
+ else:
1012
+ raise NotImplementedError
1013
+
1014
+ @staticmethod
1015
+ def ravel_hash_vec(arr):
1016
+ """
1017
+ Ravel the coordinates after subtracting the min coordinates.
1018
+ """
1019
+ assert arr.ndim == 2
1020
+ arr = arr.copy()
1021
+ arr -= arr.min(0)
1022
+ arr = arr.astype(np.uint64, copy=False)
1023
+ arr_max = arr.max(0).astype(np.uint64) + 1
1024
+
1025
+ keys = np.zeros(arr.shape[0], dtype=np.uint64)
1026
+ # Fortran style indexing
1027
+ for j in range(arr.shape[1] - 1):
1028
+ keys += arr[:, j]
1029
+ keys *= arr_max[j + 1]
1030
+ keys += arr[:, -1]
1031
+ return keys
1032
+
1033
+ @staticmethod
1034
+ def fnv_hash_vec(arr):
1035
+ """
1036
+ FNV64-1A
1037
+ """
1038
+ assert arr.ndim == 2
1039
+ # Floor first for negative coordinates
1040
+ arr = arr.copy()
1041
+ arr = arr.astype(np.uint64, copy=False)
1042
+ hashed_arr = np.uint64(14695981039346656037) * np.ones(
1043
+ arr.shape[0], dtype=np.uint64
1044
+ )
1045
+ for j in range(arr.shape[1]):
1046
+ hashed_arr *= np.uint64(1099511628211)
1047
+ hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
1048
+ return hashed_arr
1049
+
1050
+
1051
+ @TRANSFORMS.register_module()
1052
+ class SphereCrop(object):
1053
+ def __init__(self, point_max=80000, sample_rate=None, mode="random"):
1054
+ self.point_max = point_max
1055
+ self.sample_rate = sample_rate
1056
+ assert mode in ["random", "center", "all"]
1057
+ self.mode = mode
1058
+
1059
+ def __call__(self, data_dict):
1060
+ point_max = (
1061
+ int(self.sample_rate * data_dict["coord"].shape[0])
1062
+ if self.sample_rate is not None
1063
+ else self.point_max
1064
+ )
1065
+
1066
+ assert "coord" in data_dict.keys()
1067
+ if data_dict["coord"].shape[0] > point_max:
1068
+ if self.mode == "random":
1069
+ center = data_dict["coord"][
1070
+ np.random.randint(data_dict["coord"].shape[0])
1071
+ ]
1072
+ elif self.mode == "center":
1073
+ center = data_dict["coord"][data_dict["coord"].shape[0] // 2]
1074
+ else:
1075
+ raise NotImplementedError
1076
+ idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[
1077
+ :point_max
1078
+ ]
1079
+ data_dict = index_operator(data_dict, idx_crop)
1080
+ return data_dict
1081
+
1082
+
1083
+ @TRANSFORMS.register_module()
1084
+ class ShufflePoint(object):
1085
+ def __call__(self, data_dict):
1086
+ assert "coord" in data_dict.keys()
1087
+ shuffle_index = np.arange(data_dict["coord"].shape[0])
1088
+ np.random.shuffle(shuffle_index)
1089
+ data_dict = index_operator(data_dict, shuffle_index)
1090
+ return data_dict
1091
+
1092
+
1093
+ @TRANSFORMS.register_module()
1094
+ class CropBoundary(object):
1095
+ def __call__(self, data_dict):
1096
+ assert "segment" in data_dict
1097
+ segment = data_dict["segment"].flatten()
1098
+ mask = (segment != 0) * (segment != 1)
1099
+ data_dict = index_operator(data_dict, mask)
1100
+ return data_dict
1101
+
1102
+
1103
+ @TRANSFORMS.register_module()
1104
+ class ContrastiveViewsGenerator(object):
1105
+ def __init__(
1106
+ self,
1107
+ view_keys=("coord", "color", "normal", "origin_coord"),
1108
+ view_trans_cfg=None,
1109
+ ):
1110
+ self.view_keys = view_keys
1111
+ self.view_trans = Compose(view_trans_cfg)
1112
+
1113
+ def __call__(self, data_dict):
1114
+ view1_dict = dict()
1115
+ view2_dict = dict()
1116
+ for key in self.view_keys:
1117
+ view1_dict[key] = data_dict[key].copy()
1118
+ view2_dict[key] = data_dict[key].copy()
1119
+ view1_dict = self.view_trans(view1_dict)
1120
+ view2_dict = self.view_trans(view2_dict)
1121
+ for key, value in view1_dict.items():
1122
+ data_dict["view1_" + key] = value
1123
+ for key, value in view2_dict.items():
1124
+ data_dict["view2_" + key] = value
1125
+ return data_dict
1126
+
1127
+
1128
+ @TRANSFORMS.register_module()
1129
+ class MultiViewGenerator(object):
1130
+ def __init__(
1131
+ self,
1132
+ global_view_num=2,
1133
+ global_view_scale=(0.4, 1.0),
1134
+ local_view_num=4,
1135
+ local_view_scale=(0.1, 0.4),
1136
+ global_shared_transform=None,
1137
+ global_transform=None,
1138
+ local_transform=None,
1139
+ max_size=65536,
1140
+ center_height_scale=(0, 1),
1141
+ shared_global_view=False,
1142
+ view_keys=("coord", "origin_coord", "color", "normal"),
1143
+ ):
1144
+ self.global_view_num = global_view_num
1145
+ self.global_view_scale = global_view_scale
1146
+ self.local_view_num = local_view_num
1147
+ self.local_view_scale = local_view_scale
1148
+ self.global_shared_transform = Compose(global_shared_transform)
1149
+ self.global_transform = Compose(global_transform)
1150
+ self.local_transform = Compose(local_transform)
1151
+ self.max_size = max_size
1152
+ self.center_height_scale = center_height_scale
1153
+ self.shared_global_view = shared_global_view
1154
+ self.view_keys = view_keys
1155
+ assert "coord" in view_keys
1156
+
1157
+ def get_view(self, point, center, scale):
1158
+ coord = point["coord"]
1159
+ max_size = min(self.max_size, coord.shape[0])
1160
+ size = int(np.random.uniform(*scale) * max_size)
1161
+ index = np.argsort(np.sum(np.square(coord - center), axis=-1))[:size]
1162
+ view = dict(index=index)
1163
+ for key in point.keys():
1164
+ if key in self.view_keys:
1165
+ view[key] = point[key][index]
1166
+
1167
+ if "index_valid_keys" in point.keys():
1168
+ # inherit index_valid_keys from point
1169
+ view["index_valid_keys"] = point["index_valid_keys"]
1170
+ return view
1171
+
1172
+ def __call__(self, data_dict):
1173
+ coord = data_dict["coord"]
1174
+ point = self.global_shared_transform(copy.deepcopy(data_dict))
1175
+ z_min = coord[:, 2].min()
1176
+ z_max = coord[:, 2].max()
1177
+ z_min_ = z_min + (z_max - z_min) * self.center_height_scale[0]
1178
+ z_max_ = z_min + (z_max - z_min) * self.center_height_scale[1]
1179
+ center_mask = np.logical_and(coord[:, 2] >= z_min_, coord[:, 2] <= z_max_)
1180
+ # get major global view
1181
+ major_center = coord[np.random.choice(np.where(center_mask)[0])]
1182
+ major_view = self.get_view(point, major_center, self.global_view_scale)
1183
+ major_coord = major_view["coord"]
1184
+ # get global views: restrict the center of left global view within the major global view
1185
+ if not self.shared_global_view:
1186
+ global_views = [
1187
+ self.get_view(
1188
+ point=point,
1189
+ center=major_coord[np.random.randint(major_coord.shape[0])],
1190
+ scale=self.global_view_scale,
1191
+ )
1192
+ for _ in range(self.global_view_num - 1)
1193
+ ]
1194
+ else:
1195
+ global_views = [
1196
+ {key: value.copy() for key, value in major_view.items()}
1197
+ for _ in range(self.global_view_num - 1)
1198
+ ]
1199
+
1200
+ global_views = [major_view] + global_views
1201
+
1202
+ # get local views: restrict the center of local view within the major global view
1203
+ cover_mask = np.zeros_like(major_view["index"], dtype=bool)
1204
+ local_views = []
1205
+ for i in range(self.local_view_num):
1206
+ if sum(~cover_mask) == 0:
1207
+ # reset cover mask if all points are sampled
1208
+ cover_mask[:] = False
1209
+ local_view = self.get_view(
1210
+ point=data_dict,
1211
+ center=major_coord[np.random.choice(np.where(~cover_mask)[0])],
1212
+ scale=self.local_view_scale,
1213
+ )
1214
+ local_views.append(local_view)
1215
+ cover_mask[np.isin(major_view["index"], local_view["index"])] = True
1216
+
1217
+ # augmentation and concat
1218
+ view_dict = {}
1219
+ for global_view in global_views:
1220
+ global_view.pop("index")
1221
+ global_view = self.global_transform(global_view)
1222
+ for key in self.view_keys:
1223
+ if f"global_{key}" in view_dict.keys():
1224
+ view_dict[f"global_{key}"].append(global_view[key])
1225
+ else:
1226
+ view_dict[f"global_{key}"] = [global_view[key]]
1227
+ view_dict["global_offset"] = np.cumsum(
1228
+ [data.shape[0] for data in view_dict["global_coord"]]
1229
+ )
1230
+ for local_view in local_views:
1231
+ local_view.pop("index")
1232
+ local_view = self.local_transform(local_view)
1233
+ for key in self.view_keys:
1234
+ if f"local_{key}" in view_dict.keys():
1235
+ view_dict[f"local_{key}"].append(local_view[key])
1236
+ else:
1237
+ view_dict[f"local_{key}"] = [local_view[key]]
1238
+ view_dict["local_offset"] = np.cumsum(
1239
+ [data.shape[0] for data in view_dict["local_coord"]]
1240
+ )
1241
+ for key in view_dict.keys():
1242
+ if "offset" not in key:
1243
+ view_dict[key] = np.concatenate(view_dict[key], axis=0)
1244
+ data_dict.update(view_dict)
1245
+ return data_dict
1246
+
1247
+
1248
+ @TRANSFORMS.register_module()
1249
+ class InstanceParser(object):
1250
+ def __init__(self, segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1):
1251
+ self.segment_ignore_index = segment_ignore_index
1252
+ self.instance_ignore_index = instance_ignore_index
1253
+
1254
+ def __call__(self, data_dict):
1255
+ coord = data_dict["coord"]
1256
+ segment = data_dict["segment"]
1257
+ instance = data_dict["instance"]
1258
+ mask = ~np.in1d(segment, self.segment_ignore_index)
1259
+ # mapping ignored instance to ignore index
1260
+ instance[~mask] = self.instance_ignore_index
1261
+ # reorder left instance
1262
+ unique, inverse = np.unique(instance[mask], return_inverse=True)
1263
+ instance_num = len(unique)
1264
+ instance[mask] = inverse
1265
+ # init instance information
1266
+ centroid = np.ones((coord.shape[0], 3)) * self.instance_ignore_index
1267
+ bbox = np.ones((instance_num, 8)) * self.instance_ignore_index
1268
+ vacancy = [
1269
+ index for index in self.segment_ignore_index if index >= 0
1270
+ ] # vacate class index
1271
+
1272
+ for instance_id in range(instance_num):
1273
+ mask_ = instance == instance_id
1274
+ coord_ = coord[mask_]
1275
+ bbox_min = coord_.min(0)
1276
+ bbox_max = coord_.max(0)
1277
+ bbox_centroid = coord_.mean(0)
1278
+ bbox_center = (bbox_max + bbox_min) / 2
1279
+ bbox_size = bbox_max - bbox_min
1280
+ bbox_theta = np.zeros(1, dtype=coord_.dtype)
1281
+ bbox_class = np.array([segment[mask_][0]], dtype=coord_.dtype)
1282
+ # shift class index to fill vacate class index caused by segment ignore index
1283
+ bbox_class -= np.greater(bbox_class, vacancy).sum()
1284
+
1285
+ centroid[mask_] = bbox_centroid
1286
+ bbox[instance_id] = np.concatenate(
1287
+ [bbox_center, bbox_size, bbox_theta, bbox_class]
1288
+ ) # 3 + 3 + 1 + 1 = 8
1289
+ data_dict["instance"] = instance
1290
+ data_dict["instance_centroid"] = centroid
1291
+ data_dict["bbox"] = bbox
1292
+ return data_dict
1293
+
1294
+
1295
+ class Compose(object):
1296
+ def __init__(self, cfg=None):
1297
+ self.cfg = cfg if cfg is not None else []
1298
+ self.transforms = []
1299
+ for t_cfg in self.cfg:
1300
+ self.transforms.append(TRANSFORMS.build(t_cfg))
1301
+
1302
+ def __call__(self, data_dict):
1303
+ for t in self.transforms:
1304
+ data_dict = t(data_dict)
1305
+ return data_dict
1306
+
1307
+
1308
+ def default():
1309
+ config = [
1310
+ dict(type="CenterShift", apply_z=True),
1311
+ dict(
1312
+ type="GridSample",
1313
+ # grid_size=0.02,
1314
+ # grid_size=0.01,
1315
+ grid_size=0.005,
1316
+ # grid_size=0.0025,
1317
+ hash_type="fnv",
1318
+ mode="train",
1319
+ return_grid_coord=True,
1320
+ return_inverse=True,
1321
+ ),
1322
+ dict(type="NormalizeColor"),
1323
+ dict(type="ToTensor"),
1324
+ dict(
1325
+ type="Collect",
1326
+ keys=("coord", "grid_coord", "color", "inverse"),
1327
+ feat_keys=("coord", "color", "normal"),
1328
+ ),
1329
+ ]
1330
+ return Compose(config)
XPart/partgen/models/sonata/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ General utils
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import os
24
+ import random
25
+ import numpy as np
26
+ import torch
27
+ import torch.backends.cudnn as cudnn
28
+ from datetime import datetime
29
+
30
+
31
+ @torch.no_grad()
32
+ def offset2bincount(offset):
33
+ return torch.diff(
34
+ offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
35
+ )
36
+
37
+
38
+ @torch.no_grad()
39
+ def bincount2offset(bincount):
40
+ return torch.cumsum(bincount, dim=0)
41
+
42
+
43
+ @torch.no_grad()
44
+ def offset2batch(offset):
45
+ bincount = offset2bincount(offset)
46
+ return torch.arange(
47
+ len(bincount), device=offset.device, dtype=torch.long
48
+ ).repeat_interleave(bincount)
49
+
50
+
51
+ @torch.no_grad()
52
+ def batch2offset(batch):
53
+ return torch.cumsum(batch.bincount(), dim=0).long()
54
+
55
+
56
+ def get_random_seed():
57
+ seed = (
58
+ os.getpid()
59
+ + int(datetime.now().strftime("%S%f"))
60
+ + int.from_bytes(os.urandom(2), "big")
61
+ )
62
+ return seed
63
+
64
+
65
+ def set_seed(seed=None):
66
+ if seed is None:
67
+ seed = get_random_seed()
68
+ random.seed(seed)
69
+ np.random.seed(seed)
70
+ torch.manual_seed(seed)
71
+ torch.cuda.manual_seed(seed)
72
+ torch.cuda.manual_seed_all(seed)
73
+ cudnn.benchmark = False
74
+ cudnn.deterministic = True
75
+ os.environ["PYTHONHASHSEED"] = str(seed)