xinjie.wang commited on
Commit
2442d05
·
0 Parent(s):

Initial clean commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. README.md +15 -0
  3. app.py +479 -0
  4. app_style.py +27 -0
  5. common.py +849 -0
  6. embodied_gen/data/asset_converter.py +663 -0
  7. embodied_gen/data/backproject.py +518 -0
  8. embodied_gen/data/backproject_v2.py +721 -0
  9. embodied_gen/data/convex_decomposer.py +190 -0
  10. embodied_gen/data/datasets.py +320 -0
  11. embodied_gen/data/differentiable_render.py +589 -0
  12. embodied_gen/data/mesh_operator.py +461 -0
  13. embodied_gen/data/utils.py +1039 -0
  14. embodied_gen/envs/pick_embodiedgen.py +420 -0
  15. embodied_gen/models/delight_model.py +202 -0
  16. embodied_gen/models/gs_model.py +511 -0
  17. embodied_gen/models/image_comm_model.py +236 -0
  18. embodied_gen/models/layout.py +510 -0
  19. embodied_gen/models/segment_model.py +379 -0
  20. embodied_gen/models/sr_model.py +174 -0
  21. embodied_gen/models/text_model.py +213 -0
  22. embodied_gen/models/texture_model.py +112 -0
  23. embodied_gen/scripts/compose_layout.py +79 -0
  24. embodied_gen/scripts/gen_layout.py +171 -0
  25. embodied_gen/scripts/gen_scene3d.py +191 -0
  26. embodied_gen/scripts/gen_texture.py +123 -0
  27. embodied_gen/scripts/imageto3d.py +350 -0
  28. embodied_gen/scripts/parallel_sim.py +166 -0
  29. embodied_gen/scripts/render_gs.py +164 -0
  30. embodied_gen/scripts/render_mv.py +198 -0
  31. embodied_gen/scripts/simulate_sapien.py +196 -0
  32. embodied_gen/scripts/text2image.py +162 -0
  33. embodied_gen/scripts/textto3d.py +282 -0
  34. embodied_gen/scripts/textto3d.sh +94 -0
  35. embodied_gen/scripts/texture_gen.sh +78 -0
  36. embodied_gen/trainer/gsplat_trainer.py +678 -0
  37. embodied_gen/trainer/pono2mesh_trainer.py +538 -0
  38. embodied_gen/utils/config.py +202 -0
  39. embodied_gen/utils/enum.py +108 -0
  40. embodied_gen/utils/gaussian.py +330 -0
  41. embodied_gen/utils/geometry.py +515 -0
  42. embodied_gen/utils/gpt_clients.py +218 -0
  43. embodied_gen/utils/gpt_config.yaml +14 -0
  44. embodied_gen/utils/log.py +48 -0
  45. embodied_gen/utils/monkey_patches.py +218 -0
  46. embodied_gen/utils/process_media.py +467 -0
  47. embodied_gen/utils/simulation.py +667 -0
  48. embodied_gen/utils/tags.py +1 -0
  49. embodied_gen/utils/trender.py +90 -0
  50. embodied_gen/validators/aesthetic_predictor.py +137 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EmbodiedGen Text To 3D
3
+ emoji: 📝
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.33.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Create 3D models from text descriptions
12
+ paper: https://huggingface.co/papers/2506.10600
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import os
19
+
20
+ os.environ["GRADIO_APP"] = "textto3d"
21
+
22
+
23
+ import gradio as gr
24
+ from app_style import custom_theme, image_css, lighting_css
25
+ from common import (
26
+ MAX_SEED,
27
+ VERSION,
28
+ active_btn_by_text_content,
29
+ end_session,
30
+ extract_3d_representations_v2,
31
+ extract_urdf,
32
+ get_cached_image,
33
+ get_seed,
34
+ get_selected_image,
35
+ image_to_3d,
36
+ start_session,
37
+ text2image_fn,
38
+ )
39
+
40
+ with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo:
41
+ gr.HTML(image_css, visible=False)
42
+ # gr.HTML(lighting_css, visible=False)
43
+ gr.Markdown(
44
+ """
45
+ ## ***EmbodiedGen***: Text-to-3D Asset
46
+ **🔖 Version**: {VERSION}
47
+ <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
48
+ <a href="https://horizonrobotics.github.io/EmbodiedGen">
49
+ <img alt="📖 Documentation" src="https://img.shields.io/badge/📖-Documentation-blue">
50
+ </a>
51
+ <a href="https://arxiv.org/abs/2506.10600">
52
+ <img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
53
+ </a>
54
+ <a href="https://github.com/HorizonRobotics/EmbodiedGen">
55
+ <img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
56
+ </a>
57
+ <a href="https://www.youtube.com/watch?v=rG4odybuJRk">
58
+ <img alt="🎥 Video" src="https://img.shields.io/badge/🎥-Video-red">
59
+ </a>
60
+ </p>
61
+
62
+ 📝 Create 3D assets from text descriptions for a wide range of geometry and styles.
63
+ """.format(
64
+ VERSION=VERSION
65
+ ),
66
+ elem_classes=["header"],
67
+ )
68
+
69
+ with gr.Row():
70
+ with gr.Column(scale=1):
71
+ raw_image_cache = gr.Image(
72
+ format="png",
73
+ image_mode="RGB",
74
+ type="pil",
75
+ visible=False,
76
+ )
77
+ text_prompt = gr.Textbox(
78
+ label="Text Prompt (Chinese or English)",
79
+ placeholder="Input text prompt here",
80
+ )
81
+ ip_image = gr.Image(
82
+ label="Reference Image(optional)",
83
+ format="png",
84
+ image_mode="RGB",
85
+ type="filepath",
86
+ height=250,
87
+ elem_classes=["image_fit"],
88
+ )
89
+ gr.Markdown(
90
+ "Note: The `reference image` is optional, if use, "
91
+ "please provide image in nearly square resolution."
92
+ )
93
+
94
+ with gr.Accordion(label="Image Generation Settings", open=False):
95
+ with gr.Row():
96
+ seed = gr.Slider(
97
+ 0, MAX_SEED, label="Seed", value=0, step=1
98
+ )
99
+ randomize_seed = gr.Checkbox(
100
+ label="Randomize Seed", value=False
101
+ )
102
+ rmbg_tag = gr.Radio(
103
+ choices=["rembg", "rmbg14"],
104
+ value="rembg",
105
+ label="Background Removal Model",
106
+ )
107
+ ip_adapt_scale = gr.Slider(
108
+ 0, 1, label="IP-adapter Scale", value=0.3, step=0.05
109
+ )
110
+ img_guidance_scale = gr.Slider(
111
+ 1, 30, label="Text Guidance Scale", value=12, step=0.2
112
+ )
113
+ img_inference_steps = gr.Slider(
114
+ 10, 100, label="Sampling Steps", value=50, step=5
115
+ )
116
+ img_resolution = gr.Slider(
117
+ 512,
118
+ 1536,
119
+ label="Image Resolution",
120
+ value=1024,
121
+ step=128,
122
+ )
123
+
124
+ generate_img_btn = gr.Button(
125
+ "🎨 1. Generate Images(~1min)",
126
+ variant="primary",
127
+ interactive=False,
128
+ )
129
+ dropdown = gr.Radio(
130
+ choices=["sample1", "sample2", "sample3"],
131
+ value="sample1",
132
+ label="Choose your favorite sample style.",
133
+ )
134
+ select_img = gr.Image(
135
+ visible=False,
136
+ format="png",
137
+ image_mode="RGBA",
138
+ type="pil",
139
+ height=300,
140
+ )
141
+
142
+ # text to 3d
143
+ with gr.Accordion(label="Generation Settings", open=False):
144
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
145
+ texture_size = gr.Slider(
146
+ 1024, 4096, label="UV texture size", value=2048, step=256
147
+ )
148
+ with gr.Row():
149
+ randomize_seed = gr.Checkbox(
150
+ label="Randomize Seed", value=False
151
+ )
152
+ project_delight = gr.Checkbox(
153
+ label="Back-project Delight", value=True
154
+ )
155
+ gr.Markdown("Geo Structure Generation")
156
+ with gr.Row():
157
+ ss_guidance_strength = gr.Slider(
158
+ 0.0,
159
+ 10.0,
160
+ label="Guidance Strength",
161
+ value=7.5,
162
+ step=0.1,
163
+ )
164
+ ss_sampling_steps = gr.Slider(
165
+ 1, 50, label="Sampling Steps", value=12, step=1
166
+ )
167
+ gr.Markdown("Visual Appearance Generation")
168
+ with gr.Row():
169
+ slat_guidance_strength = gr.Slider(
170
+ 0.0,
171
+ 10.0,
172
+ label="Guidance Strength",
173
+ value=3.0,
174
+ step=0.1,
175
+ )
176
+ slat_sampling_steps = gr.Slider(
177
+ 1, 50, label="Sampling Steps", value=12, step=1
178
+ )
179
+
180
+ generate_btn = gr.Button(
181
+ "🚀 2. Generate 3D(~0.5 mins)",
182
+ variant="primary",
183
+ interactive=False,
184
+ )
185
+ model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False)
186
+ with gr.Row():
187
+ extract_rep3d_btn = gr.Button(
188
+ "🔍 3. Extract 3D Representation(~1 mins)",
189
+ variant="primary",
190
+ interactive=False,
191
+ )
192
+ with gr.Accordion(
193
+ label="Enter Asset Attributes(optional)", open=False
194
+ ):
195
+ asset_cat_text = gr.Textbox(
196
+ label="Enter Asset Category (e.g., chair)"
197
+ )
198
+ height_range_text = gr.Textbox(
199
+ label="Enter Height Range in meter (e.g., 0.5-0.6)"
200
+ )
201
+ mass_range_text = gr.Textbox(
202
+ label="Enter Mass Range in kg (e.g., 1.1-1.2)"
203
+ )
204
+ asset_version_text = gr.Textbox(
205
+ label=f"Enter version (e.g., {VERSION})"
206
+ )
207
+ with gr.Row():
208
+ extract_urdf_btn = gr.Button(
209
+ "🧩 4. Extract URDF with physics(~1 mins)",
210
+ variant="primary",
211
+ interactive=False,
212
+ )
213
+ with gr.Row():
214
+ download_urdf = gr.DownloadButton(
215
+ label="⬇️ 5. Download URDF",
216
+ variant="primary",
217
+ interactive=False,
218
+ )
219
+
220
+ with gr.Column(scale=3):
221
+ with gr.Row():
222
+ image_sample1 = gr.Image(
223
+ label="sample1",
224
+ format="png",
225
+ image_mode="RGBA",
226
+ type="filepath",
227
+ height=300,
228
+ interactive=False,
229
+ elem_classes=["image_fit"],
230
+ )
231
+ image_sample2 = gr.Image(
232
+ label="sample2",
233
+ format="png",
234
+ image_mode="RGBA",
235
+ type="filepath",
236
+ height=300,
237
+ interactive=False,
238
+ elem_classes=["image_fit"],
239
+ )
240
+ image_sample3 = gr.Image(
241
+ label="sample3",
242
+ format="png",
243
+ image_mode="RGBA",
244
+ type="filepath",
245
+ height=300,
246
+ interactive=False,
247
+ elem_classes=["image_fit"],
248
+ )
249
+ usample1 = gr.Image(
250
+ format="png",
251
+ image_mode="RGBA",
252
+ type="filepath",
253
+ visible=False,
254
+ )
255
+ usample2 = gr.Image(
256
+ format="png",
257
+ image_mode="RGBA",
258
+ type="filepath",
259
+ visible=False,
260
+ )
261
+ usample3 = gr.Image(
262
+ format="png",
263
+ image_mode="RGBA",
264
+ type="filepath",
265
+ visible=False,
266
+ )
267
+ gr.Markdown(
268
+ "Generated image may be poor quality due to auto seg."
269
+ "Retry by adjusting text prompt, seed or switch seg model in `Image Gen Settings`."
270
+ )
271
+ with gr.Row():
272
+ video_output = gr.Video(
273
+ label="Generated 3D Asset",
274
+ autoplay=True,
275
+ loop=True,
276
+ height=300,
277
+ interactive=False,
278
+ )
279
+ model_output_gs = gr.Model3D(
280
+ label="Gaussian Representation",
281
+ height=300,
282
+ interactive=False,
283
+ )
284
+ aligned_gs = gr.Textbox(visible=False)
285
+
286
+ model_output_mesh = gr.Model3D(
287
+ label="Mesh Representation",
288
+ clear_color=[0.8, 0.8, 0.8, 1],
289
+ height=300,
290
+ interactive=False,
291
+ elem_id="lighter_mesh",
292
+ )
293
+
294
+ gr.Markdown("Estimated Asset 3D Attributes(No input required)")
295
+ with gr.Row():
296
+ est_type_text = gr.Textbox(
297
+ label="Asset category", interactive=False
298
+ )
299
+ est_height_text = gr.Textbox(
300
+ label="Real height(.m)", interactive=False
301
+ )
302
+ est_mass_text = gr.Textbox(
303
+ label="Mass(.kg)", interactive=False
304
+ )
305
+ est_mu_text = gr.Textbox(
306
+ label="Friction coefficient", interactive=False
307
+ )
308
+
309
+ prompt_examples = [
310
+ "satin gold tea cup with saucer",
311
+ "small bronze figurine of a lion",
312
+ "brown leather bag",
313
+ "Miniature cup with floral design",
314
+ "带木质底座, 具有经纬线的地球仪",
315
+ "橙色电动手钻, 有磨损细节",
316
+ "手工制作的皮革笔记本",
317
+ ]
318
+ examples = gr.Examples(
319
+ label="Gallery",
320
+ examples=prompt_examples,
321
+ inputs=[text_prompt],
322
+ examples_per_page=10,
323
+ )
324
+
325
+ output_buf = gr.State()
326
+
327
+ demo.load(start_session)
328
+ demo.unload(end_session)
329
+
330
+ text_prompt.change(
331
+ active_btn_by_text_content,
332
+ inputs=[text_prompt],
333
+ outputs=[generate_img_btn],
334
+ )
335
+
336
+ generate_img_btn.click(
337
+ lambda: tuple(
338
+ [
339
+ gr.Button(interactive=False),
340
+ gr.Button(interactive=False),
341
+ gr.Button(interactive=False),
342
+ gr.Button(interactive=False),
343
+ None,
344
+ "",
345
+ None,
346
+ None,
347
+ "",
348
+ "",
349
+ "",
350
+ "",
351
+ "",
352
+ "",
353
+ "",
354
+ "",
355
+ None,
356
+ None,
357
+ None,
358
+ ]
359
+ ),
360
+ outputs=[
361
+ extract_rep3d_btn,
362
+ extract_urdf_btn,
363
+ download_urdf,
364
+ generate_btn,
365
+ model_output_gs,
366
+ aligned_gs,
367
+ model_output_mesh,
368
+ video_output,
369
+ asset_cat_text,
370
+ height_range_text,
371
+ mass_range_text,
372
+ asset_version_text,
373
+ est_type_text,
374
+ est_height_text,
375
+ est_mass_text,
376
+ est_mu_text,
377
+ image_sample1,
378
+ image_sample2,
379
+ image_sample3,
380
+ ],
381
+ ).success(
382
+ text2image_fn,
383
+ inputs=[
384
+ text_prompt,
385
+ img_guidance_scale,
386
+ img_inference_steps,
387
+ ip_image,
388
+ ip_adapt_scale,
389
+ img_resolution,
390
+ rmbg_tag,
391
+ seed,
392
+ ],
393
+ outputs=[
394
+ image_sample1,
395
+ image_sample2,
396
+ image_sample3,
397
+ usample1,
398
+ usample2,
399
+ usample3,
400
+ ],
401
+ ).success(
402
+ lambda: gr.Button(interactive=True),
403
+ outputs=[generate_btn],
404
+ )
405
+
406
+ generate_btn.click(
407
+ get_seed,
408
+ inputs=[randomize_seed, seed],
409
+ outputs=[seed],
410
+ ).success(
411
+ get_selected_image,
412
+ inputs=[dropdown, usample1, usample2, usample3],
413
+ outputs=select_img,
414
+ ).success(
415
+ get_cached_image,
416
+ inputs=[select_img],
417
+ outputs=[raw_image_cache],
418
+ ).success(
419
+ image_to_3d,
420
+ inputs=[
421
+ select_img,
422
+ seed,
423
+ ss_guidance_strength,
424
+ ss_sampling_steps,
425
+ slat_guidance_strength,
426
+ slat_sampling_steps,
427
+ raw_image_cache,
428
+ ],
429
+ outputs=[output_buf, video_output],
430
+ ).success(
431
+ lambda: gr.Button(interactive=True),
432
+ outputs=[extract_rep3d_btn],
433
+ )
434
+
435
+ extract_rep3d_btn.click(
436
+ extract_3d_representations_v2,
437
+ inputs=[
438
+ output_buf,
439
+ project_delight,
440
+ texture_size,
441
+ ],
442
+ outputs=[
443
+ model_output_mesh,
444
+ model_output_gs,
445
+ model_output_obj,
446
+ aligned_gs,
447
+ ],
448
+ ).success(
449
+ lambda: gr.Button(interactive=True),
450
+ outputs=[extract_urdf_btn],
451
+ )
452
+
453
+ extract_urdf_btn.click(
454
+ extract_urdf,
455
+ inputs=[
456
+ aligned_gs,
457
+ model_output_obj,
458
+ asset_cat_text,
459
+ height_range_text,
460
+ mass_range_text,
461
+ asset_version_text,
462
+ ],
463
+ outputs=[
464
+ download_urdf,
465
+ est_type_text,
466
+ est_height_text,
467
+ est_mass_text,
468
+ est_mu_text,
469
+ ],
470
+ queue=True,
471
+ show_progress="full",
472
+ ).success(
473
+ lambda: gr.Button(interactive=True),
474
+ outputs=[download_urdf],
475
+ )
476
+
477
+
478
+ if __name__ == "__main__":
479
+ demo.launch()
app_style.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio.themes import Soft
2
+ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
3
+
4
+ lighting_css = """
5
+ <style>
6
+ #lighter_mesh canvas {
7
+ filter: brightness(1.9) !important;
8
+ }
9
+ </style>
10
+ """
11
+
12
+ image_css = """
13
+ <style>
14
+ .image_fit .image-frame {
15
+ object-fit: contain !important;
16
+ height: 100% !important;
17
+ }
18
+ </style>
19
+ """
20
+
21
+ custom_theme = Soft(
22
+ primary_hue=stone,
23
+ secondary_hue=gray,
24
+ radius_size="md",
25
+ text_size="sm",
26
+ spacing_size="sm",
27
+ )
common.py ADDED
@@ -0,0 +1,849 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import gc
18
+ import logging
19
+ import os
20
+ import shutil
21
+ import subprocess
22
+ import sys
23
+ from glob import glob
24
+
25
+ import cv2
26
+ import gradio as gr
27
+ import numpy as np
28
+ import spaces
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import trimesh
32
+ from easydict import EasyDict as edict
33
+ from PIL import Image
34
+ from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
+ from embodied_gen.data.differentiable_render import entrypoint as render_api
36
+ from embodied_gen.data.utils import trellis_preprocess, zip_files
37
+ from embodied_gen.models.delight_model import DelightingModel
38
+ from embodied_gen.models.gs_model import GaussianOperator
39
+ from embodied_gen.models.segment_model import (
40
+ BMGG14Remover,
41
+ RembgRemover,
42
+ SAMPredictor,
43
+ )
44
+ from embodied_gen.models.sr_model import ImageRealESRGAN, ImageStableSR
45
+ from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
46
+ from embodied_gen.scripts.render_mv import build_texture_gen_pipe, infer_pipe
47
+ from embodied_gen.scripts.text2image import (
48
+ build_text2img_ip_pipeline,
49
+ build_text2img_pipeline,
50
+ text2img_gen,
51
+ )
52
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
53
+ from embodied_gen.utils.process_media import (
54
+ filter_image_small_connected_components,
55
+ merge_images_video,
56
+ )
57
+ from embodied_gen.utils.tags import VERSION
58
+ from embodied_gen.utils.trender import render_video
59
+ from embodied_gen.validators.quality_checkers import (
60
+ BaseChecker,
61
+ ImageAestheticChecker,
62
+ ImageSegChecker,
63
+ MeshGeoChecker,
64
+ )
65
+ from embodied_gen.validators.urdf_convertor import URDFGenerator
66
+
67
+ current_file_path = os.path.abspath(__file__)
68
+ current_dir = os.path.dirname(current_file_path)
69
+ sys.path.append(os.path.join(current_dir, ".."))
70
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
71
+ from thirdparty.TRELLIS.trellis.representations import (
72
+ Gaussian,
73
+ MeshExtractResult,
74
+ )
75
+ from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
76
+ build_scaling_rotation,
77
+ inverse_sigmoid,
78
+ strip_symmetric,
79
+ )
80
+ from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
81
+
82
+ logging.basicConfig(
83
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
84
+ )
85
+ logger = logging.getLogger(__name__)
86
+
87
+
88
+ os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
89
+ "~/.cache/torch_extensions"
90
+ )
91
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
92
+ os.environ["SPCONV_ALGO"] = "native"
93
+
94
+ MAX_SEED = 100000
95
+
96
+
97
+ def patched_setup_functions(self):
98
+ def inverse_softplus(x):
99
+ return x + torch.log(-torch.expm1(-x))
100
+
101
+ def build_covariance_from_scaling_rotation(
102
+ scaling, scaling_modifier, rotation
103
+ ):
104
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
105
+ actual_covariance = L @ L.transpose(1, 2)
106
+ symm = strip_symmetric(actual_covariance)
107
+ return symm
108
+
109
+ if self.scaling_activation_type == "exp":
110
+ self.scaling_activation = torch.exp
111
+ self.inverse_scaling_activation = torch.log
112
+ elif self.scaling_activation_type == "softplus":
113
+ self.scaling_activation = F.softplus
114
+ self.inverse_scaling_activation = inverse_softplus
115
+
116
+ self.covariance_activation = build_covariance_from_scaling_rotation
117
+ self.opacity_activation = torch.sigmoid
118
+ self.inverse_opacity_activation = inverse_sigmoid
119
+ self.rotation_activation = F.normalize
120
+
121
+ self.scale_bias = self.inverse_scaling_activation(
122
+ torch.tensor(self.scaling_bias)
123
+ ).to(self.device)
124
+ self.rots_bias = torch.zeros((4)).to(self.device)
125
+ self.rots_bias[0] = 1
126
+ self.opacity_bias = self.inverse_opacity_activation(
127
+ torch.tensor(self.opacity_bias)
128
+ ).to(self.device)
129
+
130
+
131
+ Gaussian.setup_functions = patched_setup_functions
132
+
133
+
134
+ DELIGHT = DelightingModel()
135
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
136
+ # IMAGESR_MODEL = ImageStableSR()
137
+ if os.getenv("GRADIO_APP") == "imageto3d":
138
+ RBG_REMOVER = RembgRemover()
139
+ RBG14_REMOVER = BMGG14Remover()
140
+ SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
141
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
142
+ "microsoft/TRELLIS-image-large"
143
+ )
144
+ # PIPELINE.cuda()
145
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
146
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
147
+ AESTHETIC_CHECKER = ImageAestheticChecker()
148
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
149
+ TMP_DIR = os.path.join(
150
+ os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
151
+ )
152
+ os.makedirs(TMP_DIR, exist_ok=True)
153
+ elif os.getenv("GRADIO_APP") == "textto3d":
154
+ RBG_REMOVER = RembgRemover()
155
+ RBG14_REMOVER = BMGG14Remover()
156
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
157
+ "microsoft/TRELLIS-image-large"
158
+ )
159
+ # PIPELINE.cuda()
160
+ text_model_dir = "weights/Kolors"
161
+ PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
162
+ PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
163
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
164
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
165
+ AESTHETIC_CHECKER = ImageAestheticChecker()
166
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
167
+ TMP_DIR = os.path.join(
168
+ os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
169
+ )
170
+ os.makedirs(TMP_DIR, exist_ok=True)
171
+ elif os.getenv("GRADIO_APP") == "texture_edit":
172
+ PIPELINE_IP = build_texture_gen_pipe(
173
+ base_ckpt_dir="./weights",
174
+ ip_adapt_scale=0.7,
175
+ device="cuda",
176
+ )
177
+ PIPELINE = build_texture_gen_pipe(
178
+ base_ckpt_dir="./weights",
179
+ ip_adapt_scale=0,
180
+ device="cuda",
181
+ )
182
+ TMP_DIR = os.path.join(
183
+ os.path.dirname(os.path.abspath(__file__)), "sessions/texture_edit"
184
+ )
185
+ os.makedirs(TMP_DIR, exist_ok=True)
186
+
187
+
188
+ def start_session(req: gr.Request) -> None:
189
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
190
+ os.makedirs(user_dir, exist_ok=True)
191
+
192
+
193
+ def end_session(req: gr.Request) -> None:
194
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
195
+ if os.path.exists(user_dir):
196
+ shutil.rmtree(user_dir)
197
+
198
+
199
+ @spaces.GPU
200
+ def preprocess_image_fn(
201
+ image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg"
202
+ ) -> tuple[Image.Image, Image.Image]:
203
+ if isinstance(image, str):
204
+ image = Image.open(image)
205
+ elif isinstance(image, np.ndarray):
206
+ image = Image.fromarray(image)
207
+
208
+ image_cache = image.copy().resize((512, 512))
209
+
210
+ bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER
211
+ image = bg_remover(image)
212
+ image = trellis_preprocess(image)
213
+
214
+ return image, image_cache
215
+
216
+
217
+ def preprocess_sam_image_fn(
218
+ image: Image.Image,
219
+ ) -> tuple[Image.Image, Image.Image]:
220
+ if isinstance(image, np.ndarray):
221
+ image = Image.fromarray(image)
222
+
223
+ sam_image = SAM_PREDICTOR.preprocess_image(image)
224
+ image_cache = Image.fromarray(sam_image).resize((512, 512))
225
+ SAM_PREDICTOR.predictor.set_image(sam_image)
226
+
227
+ return sam_image, image_cache
228
+
229
+
230
+ def active_btn_by_content(content: gr.Image) -> gr.Button:
231
+ interactive = True if content is not None else False
232
+
233
+ return gr.Button(interactive=interactive)
234
+
235
+
236
+ def active_btn_by_text_content(content: gr.Textbox) -> gr.Button:
237
+ if content is not None and len(content) > 0:
238
+ interactive = True
239
+ else:
240
+ interactive = False
241
+
242
+ return gr.Button(interactive=interactive)
243
+
244
+
245
+ def get_selected_image(
246
+ choice: str, sample1: str, sample2: str, sample3: str
247
+ ) -> str:
248
+ if choice == "sample1":
249
+ return sample1
250
+ elif choice == "sample2":
251
+ return sample2
252
+ elif choice == "sample3":
253
+ return sample3
254
+ else:
255
+ raise ValueError(f"Invalid choice: {choice}")
256
+
257
+
258
+ def get_cached_image(image_path: str) -> Image.Image:
259
+ if isinstance(image_path, Image.Image):
260
+ return image_path
261
+ return Image.open(image_path).resize((512, 512))
262
+
263
+
264
+ @spaces.GPU
265
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
266
+ return {
267
+ "gaussian": {
268
+ **gs.init_params,
269
+ "_xyz": gs._xyz.cpu().numpy(),
270
+ "_features_dc": gs._features_dc.cpu().numpy(),
271
+ "_scaling": gs._scaling.cpu().numpy(),
272
+ "_rotation": gs._rotation.cpu().numpy(),
273
+ "_opacity": gs._opacity.cpu().numpy(),
274
+ },
275
+ "mesh": {
276
+ "vertices": mesh.vertices.cpu().numpy(),
277
+ "faces": mesh.faces.cpu().numpy(),
278
+ },
279
+ }
280
+
281
+
282
+ def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
283
+ gs = Gaussian(
284
+ aabb=state["gaussian"]["aabb"],
285
+ sh_degree=state["gaussian"]["sh_degree"],
286
+ mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
287
+ scaling_bias=state["gaussian"]["scaling_bias"],
288
+ opacity_bias=state["gaussian"]["opacity_bias"],
289
+ scaling_activation=state["gaussian"]["scaling_activation"],
290
+ device=device,
291
+ )
292
+ gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
293
+ gs._features_dc = torch.tensor(
294
+ state["gaussian"]["_features_dc"], device=device
295
+ )
296
+ gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
297
+ gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
298
+ gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
299
+
300
+ mesh = edict(
301
+ vertices=torch.tensor(state["mesh"]["vertices"], device=device),
302
+ faces=torch.tensor(state["mesh"]["faces"], device=device),
303
+ )
304
+
305
+ return gs, mesh
306
+
307
+
308
+ def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
309
+ return np.random.randint(0, max_seed) if randomize_seed else seed
310
+
311
+
312
+ def select_point(
313
+ image: np.ndarray,
314
+ sel_pix: list,
315
+ point_type: str,
316
+ evt: gr.SelectData,
317
+ ):
318
+ if point_type == "foreground_point":
319
+ sel_pix.append((evt.index, 1)) # append the foreground_point
320
+ elif point_type == "background_point":
321
+ sel_pix.append((evt.index, 0)) # append the background_point
322
+ else:
323
+ sel_pix.append((evt.index, 1)) # default foreground_point
324
+
325
+ masks = SAM_PREDICTOR.generate_masks(image, sel_pix)
326
+ seg_image = SAM_PREDICTOR.get_segmented_image(image, masks)
327
+
328
+ for point, label in sel_pix:
329
+ color = (255, 0, 0) if label == 0 else (0, 255, 0)
330
+ marker_type = 1 if label == 0 else 5
331
+ cv2.drawMarker(
332
+ image,
333
+ point,
334
+ color,
335
+ markerType=marker_type,
336
+ markerSize=15,
337
+ thickness=10,
338
+ )
339
+
340
+ torch.cuda.empty_cache()
341
+
342
+ return (image, masks), seg_image
343
+
344
+
345
+ @spaces.GPU
346
+ def image_to_3d(
347
+ image: Image.Image,
348
+ seed: int,
349
+ ss_guidance_strength: float,
350
+ ss_sampling_steps: int,
351
+ slat_guidance_strength: float,
352
+ slat_sampling_steps: int,
353
+ raw_image_cache: Image.Image,
354
+ sam_image: Image.Image = None,
355
+ is_sam_image: bool = False,
356
+ req: gr.Request = None,
357
+ ) -> tuple[dict, str]:
358
+ if is_sam_image:
359
+ seg_image = filter_image_small_connected_components(sam_image)
360
+ seg_image = Image.fromarray(seg_image, mode="RGBA")
361
+ seg_image = trellis_preprocess(seg_image)
362
+ else:
363
+ seg_image = image
364
+
365
+ if isinstance(seg_image, np.ndarray):
366
+ seg_image = Image.fromarray(seg_image)
367
+
368
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
369
+ os.makedirs(output_root, exist_ok=True)
370
+ seg_image.save(f"{output_root}/seg_image.png")
371
+ raw_image_cache.save(f"{output_root}/raw_image.png")
372
+ PIPELINE.cuda()
373
+ outputs = PIPELINE.run(
374
+ seg_image,
375
+ seed=seed,
376
+ formats=["gaussian", "mesh"],
377
+ preprocess_image=False,
378
+ sparse_structure_sampler_params={
379
+ "steps": ss_sampling_steps,
380
+ "cfg_strength": ss_guidance_strength,
381
+ },
382
+ slat_sampler_params={
383
+ "steps": slat_sampling_steps,
384
+ "cfg_strength": slat_guidance_strength,
385
+ },
386
+ )
387
+ # Set to cpu for memory saving.
388
+ PIPELINE.cpu()
389
+
390
+ gs_model = outputs["gaussian"][0]
391
+ mesh_model = outputs["mesh"][0]
392
+ color_images = render_video(gs_model)["color"]
393
+ normal_images = render_video(mesh_model)["normal"]
394
+
395
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
396
+ merge_images_video(color_images, normal_images, video_path)
397
+ state = pack_state(gs_model, mesh_model)
398
+
399
+ gc.collect()
400
+ torch.cuda.empty_cache()
401
+
402
+ return state, video_path
403
+
404
+
405
+ @spaces.GPU
406
+ def extract_3d_representations(
407
+ state: dict, enable_delight: bool, texture_size: int, req: gr.Request
408
+ ):
409
+ output_root = TMP_DIR
410
+ output_root = os.path.join(output_root, str(req.session_hash))
411
+ gs_model, mesh_model = unpack_state(state, device="cuda")
412
+
413
+ mesh = postprocessing_utils.to_glb(
414
+ gs_model,
415
+ mesh_model,
416
+ simplify=0.9,
417
+ texture_size=1024,
418
+ verbose=True,
419
+ )
420
+ filename = "sample"
421
+ gs_path = os.path.join(output_root, f"{filename}_gs.ply")
422
+ gs_model.save_ply(gs_path)
423
+
424
+ # Rotate mesh and GS by 90 degrees around Z-axis.
425
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
426
+ # Addtional rotation for GS to align mesh.
427
+ gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
428
+ rot_matrix
429
+ )
430
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
431
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
432
+ GaussianOperator.resave_ply(
433
+ in_ply=gs_path,
434
+ out_ply=aligned_gs_path,
435
+ instance_pose=pose,
436
+ )
437
+
438
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
439
+ mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
440
+ mesh.export(mesh_obj_path)
441
+ mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
442
+ mesh.export(mesh_glb_path)
443
+
444
+ torch.cuda.empty_cache()
445
+
446
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
447
+
448
+
449
+ def extract_3d_representations_v2(
450
+ state: dict,
451
+ enable_delight: bool,
452
+ texture_size: int,
453
+ req: gr.Request,
454
+ ):
455
+ output_root = TMP_DIR
456
+ user_dir = os.path.join(output_root, str(req.session_hash))
457
+ gs_model, mesh_model = unpack_state(state, device="cpu")
458
+
459
+ filename = "sample"
460
+ gs_path = os.path.join(user_dir, f"{filename}_gs.ply")
461
+ gs_model.save_ply(gs_path)
462
+
463
+ # Rotate mesh and GS by 90 degrees around Z-axis.
464
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
465
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
466
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
467
+
468
+ # Addtional rotation for GS to align mesh.
469
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
470
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
471
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
472
+ GaussianOperator.resave_ply(
473
+ in_ply=gs_path,
474
+ out_ply=aligned_gs_path,
475
+ instance_pose=pose,
476
+ device="cpu",
477
+ )
478
+ color_path = os.path.join(user_dir, "color.png")
479
+ render_gs_api(
480
+ input_gs=aligned_gs_path,
481
+ output_path=color_path,
482
+ elevation=[20, -10, 60, -50],
483
+ num_images=12,
484
+ )
485
+
486
+ mesh = trimesh.Trimesh(
487
+ vertices=mesh_model.vertices.cpu().numpy(),
488
+ faces=mesh_model.faces.cpu().numpy(),
489
+ )
490
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
491
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
492
+
493
+ mesh_obj_path = os.path.join(user_dir, f"{filename}.obj")
494
+ mesh.export(mesh_obj_path)
495
+
496
+ mesh = backproject_api(
497
+ delight_model=DELIGHT,
498
+ imagesr_model=IMAGESR_MODEL,
499
+ color_path=color_path,
500
+ mesh_path=mesh_obj_path,
501
+ output_path=mesh_obj_path,
502
+ skip_fix_mesh=False,
503
+ delight=enable_delight,
504
+ texture_wh=[texture_size, texture_size],
505
+ elevation=[20, -10, 60, -50],
506
+ num_images=12,
507
+ )
508
+
509
+ mesh_glb_path = os.path.join(user_dir, f"{filename}.glb")
510
+ mesh.export(mesh_glb_path)
511
+
512
+ return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
513
+
514
+
515
+ def extract_urdf(
516
+ gs_path: str,
517
+ mesh_obj_path: str,
518
+ asset_cat_text: str,
519
+ height_range_text: str,
520
+ mass_range_text: str,
521
+ asset_version_text: str,
522
+ req: gr.Request = None,
523
+ ):
524
+ output_root = TMP_DIR
525
+ if req is not None:
526
+ output_root = os.path.join(output_root, str(req.session_hash))
527
+
528
+ # Convert to URDF and recover attrs by GPT.
529
+ filename = "sample"
530
+ urdf_convertor = URDFGenerator(
531
+ GPT_CLIENT, render_view_num=4, decompose_convex=True
532
+ )
533
+ asset_attrs = {
534
+ "version": VERSION,
535
+ "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
536
+ }
537
+ if asset_version_text:
538
+ asset_attrs["version"] = asset_version_text
539
+ if asset_cat_text:
540
+ asset_attrs["category"] = asset_cat_text.lower()
541
+ if height_range_text:
542
+ try:
543
+ min_height, max_height = map(float, height_range_text.split("-"))
544
+ asset_attrs["min_height"] = min_height
545
+ asset_attrs["max_height"] = max_height
546
+ except ValueError:
547
+ return "Invalid height input format. Use the format: min-max."
548
+ if mass_range_text:
549
+ try:
550
+ min_mass, max_mass = map(float, mass_range_text.split("-"))
551
+ asset_attrs["min_mass"] = min_mass
552
+ asset_attrs["max_mass"] = max_mass
553
+ except ValueError:
554
+ return "Invalid mass input format. Use the format: min-max."
555
+
556
+ urdf_path = urdf_convertor(
557
+ mesh_path=mesh_obj_path,
558
+ output_root=f"{output_root}/URDF_{filename}",
559
+ **asset_attrs,
560
+ )
561
+
562
+ # Rescale GS and save to URDF/mesh folder.
563
+ real_height = urdf_convertor.get_attr_from_urdf(
564
+ urdf_path, attr_name="real_height"
565
+ )
566
+ out_gs = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
567
+ GaussianOperator.resave_ply(
568
+ in_ply=gs_path,
569
+ out_ply=out_gs,
570
+ real_height=real_height,
571
+ device="cpu",
572
+ )
573
+
574
+ # Quality check and update .urdf file.
575
+ mesh_out = f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
576
+ trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
577
+ # image_paths = render_asset3d(
578
+ # mesh_path=mesh_out,
579
+ # output_root=f"{output_root}/URDF_{filename}",
580
+ # output_subdir="qa_renders",
581
+ # num_images=8,
582
+ # elevation=(30, -30),
583
+ # distance=5.5,
584
+ # )
585
+
586
+ image_dir = f"{output_root}/URDF_{filename}/{urdf_convertor.output_render_dir}/image_color" # noqa
587
+ image_paths = glob(f"{image_dir}/*.png")
588
+ images_list = []
589
+ for checker in CHECKERS:
590
+ images = image_paths
591
+ if isinstance(checker, ImageSegChecker):
592
+ images = [
593
+ f"{TMP_DIR}/{req.session_hash}/raw_image.png",
594
+ f"{TMP_DIR}/{req.session_hash}/seg_image.png",
595
+ ]
596
+ images_list.append(images)
597
+
598
+ results = BaseChecker.validate(CHECKERS, images_list)
599
+ urdf_convertor.add_quality_tag(urdf_path, results)
600
+
601
+ # Zip urdf files
602
+ urdf_zip = zip_files(
603
+ input_paths=[
604
+ f"{output_root}/URDF_{filename}/{urdf_convertor.output_mesh_dir}",
605
+ f"{output_root}/URDF_{filename}/{filename}.urdf",
606
+ ],
607
+ output_zip=f"{output_root}/urdf_{filename}.zip",
608
+ )
609
+
610
+ estimated_type = urdf_convertor.estimated_attrs["category"]
611
+ estimated_height = urdf_convertor.estimated_attrs["height"]
612
+ estimated_mass = urdf_convertor.estimated_attrs["mass"]
613
+ estimated_mu = urdf_convertor.estimated_attrs["mu"]
614
+
615
+ return (
616
+ urdf_zip,
617
+ estimated_type,
618
+ estimated_height,
619
+ estimated_mass,
620
+ estimated_mu,
621
+ )
622
+
623
+
624
+ @spaces.GPU
625
+ def text2image_fn(
626
+ prompt: str,
627
+ guidance_scale: float,
628
+ infer_step: int = 50,
629
+ ip_image: Image.Image | str = None,
630
+ ip_adapt_scale: float = 0.3,
631
+ image_wh: int | tuple[int, int] = [1024, 1024],
632
+ rmbg_tag: str = "rembg",
633
+ seed: int = None,
634
+ n_sample: int = 3,
635
+ req: gr.Request = None,
636
+ ):
637
+ if isinstance(image_wh, int):
638
+ image_wh = (image_wh, image_wh)
639
+ output_root = TMP_DIR
640
+ if req is not None:
641
+ output_root = os.path.join(output_root, str(req.session_hash))
642
+ os.makedirs(output_root, exist_ok=True)
643
+
644
+ pipeline = PIPELINE_IMG if ip_image is None else PIPELINE_IMG_IP
645
+ if ip_image is not None:
646
+ pipeline.set_ip_adapter_scale([ip_adapt_scale])
647
+
648
+ images = text2img_gen(
649
+ prompt=prompt,
650
+ n_sample=n_sample,
651
+ guidance_scale=guidance_scale,
652
+ pipeline=pipeline,
653
+ ip_image=ip_image,
654
+ image_wh=image_wh,
655
+ infer_step=infer_step,
656
+ seed=seed,
657
+ )
658
+
659
+ for idx in range(len(images)):
660
+ image = images[idx]
661
+ images[idx], _ = preprocess_image_fn(image, rmbg_tag)
662
+
663
+ save_paths = []
664
+ for idx, image in enumerate(images):
665
+ save_path = f"{output_root}/sample_{idx}.png"
666
+ image.save(save_path)
667
+ save_paths.append(save_path)
668
+
669
+ logger.info(f"Images saved to {output_root}")
670
+
671
+ gc.collect()
672
+ torch.cuda.empty_cache()
673
+
674
+ return save_paths + save_paths
675
+
676
+
677
+ @spaces.GPU
678
+ def generate_condition(mesh_path: str, req: gr.Request, uuid: str = "sample"):
679
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
680
+
681
+ _ = render_api(
682
+ mesh_path=mesh_path,
683
+ output_root=f"{output_root}/condition",
684
+ uuid=str(uuid),
685
+ )
686
+
687
+ gc.collect()
688
+ torch.cuda.empty_cache()
689
+
690
+ return None, None, None
691
+
692
+
693
+ @spaces.GPU
694
+ def generate_texture_mvimages(
695
+ prompt: str,
696
+ controlnet_cond_scale: float = 0.55,
697
+ guidance_scale: float = 9,
698
+ strength: float = 0.9,
699
+ num_inference_steps: int = 50,
700
+ seed: int = 0,
701
+ ip_adapt_scale: float = 0,
702
+ ip_img_path: str = None,
703
+ uid: str = "sample",
704
+ sub_idxs: tuple[tuple[int]] = ((0, 1, 2), (3, 4, 5)),
705
+ req: gr.Request = None,
706
+ ) -> list[str]:
707
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
708
+ use_ip_adapter = True if ip_img_path and ip_adapt_scale > 0 else False
709
+ PIPELINE_IP.set_ip_adapter_scale([ip_adapt_scale])
710
+ img_save_paths = infer_pipe(
711
+ index_file=f"{output_root}/condition/index.json",
712
+ controlnet_cond_scale=controlnet_cond_scale,
713
+ guidance_scale=guidance_scale,
714
+ strength=strength,
715
+ num_inference_steps=num_inference_steps,
716
+ ip_adapt_scale=ip_adapt_scale,
717
+ ip_img_path=ip_img_path,
718
+ uid=uid,
719
+ prompt=prompt,
720
+ save_dir=f"{output_root}/multi_view",
721
+ sub_idxs=sub_idxs,
722
+ pipeline=PIPELINE_IP if use_ip_adapter else PIPELINE,
723
+ seed=seed,
724
+ )
725
+
726
+ gc.collect()
727
+ torch.cuda.empty_cache()
728
+
729
+ return img_save_paths + img_save_paths
730
+
731
+
732
+ def backproject_texture(
733
+ mesh_path: str,
734
+ input_image: str,
735
+ texture_size: int,
736
+ uuid: str = "sample",
737
+ req: gr.Request = None,
738
+ ) -> str:
739
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
740
+ output_dir = os.path.join(output_root, "texture_mesh")
741
+ os.makedirs(output_dir, exist_ok=True)
742
+ command = [
743
+ "backproject-cli",
744
+ "--mesh_path",
745
+ mesh_path,
746
+ "--input_image",
747
+ input_image,
748
+ "--output_root",
749
+ output_dir,
750
+ "--uuid",
751
+ f"{uuid}",
752
+ "--texture_size",
753
+ str(texture_size),
754
+ "--skip_fix_mesh",
755
+ ]
756
+
757
+ _ = subprocess.run(
758
+ command, capture_output=True, text=True, encoding="utf-8"
759
+ )
760
+ output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
761
+ output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
762
+ _ = trimesh.load(output_obj_mesh).export(output_glb_mesh)
763
+
764
+ zip_file = zip_files(
765
+ input_paths=[
766
+ output_glb_mesh,
767
+ output_obj_mesh,
768
+ os.path.join(output_dir, "material.mtl"),
769
+ os.path.join(output_dir, "material_0.png"),
770
+ ],
771
+ output_zip=os.path.join(output_dir, f"{uuid}.zip"),
772
+ )
773
+
774
+ gc.collect()
775
+ torch.cuda.empty_cache()
776
+
777
+ return output_glb_mesh, output_obj_mesh, zip_file
778
+
779
+
780
+ @spaces.GPU
781
+ def backproject_texture_v2(
782
+ mesh_path: str,
783
+ input_image: str,
784
+ texture_size: int,
785
+ enable_delight: bool = True,
786
+ fix_mesh: bool = False,
787
+ uuid: str = "sample",
788
+ req: gr.Request = None,
789
+ ) -> str:
790
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
791
+ output_dir = os.path.join(output_root, "texture_mesh")
792
+ os.makedirs(output_dir, exist_ok=True)
793
+
794
+ textured_mesh = backproject_api(
795
+ delight_model=DELIGHT,
796
+ imagesr_model=IMAGESR_MODEL,
797
+ color_path=input_image,
798
+ mesh_path=mesh_path,
799
+ output_path=f"{output_dir}/{uuid}.obj",
800
+ skip_fix_mesh=not fix_mesh,
801
+ delight=enable_delight,
802
+ texture_wh=[texture_size, texture_size],
803
+ )
804
+
805
+ output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj")
806
+ output_glb_mesh = os.path.join(output_dir, f"{uuid}.glb")
807
+ _ = textured_mesh.export(output_glb_mesh)
808
+
809
+ zip_file = zip_files(
810
+ input_paths=[
811
+ output_glb_mesh,
812
+ output_obj_mesh,
813
+ os.path.join(output_dir, "material.mtl"),
814
+ os.path.join(output_dir, "material_0.png"),
815
+ ],
816
+ output_zip=os.path.join(output_dir, f"{uuid}.zip"),
817
+ )
818
+
819
+ gc.collect()
820
+ torch.cuda.empty_cache()
821
+
822
+ return output_glb_mesh, output_obj_mesh, zip_file
823
+
824
+
825
+ @spaces.GPU
826
+ def render_result_video(
827
+ mesh_path: str, video_size: int, req: gr.Request, uuid: str = ""
828
+ ) -> str:
829
+ output_root = os.path.join(TMP_DIR, str(req.session_hash))
830
+ output_dir = os.path.join(output_root, "texture_mesh")
831
+
832
+ _ = render_api(
833
+ mesh_path=mesh_path,
834
+ output_root=output_dir,
835
+ num_images=90,
836
+ elevation=[20],
837
+ with_mtl=True,
838
+ pbr_light_factor=1,
839
+ uuid=str(uuid),
840
+ gen_color_mp4=True,
841
+ gen_glonormal_mp4=True,
842
+ distance=5.5,
843
+ resolution_hw=(video_size, video_size),
844
+ )
845
+
846
+ gc.collect()
847
+ torch.cuda.empty_cache()
848
+
849
+ return f"{output_dir}/color.mp4"
embodied_gen/data/asset_converter.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import xml.etree.ElementTree as ET
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ from glob import glob
9
+ from shutil import copy
10
+
11
+ import trimesh
12
+ from scipy.spatial.transform import Rotation
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ __all__ = [
19
+ "AssetConverterFactory",
20
+ "AssetType",
21
+ "MeshtoMJCFConverter",
22
+ "MeshtoUSDConverter",
23
+ "URDFtoUSDConverter",
24
+ ]
25
+
26
+
27
+ @dataclass
28
+ class AssetType(str):
29
+ """Asset type enumeration."""
30
+
31
+ MJCF = "mjcf"
32
+ USD = "usd"
33
+ URDF = "urdf"
34
+ MESH = "mesh"
35
+
36
+
37
+ class AssetConverterBase(ABC):
38
+ """Converter abstract base class."""
39
+
40
+ @abstractmethod
41
+ def convert(self, urdf_path: str, output_path: str, **kwargs) -> str:
42
+ pass
43
+
44
+ def transform_mesh(
45
+ self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element
46
+ ) -> None:
47
+ """Apply transform to the mesh based on the origin element in URDF."""
48
+ mesh = trimesh.load(input_mesh)
49
+ rpy = list(map(float, mesh_origin.get("rpy").split(" ")))
50
+ rotation = Rotation.from_euler("xyz", rpy, degrees=False)
51
+ offset = list(map(float, mesh_origin.get("xyz").split(" ")))
52
+ mesh.vertices = (mesh.vertices @ rotation.as_matrix().T) + offset
53
+
54
+ os.makedirs(os.path.dirname(output_mesh), exist_ok=True)
55
+ _ = mesh.export(output_mesh)
56
+
57
+ return
58
+
59
+ def __enter__(self):
60
+ return self
61
+
62
+ def __exit__(self, exc_type, exc_val, exc_tb):
63
+ return False
64
+
65
+
66
+ class MeshtoMJCFConverter(AssetConverterBase):
67
+ """Convert URDF files into MJCF format."""
68
+
69
+ def __init__(
70
+ self,
71
+ **kwargs,
72
+ ) -> None:
73
+ self.kwargs = kwargs
74
+
75
+ def _copy_asset_file(self, src: str, dst: str) -> None:
76
+ if os.path.exists(dst):
77
+ return
78
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
79
+ copy(src, dst)
80
+
81
+ def add_geometry(
82
+ self,
83
+ mujoco_element: ET.Element,
84
+ link: ET.Element,
85
+ body: ET.Element,
86
+ tag: str,
87
+ input_dir: str,
88
+ output_dir: str,
89
+ mesh_name: str,
90
+ material: ET.Element | None = None,
91
+ is_collision: bool = False,
92
+ ) -> None:
93
+ """Add geometry to the MJCF body from the URDF link."""
94
+ element = link.find(tag)
95
+ geometry = element.find("geometry")
96
+ mesh = geometry.find("mesh")
97
+ filename = mesh.get("filename")
98
+ scale = mesh.get("scale", "1.0 1.0 1.0")
99
+
100
+ mesh_asset = ET.SubElement(
101
+ mujoco_element, "mesh", name=mesh_name, file=filename, scale=scale
102
+ )
103
+ geom = ET.SubElement(body, "geom", type="mesh", mesh=mesh_name)
104
+
105
+ self._copy_asset_file(
106
+ f"{input_dir}/{filename}",
107
+ f"{output_dir}/{filename}",
108
+ )
109
+
110
+ # Preprocess the mesh by applying rotation.
111
+ input_mesh = f"{input_dir}/{filename}"
112
+ output_mesh = f"{output_dir}/{filename}"
113
+ mesh_origin = element.find("origin")
114
+ if mesh_origin is not None:
115
+ self.transform_mesh(input_mesh, output_mesh, mesh_origin)
116
+
117
+ if material is not None:
118
+ geom.set("material", material.get("name"))
119
+
120
+ if is_collision:
121
+ geom.set("contype", "1")
122
+ geom.set("conaffinity", "1")
123
+ geom.set("rgba", "1 1 1 0")
124
+
125
+ def add_materials(
126
+ self,
127
+ mujoco_element: ET.Element,
128
+ link: ET.Element,
129
+ tag: str,
130
+ input_dir: str,
131
+ output_dir: str,
132
+ name: str,
133
+ reflectance: float = 0.2,
134
+ ) -> ET.Element:
135
+ """Add materials to the MJCF asset from the URDF link."""
136
+ element = link.find(tag)
137
+ geometry = element.find("geometry")
138
+ mesh = geometry.find("mesh")
139
+ filename = mesh.get("filename")
140
+ dirname = os.path.dirname(filename)
141
+
142
+ material = ET.SubElement(
143
+ mujoco_element,
144
+ "material",
145
+ name=f"material_{name}",
146
+ texture=f"texture_{name}",
147
+ reflectance=str(reflectance),
148
+ )
149
+
150
+ for path in glob(f"{input_dir}/{dirname}/*.png"):
151
+ file_name = os.path.basename(path)
152
+ self._copy_asset_file(
153
+ path,
154
+ f"{output_dir}/{dirname}/{file_name}",
155
+ )
156
+ ET.SubElement(
157
+ mujoco_element,
158
+ "texture",
159
+ name=f"texture_{name}_{os.path.splitext(file_name)[0]}",
160
+ type="2d",
161
+ file=f"{dirname}/{file_name}",
162
+ )
163
+
164
+ return material
165
+
166
+ def convert(self, urdf_path: str, mjcf_path: str):
167
+ """Convert a URDF file to MJCF format."""
168
+ tree = ET.parse(urdf_path)
169
+ root = tree.getroot()
170
+
171
+ mujoco_struct = ET.Element("mujoco")
172
+ mujoco_struct.set("model", root.get("name"))
173
+ mujoco_asset = ET.SubElement(mujoco_struct, "asset")
174
+ mujoco_worldbody = ET.SubElement(mujoco_struct, "worldbody")
175
+
176
+ input_dir = os.path.dirname(urdf_path)
177
+ output_dir = os.path.dirname(mjcf_path)
178
+ os.makedirs(output_dir, exist_ok=True)
179
+ for idx, link in enumerate(root.findall("link")):
180
+ link_name = link.get("name", "unnamed_link")
181
+ body = ET.SubElement(mujoco_worldbody, "body", name=link_name)
182
+
183
+ material = self.add_materials(
184
+ mujoco_asset,
185
+ link,
186
+ "visual",
187
+ input_dir,
188
+ output_dir,
189
+ name=str(idx),
190
+ )
191
+ self.add_geometry(
192
+ mujoco_asset,
193
+ link,
194
+ body,
195
+ "visual",
196
+ input_dir,
197
+ output_dir,
198
+ f"visual_mesh_{idx}",
199
+ material,
200
+ )
201
+ self.add_geometry(
202
+ mujoco_asset,
203
+ link,
204
+ body,
205
+ "collision",
206
+ input_dir,
207
+ output_dir,
208
+ f"collision_mesh_{idx}",
209
+ is_collision=True,
210
+ )
211
+
212
+ tree = ET.ElementTree(mujoco_struct)
213
+ ET.indent(tree, space=" ", level=0)
214
+
215
+ tree.write(mjcf_path, encoding="utf-8", xml_declaration=True)
216
+ logger.info(f"Successfully converted {urdf_path} → {mjcf_path}")
217
+
218
+
219
+ class URDFtoMJCFConverter(MeshtoMJCFConverter):
220
+ """Convert URDF files with joints to MJCF format, handling transformations from joints."""
221
+
222
+ def add_materials(
223
+ self,
224
+ mujoco_element: ET.Element,
225
+ link: ET.Element,
226
+ tag: str,
227
+ input_dir: str,
228
+ output_dir: str,
229
+ name: str,
230
+ reflectance: float = 0.2,
231
+ ) -> ET.Element:
232
+ """Add materials to the MJCF asset from the URDF link."""
233
+ element = link.find(tag)
234
+ geometry = element.find("geometry")
235
+ mesh = geometry.find("mesh")
236
+ filename = mesh.get("filename")
237
+ dirname = os.path.dirname(filename)
238
+
239
+ diffuse_texture = None
240
+ for path in glob(f"{input_dir}/{dirname}/*.png"):
241
+ file_name = os.path.basename(path)
242
+ self._copy_asset_file(
243
+ path,
244
+ f"{output_dir}/{dirname}/{file_name}",
245
+ )
246
+ texture_name = f"texture_{name}_{os.path.splitext(file_name)[0]}"
247
+ ET.SubElement(
248
+ mujoco_element,
249
+ "texture",
250
+ name=texture_name,
251
+ type="2d",
252
+ file=f"{dirname}/{file_name}",
253
+ )
254
+ if "diffuse" in file_name.lower():
255
+ diffuse_texture = texture_name
256
+
257
+ if diffuse_texture is None:
258
+ return None
259
+
260
+ material = ET.SubElement(
261
+ mujoco_element,
262
+ "material",
263
+ name=f"material_{name}",
264
+ texture=diffuse_texture,
265
+ reflectance=str(reflectance),
266
+ )
267
+
268
+ return material
269
+
270
+ def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str:
271
+ """Convert a URDF file with joints to MJCF format."""
272
+ tree = ET.parse(urdf_path)
273
+ root = tree.getroot()
274
+
275
+ mujoco_struct = ET.Element("mujoco")
276
+ mujoco_struct.set("model", root.get("name"))
277
+ mujoco_asset = ET.SubElement(mujoco_struct, "asset")
278
+ mujoco_worldbody = ET.SubElement(mujoco_struct, "worldbody")
279
+
280
+ input_dir = os.path.dirname(urdf_path)
281
+ output_dir = os.path.dirname(mjcf_path)
282
+ os.makedirs(output_dir, exist_ok=True)
283
+
284
+ # Create a dictionary to store body elements for each link
285
+ body_dict = {}
286
+
287
+ # Process all links first
288
+ for idx, link in enumerate(root.findall("link")):
289
+ link_name = link.get("name", f"unnamed_link_{idx}")
290
+ body = ET.SubElement(mujoco_worldbody, "body", name=link_name)
291
+ body_dict[link_name] = body
292
+
293
+ # Add materials and geometry
294
+ visual_element = link.find("visual")
295
+ if visual_element is not None:
296
+ material = self.add_materials(
297
+ mujoco_asset,
298
+ link,
299
+ "visual",
300
+ input_dir,
301
+ output_dir,
302
+ name=str(idx),
303
+ )
304
+ self.add_geometry(
305
+ mujoco_asset,
306
+ link,
307
+ body,
308
+ "visual",
309
+ input_dir,
310
+ output_dir,
311
+ f"visual_mesh_{idx}",
312
+ material,
313
+ )
314
+
315
+ collision_element = link.find("collision")
316
+ if collision_element is not None:
317
+ self.add_geometry(
318
+ mujoco_asset,
319
+ link,
320
+ body,
321
+ "collision",
322
+ input_dir,
323
+ output_dir,
324
+ f"collision_mesh_{idx}",
325
+ is_collision=True,
326
+ )
327
+
328
+ # Process joints to set transformations and hierarchy
329
+ for joint in root.findall("joint"):
330
+ joint_type = joint.get("type")
331
+ if joint_type != "fixed":
332
+ logger.warning(
333
+ f"Skipping non-fixed joint: {joint.get('name')}"
334
+ )
335
+ continue
336
+
337
+ parent_link = joint.find("parent").get("link")
338
+ child_link = joint.find("child").get("link")
339
+ origin = joint.find("origin")
340
+
341
+ if parent_link not in body_dict or child_link not in body_dict:
342
+ logger.warning(
343
+ f"Parent or child link not found for joint: {joint.get('name')}"
344
+ )
345
+ continue
346
+
347
+ # Move child body under parent body in MJCF hierarchy
348
+ child_body = body_dict[child_link]
349
+ mujoco_worldbody.remove(child_body)
350
+ parent_body = body_dict[parent_link]
351
+ parent_body.append(child_body)
352
+
353
+ # Apply joint origin transformation to child body
354
+ if origin is not None:
355
+ xyz = origin.get("xyz", "0 0 0")
356
+ rpy = origin.get("rpy", "0 0 0")
357
+ child_body.set("pos", xyz)
358
+ # Convert rpy to MJCF euler format (degrees)
359
+ rpy_floats = list(map(float, rpy.split()))
360
+ rotation = Rotation.from_euler(
361
+ "xyz", rpy_floats, degrees=False
362
+ )
363
+ euler_deg = rotation.as_euler("xyz", degrees=True)
364
+ child_body.set(
365
+ "euler", f"{euler_deg[0]} {euler_deg[1]} {euler_deg[2]}"
366
+ )
367
+
368
+ tree = ET.ElementTree(mujoco_struct)
369
+ ET.indent(tree, space=" ", level=0)
370
+ tree.write(mjcf_path, encoding="utf-8", xml_declaration=True)
371
+ logger.info(f"Successfully converted {urdf_path} → {mjcf_path}")
372
+
373
+ return mjcf_path
374
+
375
+
376
+ class MeshtoUSDConverter(AssetConverterBase):
377
+ """Convert Mesh file from URDF into USD format."""
378
+
379
+ DEFAULT_BIND_APIS = [
380
+ "MaterialBindingAPI",
381
+ "PhysicsMeshCollisionAPI",
382
+ "PhysicsCollisionAPI",
383
+ "PhysxCollisionAPI",
384
+ "PhysicsMassAPI",
385
+ "PhysicsRigidBodyAPI",
386
+ "PhysxRigidBodyAPI",
387
+ ]
388
+
389
+ def __init__(
390
+ self,
391
+ force_usd_conversion: bool = True,
392
+ make_instanceable: bool = False,
393
+ simulation_app=None,
394
+ **kwargs,
395
+ ):
396
+ self.usd_parms = dict(
397
+ force_usd_conversion=force_usd_conversion,
398
+ make_instanceable=make_instanceable,
399
+ **kwargs,
400
+ )
401
+ if simulation_app is not None:
402
+ self.simulation_app = simulation_app
403
+
404
+ def __enter__(self):
405
+ from isaaclab.app import AppLauncher
406
+
407
+ if not hasattr(self, "simulation_app"):
408
+ launch_args = dict(
409
+ headless=True,
410
+ no_splash=True,
411
+ fast_shutdown=True,
412
+ disable_gpu=True,
413
+ )
414
+ self.app_launcher = AppLauncher(launch_args)
415
+ self.simulation_app = self.app_launcher.app
416
+
417
+ return self
418
+
419
+ def __exit__(self, exc_type, exc_val, exc_tb):
420
+ # Close the simulation app if it was created here
421
+ if hasattr(self, "app_launcher"):
422
+ self.simulation_app.close()
423
+
424
+ if exc_val is not None:
425
+ logger.error(f"Exception occurred: {exc_val}.")
426
+
427
+ return False
428
+
429
+ def convert(self, urdf_path: str, output_file: str):
430
+ """Convert a URDF file to USD and post-process collision meshes."""
431
+ from isaaclab.sim.converters import MeshConverter, MeshConverterCfg
432
+ from pxr import PhysxSchema, Sdf, Usd, UsdShade
433
+
434
+ tree = ET.parse(urdf_path)
435
+ root = tree.getroot()
436
+ mesh_file = root.find("link/visual/geometry/mesh").get("filename")
437
+ input_mesh = os.path.join(os.path.dirname(urdf_path), mesh_file)
438
+ output_dir = os.path.abspath(os.path.dirname(output_file))
439
+ output_mesh = f"{output_dir}/mesh/{os.path.basename(mesh_file)}"
440
+ mesh_origin = root.find("link/visual/origin")
441
+ if mesh_origin is not None:
442
+ self.transform_mesh(input_mesh, output_mesh, mesh_origin)
443
+
444
+ cfg = MeshConverterCfg(
445
+ asset_path=output_mesh,
446
+ usd_dir=output_dir,
447
+ usd_file_name=os.path.basename(output_file),
448
+ **self.usd_parms,
449
+ )
450
+ urdf_converter = MeshConverter(cfg)
451
+ usd_path = urdf_converter.usd_path
452
+
453
+ stage = Usd.Stage.Open(usd_path)
454
+ layer = stage.GetRootLayer()
455
+ with Usd.EditContext(stage, layer):
456
+ for prim in stage.Traverse():
457
+ # Change texture path to relative path.
458
+ if prim.GetName() == "material_0":
459
+ shader = UsdShade.Shader(prim).GetInput("diffuse_texture")
460
+ if shader.Get() is not None:
461
+ relative_path = shader.Get().path.replace(
462
+ f"{output_dir}/", ""
463
+ )
464
+ shader.Set(Sdf.AssetPath(relative_path))
465
+
466
+ # Add convex decomposition collision and set ShrinkWrap.
467
+ elif prim.GetName() == "mesh":
468
+ approx_attr = prim.GetAttribute("physics:approximation")
469
+ if not approx_attr:
470
+ approx_attr = prim.CreateAttribute(
471
+ "physics:approximation", Sdf.ValueTypeNames.Token
472
+ )
473
+ approx_attr.Set("convexDecomposition")
474
+
475
+ physx_conv_api = (
476
+ PhysxSchema.PhysxConvexDecompositionCollisionAPI.Apply(
477
+ prim
478
+ )
479
+ )
480
+ physx_conv_api.GetShrinkWrapAttr().Set(True)
481
+
482
+ api_schemas = prim.GetMetadata("apiSchemas")
483
+ if api_schemas is None:
484
+ api_schemas = Sdf.TokenListOp()
485
+
486
+ api_list = list(api_schemas.GetAddedOrExplicitItems())
487
+ for api in self.DEFAULT_BIND_APIS:
488
+ if api not in api_list:
489
+ api_list.append(api)
490
+
491
+ api_schemas.appendedItems = api_list
492
+ prim.SetMetadata("apiSchemas", api_schemas)
493
+
494
+ layer.Save()
495
+ logger.info(f"Successfully converted {urdf_path} → {usd_path}")
496
+
497
+
498
+ class URDFtoUSDConverter(MeshtoUSDConverter):
499
+ """Convert URDF files into USD format.
500
+
501
+ Args:
502
+ fix_base (bool): Whether to fix the base link.
503
+ merge_fixed_joints (bool): Whether to merge fixed joints.
504
+ make_instanceable (bool): Whether to make prims instanceable.
505
+ force_usd_conversion (bool): Force conversion to USD.
506
+ collision_from_visuals (bool): Generate collisions from visuals if not provided.
507
+ """
508
+
509
+ def __init__(
510
+ self,
511
+ fix_base: bool = False,
512
+ merge_fixed_joints: bool = False,
513
+ make_instanceable: bool = True,
514
+ force_usd_conversion: bool = True,
515
+ collision_from_visuals: bool = True,
516
+ joint_drive=None,
517
+ rotate_wxyz: tuple[float] | None = None,
518
+ simulation_app=None,
519
+ **kwargs,
520
+ ):
521
+ self.usd_parms = dict(
522
+ fix_base=fix_base,
523
+ merge_fixed_joints=merge_fixed_joints,
524
+ make_instanceable=make_instanceable,
525
+ force_usd_conversion=force_usd_conversion,
526
+ collision_from_visuals=collision_from_visuals,
527
+ joint_drive=joint_drive,
528
+ **kwargs,
529
+ )
530
+ self.rotate_wxyz = rotate_wxyz
531
+ if simulation_app is not None:
532
+ self.simulation_app = simulation_app
533
+
534
+ def convert(self, urdf_path: str, output_file: str):
535
+ """Convert a URDF file to USD and post-process collision meshes."""
536
+ from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg
537
+ from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom
538
+
539
+ cfg = UrdfConverterCfg(
540
+ asset_path=urdf_path,
541
+ usd_dir=os.path.abspath(os.path.dirname(output_file)),
542
+ usd_file_name=os.path.basename(output_file),
543
+ **self.usd_parms,
544
+ )
545
+
546
+ urdf_converter = UrdfConverter(cfg)
547
+ usd_path = urdf_converter.usd_path
548
+
549
+ stage = Usd.Stage.Open(usd_path)
550
+ layer = stage.GetRootLayer()
551
+ with Usd.EditContext(stage, layer):
552
+ for prim in stage.Traverse():
553
+ if prim.GetName() == "collisions":
554
+ approx_attr = prim.GetAttribute("physics:approximation")
555
+ if not approx_attr:
556
+ approx_attr = prim.CreateAttribute(
557
+ "physics:approximation", Sdf.ValueTypeNames.Token
558
+ )
559
+ approx_attr.Set("convexDecomposition")
560
+
561
+ physx_conv_api = (
562
+ PhysxSchema.PhysxConvexDecompositionCollisionAPI.Apply(
563
+ prim
564
+ )
565
+ )
566
+ physx_conv_api.GetShrinkWrapAttr().Set(True)
567
+
568
+ api_schemas = prim.GetMetadata("apiSchemas")
569
+ if api_schemas is None:
570
+ api_schemas = Sdf.TokenListOp()
571
+
572
+ api_list = list(api_schemas.GetAddedOrExplicitItems())
573
+ for api in self.DEFAULT_BIND_APIS:
574
+ if api not in api_list:
575
+ api_list.append(api)
576
+
577
+ api_schemas.appendedItems = api_list
578
+ prim.SetMetadata("apiSchemas", api_schemas)
579
+
580
+ if self.rotate_wxyz is not None:
581
+ inner_prim = next(
582
+ p
583
+ for p in stage.GetDefaultPrim().GetChildren()
584
+ if p.IsA(UsdGeom.Xform)
585
+ )
586
+ xformable = UsdGeom.Xformable(inner_prim)
587
+ xformable.ClearXformOpOrder()
588
+ orient_op = xformable.AddOrientOp(UsdGeom.XformOp.PrecisionDouble)
589
+ orient_op.Set(Gf.Quatd(*self.rotate_wxyz))
590
+
591
+ layer.Save()
592
+ logger.info(f"Successfully converted {urdf_path} → {usd_path}")
593
+
594
+
595
+ class AssetConverterFactory:
596
+ """Factory class for creating asset converters based on target and source types."""
597
+
598
+ @staticmethod
599
+ def create(
600
+ target_type: AssetType, source_type: AssetType = "urdf", **kwargs
601
+ ) -> AssetConverterBase:
602
+ """Create an asset converter instance based on target and source types."""
603
+ if target_type == AssetType.MJCF and source_type == AssetType.URDF:
604
+ converter = MeshtoMJCFConverter(**kwargs)
605
+ elif target_type == AssetType.USD and source_type == AssetType.URDF:
606
+ converter = URDFtoUSDConverter(**kwargs)
607
+ elif target_type == AssetType.USD and source_type == AssetType.MESH:
608
+ converter = MeshtoUSDConverter(**kwargs)
609
+ else:
610
+ raise ValueError(
611
+ f"Unsupported converter type: {source_type} -> {target_type}."
612
+ )
613
+
614
+ return converter
615
+
616
+
617
+ if __name__ == "__main__":
618
+ # # target_asset_type = AssetType.MJCF
619
+ # target_asset_type = AssetType.USD
620
+
621
+ # urdf_paths = [
622
+ # "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf",
623
+ # ]
624
+
625
+ # if target_asset_type == AssetType.MJCF:
626
+ # output_files = [
627
+ # "outputs/embodiedgen_assets/demo_assets/remote_control/mjcf/remote_control.mjcf",
628
+ # ]
629
+ # asset_converter = AssetConverterFactory.create(
630
+ # target_type=AssetType.MJCF,
631
+ # source_type=AssetType.URDF,
632
+ # )
633
+
634
+ # elif target_asset_type == AssetType.USD:
635
+ # output_files = [
636
+ # "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd",
637
+ # ]
638
+ # asset_converter = AssetConverterFactory.create(
639
+ # target_type=AssetType.USD,
640
+ # source_type=AssetType.MESH,
641
+ # )
642
+
643
+ # with asset_converter:
644
+ # for urdf_path, output_file in zip(urdf_paths, output_files):
645
+ # asset_converter.convert(urdf_path, output_file)
646
+
647
+ # urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf"
648
+ # output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd"
649
+
650
+ # asset_converter = AssetConverterFactory.create(
651
+ # target_type=AssetType.USD,
652
+ # source_type=AssetType.URDF,
653
+ # rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis
654
+ # )
655
+
656
+ # with asset_converter:
657
+ # asset_converter.convert(urdf_path, output_file)
658
+
659
+ urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_simple_solve_nos_i_urdf/export_scene/scene.urdf"
660
+ output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_simple_solve_nos_i_urdf/mjcf/scene.urdf"
661
+ asset_converter = URDFtoMJCFConverter()
662
+ with asset_converter:
663
+ asset_converter.convert(urdf_path, output_file)
embodied_gen/data/backproject.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import logging
20
+ import math
21
+ import os
22
+ from typing import List, Literal, Union
23
+
24
+ import cv2
25
+ import numpy as np
26
+ import nvdiffrast.torch as dr
27
+ import torch
28
+ import trimesh
29
+ import utils3d
30
+ import xatlas
31
+ from tqdm import tqdm
32
+ from embodied_gen.data.mesh_operator import MeshFixer
33
+ from embodied_gen.data.utils import (
34
+ CameraSetting,
35
+ get_images_from_grid,
36
+ init_kal_camera,
37
+ normalize_vertices_array,
38
+ post_process_texture,
39
+ save_mesh_with_mtl,
40
+ )
41
+ from embodied_gen.models.delight_model import DelightingModel
42
+
43
+ logging.basicConfig(
44
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
45
+ )
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ class TextureBaker(object):
50
+ """Baking textures onto a mesh from multiple observations.
51
+
52
+ This class take 3D mesh data, camera settings and texture baking parameters
53
+ to generate texture map by projecting images to the mesh from diff views.
54
+ It supports both a fast texture baking approach and a more optimized method
55
+ with total variation regularization.
56
+
57
+ Attributes:
58
+ vertices (torch.Tensor): The vertices of the mesh.
59
+ faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
60
+ uvs (torch.Tensor): The UV coordinates of the mesh.
61
+ camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
62
+ device (str): The device to run computations on ("cpu" or "cuda").
63
+ w2cs (torch.Tensor): World-to-camera transformation matrices.
64
+ projections (torch.Tensor): Camera projection matrices.
65
+
66
+ Example:
67
+ >>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
68
+ >>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
69
+ >>> images = get_images_from_grid(args.color_path, image_size)
70
+ >>> texture = texture_backer.bake_texture(
71
+ ... images, texture_size=args.texture_size, mode=args.baker_mode
72
+ ... )
73
+ >>> texture = post_process_texture(texture)
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ vertices: np.ndarray,
79
+ faces: np.ndarray,
80
+ uvs: np.ndarray,
81
+ camera_params: CameraSetting,
82
+ device: str = "cuda",
83
+ ) -> None:
84
+ self.vertices = (
85
+ torch.tensor(vertices, device=device)
86
+ if isinstance(vertices, np.ndarray)
87
+ else vertices.to(device)
88
+ )
89
+ self.faces = (
90
+ torch.tensor(faces.astype(np.int32), device=device)
91
+ if isinstance(faces, np.ndarray)
92
+ else faces.to(device)
93
+ )
94
+ self.uvs = (
95
+ torch.tensor(uvs, device=device)
96
+ if isinstance(uvs, np.ndarray)
97
+ else uvs.to(device)
98
+ )
99
+ self.camera_params = camera_params
100
+ self.device = device
101
+
102
+ camera = init_kal_camera(camera_params)
103
+ matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
104
+ matrix_mv = kaolin_to_opencv_view(matrix_mv)
105
+ matrix_p = (
106
+ camera.intrinsics.projection_matrix()
107
+ ) # (n_cam 4 4) cam2pixel
108
+ self.w2cs = matrix_mv.to(self.device)
109
+ self.projections = matrix_p.to(self.device)
110
+
111
+ @staticmethod
112
+ def parametrize_mesh(
113
+ vertices: np.array, faces: np.array
114
+ ) -> Union[np.array, np.array, np.array]:
115
+ vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
116
+
117
+ vertices = vertices[vmapping]
118
+ faces = indices
119
+
120
+ return vertices, faces, uvs
121
+
122
+ def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
123
+ texture = torch.zeros(
124
+ (texture_size * texture_size, 3), dtype=torch.float32
125
+ ).cuda()
126
+ texture_weights = torch.zeros(
127
+ (texture_size * texture_size), dtype=torch.float32
128
+ ).cuda()
129
+ rastctx = utils3d.torch.RastContext(backend="cuda")
130
+ for observation, w2c, projection in tqdm(
131
+ zip(observations, w2cs, projections),
132
+ total=len(observations),
133
+ desc="Texture baking (fast)",
134
+ ):
135
+ with torch.no_grad():
136
+ rast = utils3d.torch.rasterize_triangle_faces(
137
+ rastctx,
138
+ self.vertices[None],
139
+ self.faces,
140
+ observation.shape[1],
141
+ observation.shape[0],
142
+ uv=self.uvs[None],
143
+ view=w2c,
144
+ projection=projection,
145
+ )
146
+ uv_map = rast["uv"][0].detach().flip(0)
147
+ mask = rast["mask"][0].detach().bool() & masks[0]
148
+
149
+ # nearest neighbor interpolation
150
+ uv_map = (uv_map * texture_size).floor().long()
151
+ obs = observation[mask]
152
+ uv_map = uv_map[mask]
153
+ idx = (
154
+ uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
155
+ )
156
+ texture = texture.scatter_add(
157
+ 0, idx.view(-1, 1).expand(-1, 3), obs
158
+ )
159
+ texture_weights = texture_weights.scatter_add(
160
+ 0,
161
+ idx,
162
+ torch.ones(
163
+ (obs.shape[0]), dtype=torch.float32, device=texture.device
164
+ ),
165
+ )
166
+
167
+ mask = texture_weights > 0
168
+ texture[mask] /= texture_weights[mask][:, None]
169
+ texture = np.clip(
170
+ texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
171
+ 0,
172
+ 255,
173
+ ).astype(np.uint8)
174
+
175
+ # inpaint
176
+ mask = (
177
+ (texture_weights == 0)
178
+ .cpu()
179
+ .numpy()
180
+ .astype(np.uint8)
181
+ .reshape(texture_size, texture_size)
182
+ )
183
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
184
+
185
+ return texture
186
+
187
+ def _bake_opt(
188
+ self,
189
+ observations,
190
+ w2cs,
191
+ projections,
192
+ texture_size,
193
+ lambda_tv,
194
+ masks,
195
+ total_steps,
196
+ ):
197
+ rastctx = utils3d.torch.RastContext(backend="cuda")
198
+ observations = [observations.flip(0) for observations in observations]
199
+ masks = [m.flip(0) for m in masks]
200
+ _uv = []
201
+ _uv_dr = []
202
+ for observation, w2c, projection in tqdm(
203
+ zip(observations, w2cs, projections),
204
+ total=len(w2cs),
205
+ ):
206
+ with torch.no_grad():
207
+ rast = utils3d.torch.rasterize_triangle_faces(
208
+ rastctx,
209
+ self.vertices[None],
210
+ self.faces,
211
+ observation.shape[1],
212
+ observation.shape[0],
213
+ uv=self.uvs[None],
214
+ view=w2c,
215
+ projection=projection,
216
+ )
217
+ _uv.append(rast["uv"].detach())
218
+ _uv_dr.append(rast["uv_dr"].detach())
219
+
220
+ texture = torch.nn.Parameter(
221
+ torch.zeros(
222
+ (1, texture_size, texture_size, 3), dtype=torch.float32
223
+ ).cuda()
224
+ )
225
+ optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
226
+
227
+ def cosine_anealing(step, total_steps, start_lr, end_lr):
228
+ return end_lr + 0.5 * (start_lr - end_lr) * (
229
+ 1 + np.cos(np.pi * step / total_steps)
230
+ )
231
+
232
+ def tv_loss(texture):
233
+ return torch.nn.functional.l1_loss(
234
+ texture[:, :-1, :, :], texture[:, 1:, :, :]
235
+ ) + torch.nn.functional.l1_loss(
236
+ texture[:, :, :-1, :], texture[:, :, 1:, :]
237
+ )
238
+
239
+ with tqdm(total=total_steps, desc="Texture baking") as pbar:
240
+ for step in range(total_steps):
241
+ optimizer.zero_grad()
242
+ selected = np.random.randint(0, len(w2cs))
243
+ uv, uv_dr, observation, mask = (
244
+ _uv[selected],
245
+ _uv_dr[selected],
246
+ observations[selected],
247
+ masks[selected],
248
+ )
249
+ render = dr.texture(texture, uv, uv_dr)[0]
250
+ loss = torch.nn.functional.l1_loss(
251
+ render[mask], observation[mask]
252
+ )
253
+ if lambda_tv > 0:
254
+ loss += lambda_tv * tv_loss(texture)
255
+ loss.backward()
256
+ optimizer.step()
257
+
258
+ optimizer.param_groups[0]["lr"] = cosine_anealing(
259
+ step, total_steps, 1e-2, 1e-5
260
+ )
261
+ pbar.set_postfix({"loss": loss.item()})
262
+ pbar.update()
263
+ texture = np.clip(
264
+ texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
265
+ ).astype(np.uint8)
266
+ mask = 1 - utils3d.torch.rasterize_triangle_faces(
267
+ rastctx,
268
+ (self.uvs * 2 - 1)[None],
269
+ self.faces,
270
+ texture_size,
271
+ texture_size,
272
+ )["mask"][0].detach().cpu().numpy().astype(np.uint8)
273
+ texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
274
+
275
+ return texture
276
+
277
+ def bake_texture(
278
+ self,
279
+ images: List[np.array],
280
+ texture_size: int = 1024,
281
+ mode: Literal["fast", "opt"] = "opt",
282
+ lambda_tv: float = 1e-2,
283
+ opt_step: int = 2000,
284
+ ):
285
+ masks = [np.any(img > 0, axis=-1) for img in images]
286
+ masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
287
+ images = [
288
+ torch.tensor(obs / 255.0).float().to(self.device) for obs in images
289
+ ]
290
+
291
+ if mode == "fast":
292
+ return self._bake_fast(
293
+ images, self.w2cs, self.projections, texture_size, masks
294
+ )
295
+ elif mode == "opt":
296
+ return self._bake_opt(
297
+ images,
298
+ self.w2cs,
299
+ self.projections,
300
+ texture_size,
301
+ lambda_tv,
302
+ masks,
303
+ opt_step,
304
+ )
305
+ else:
306
+ raise ValueError(f"Unknown mode: {mode}")
307
+
308
+
309
+ def kaolin_to_opencv_view(raw_matrix):
310
+ R_orig = raw_matrix[:, :3, :3]
311
+ t_orig = raw_matrix[:, :3, 3]
312
+
313
+ R_target = torch.zeros_like(R_orig)
314
+ R_target[:, :, 0] = R_orig[:, :, 2]
315
+ R_target[:, :, 1] = R_orig[:, :, 0]
316
+ R_target[:, :, 2] = R_orig[:, :, 1]
317
+
318
+ t_target = t_orig
319
+
320
+ target_matrix = (
321
+ torch.eye(4, device=raw_matrix.device)
322
+ .unsqueeze(0)
323
+ .repeat(raw_matrix.size(0), 1, 1)
324
+ )
325
+ target_matrix[:, :3, :3] = R_target
326
+ target_matrix[:, :3, 3] = t_target
327
+
328
+ return target_matrix
329
+
330
+
331
+ def parse_args():
332
+ parser = argparse.ArgumentParser(description="Render settings")
333
+
334
+ parser.add_argument(
335
+ "--mesh_path",
336
+ type=str,
337
+ nargs="+",
338
+ required=True,
339
+ help="Paths to the mesh files for rendering.",
340
+ )
341
+ parser.add_argument(
342
+ "--color_path",
343
+ type=str,
344
+ nargs="+",
345
+ required=True,
346
+ help="Paths to the mesh files for rendering.",
347
+ )
348
+ parser.add_argument(
349
+ "--output_root",
350
+ type=str,
351
+ default="./outputs",
352
+ help="Root directory for output",
353
+ )
354
+ parser.add_argument(
355
+ "--uuid",
356
+ type=str,
357
+ nargs="+",
358
+ default=None,
359
+ help="uuid for rendering saving.",
360
+ )
361
+ parser.add_argument(
362
+ "--num_images", type=int, default=6, help="Number of images to render."
363
+ )
364
+ parser.add_argument(
365
+ "--elevation",
366
+ type=float,
367
+ nargs="+",
368
+ default=[20.0, -10.0],
369
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
370
+ )
371
+ parser.add_argument(
372
+ "--distance",
373
+ type=float,
374
+ default=5,
375
+ help="Camera distance (default: 5)",
376
+ )
377
+ parser.add_argument(
378
+ "--resolution_hw",
379
+ type=int,
380
+ nargs=2,
381
+ default=(512, 512),
382
+ help="Resolution of the output images (default: (512, 512))",
383
+ )
384
+ parser.add_argument(
385
+ "--fov",
386
+ type=float,
387
+ default=30,
388
+ help="Field of view in degrees (default: 30)",
389
+ )
390
+ parser.add_argument(
391
+ "--device",
392
+ type=str,
393
+ choices=["cpu", "cuda"],
394
+ default="cuda",
395
+ help="Device to run on (default: `cuda`)",
396
+ )
397
+ parser.add_argument(
398
+ "--texture_size",
399
+ type=int,
400
+ default=1024,
401
+ help="Texture size for texture baking (default: 1024)",
402
+ )
403
+ parser.add_argument(
404
+ "--baker_mode",
405
+ type=str,
406
+ default="opt",
407
+ help="Texture baking mode, `fast` or `opt` (default: opt)",
408
+ )
409
+ parser.add_argument(
410
+ "--opt_step",
411
+ type=int,
412
+ default=2500,
413
+ help="Optimization steps for texture baking (default: 2500)",
414
+ )
415
+ parser.add_argument(
416
+ "--mesh_sipmlify_ratio",
417
+ type=float,
418
+ default=0.9,
419
+ help="Mesh simplification ratio (default: 0.9)",
420
+ )
421
+ parser.add_argument(
422
+ "--no_coor_trans",
423
+ action="store_true",
424
+ help="Do not transform the asset coordinate system.",
425
+ )
426
+ parser.add_argument(
427
+ "--delight", action="store_true", help="Use delighting model."
428
+ )
429
+ parser.add_argument(
430
+ "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
431
+ )
432
+
433
+ args = parser.parse_args()
434
+
435
+ if args.uuid is None:
436
+ args.uuid = []
437
+ for path in args.mesh_path:
438
+ uuid = os.path.basename(path).split(".")[0]
439
+ args.uuid.append(uuid)
440
+
441
+ return args
442
+
443
+
444
+ def entrypoint() -> None:
445
+ args = parse_args()
446
+ camera_params = CameraSetting(
447
+ num_images=args.num_images,
448
+ elevation=args.elevation,
449
+ distance=args.distance,
450
+ resolution_hw=args.resolution_hw,
451
+ fov=math.radians(args.fov),
452
+ device=args.device,
453
+ )
454
+
455
+ for mesh_path, uuid, img_path in zip(
456
+ args.mesh_path, args.uuid, args.color_path
457
+ ):
458
+ mesh = trimesh.load(mesh_path)
459
+ if isinstance(mesh, trimesh.Scene):
460
+ mesh = mesh.dump(concatenate=True)
461
+ vertices, scale, center = normalize_vertices_array(mesh.vertices)
462
+
463
+ if not args.no_coor_trans:
464
+ x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
465
+ z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
466
+ vertices = vertices @ x_rot
467
+ vertices = vertices @ z_rot
468
+
469
+ faces = mesh.faces.astype(np.int32)
470
+ vertices = vertices.astype(np.float32)
471
+
472
+ if not args.skip_fix_mesh:
473
+ mesh_fixer = MeshFixer(vertices, faces, args.device)
474
+ vertices, faces = mesh_fixer(
475
+ filter_ratio=args.mesh_sipmlify_ratio,
476
+ max_hole_size=0.04,
477
+ resolution=1024,
478
+ num_views=1000,
479
+ norm_mesh_ratio=0.5,
480
+ )
481
+
482
+ vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
483
+ texture_backer = TextureBaker(
484
+ vertices,
485
+ faces,
486
+ uvs,
487
+ camera_params,
488
+ )
489
+ images = get_images_from_grid(
490
+ img_path, img_size=camera_params.resolution_hw[0]
491
+ )
492
+ if args.delight:
493
+ delight_model = DelightingModel()
494
+ images = [delight_model(img) for img in images]
495
+
496
+ images = [np.array(img) for img in images]
497
+ texture = texture_backer.bake_texture(
498
+ images=[img[..., :3] for img in images],
499
+ texture_size=args.texture_size,
500
+ mode=args.baker_mode,
501
+ opt_step=args.opt_step,
502
+ )
503
+ texture = post_process_texture(texture)
504
+
505
+ if not args.no_coor_trans:
506
+ vertices = vertices @ np.linalg.inv(z_rot)
507
+ vertices = vertices @ np.linalg.inv(x_rot)
508
+ vertices = vertices / scale
509
+ vertices = vertices + center
510
+
511
+ output_path = os.path.join(args.output_root, f"{uuid}.obj")
512
+ mesh = save_mesh_with_mtl(vertices, faces, uvs, texture, output_path)
513
+
514
+ return
515
+
516
+
517
+ if __name__ == "__main__":
518
+ entrypoint()
embodied_gen/data/backproject_v2.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import logging
20
+ import math
21
+ import os
22
+
23
+ import cv2
24
+ import numpy as np
25
+ import nvdiffrast.torch as dr
26
+ import spaces
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import trimesh
30
+ import xatlas
31
+ from PIL import Image
32
+ from embodied_gen.data.mesh_operator import MeshFixer
33
+ from embodied_gen.data.utils import (
34
+ CameraSetting,
35
+ DiffrastRender,
36
+ as_list,
37
+ get_images_from_grid,
38
+ init_kal_camera,
39
+ normalize_vertices_array,
40
+ post_process_texture,
41
+ save_mesh_with_mtl,
42
+ )
43
+ from embodied_gen.models.delight_model import DelightingModel
44
+ from embodied_gen.models.sr_model import ImageRealESRGAN
45
+ from embodied_gen.utils.process_media import vcat_pil_images
46
+
47
+ logging.basicConfig(
48
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
49
+ )
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ __all__ = [
54
+ "TextureBacker",
55
+ ]
56
+
57
+
58
+ def _transform_vertices(
59
+ mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False
60
+ ) -> torch.Tensor:
61
+ """Transform 3D vertices using a projection matrix."""
62
+ t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype)
63
+ if pos.size(-1) == 3:
64
+ pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1)
65
+
66
+ result = pos @ t_mtx.T
67
+
68
+ return result if keepdim else result.unsqueeze(0)
69
+
70
+
71
+ def _bilinear_interpolation_scattering(
72
+ image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor
73
+ ) -> torch.Tensor:
74
+ """Bilinear interpolation scattering for grid-based value accumulation."""
75
+ device = values.device
76
+ dtype = values.dtype
77
+ C = values.shape[-1]
78
+
79
+ indices = coords * torch.tensor(
80
+ [image_h - 1, image_w - 1], dtype=dtype, device=device
81
+ )
82
+ i, j = indices.unbind(-1)
83
+
84
+ i0, j0 = (
85
+ indices.floor()
86
+ .long()
87
+ .clamp(0, image_h - 2)
88
+ .clamp(0, image_w - 2)
89
+ .unbind(-1)
90
+ )
91
+ i1, j1 = i0 + 1, j0 + 1
92
+
93
+ w_i = i - i0.float()
94
+ w_j = j - j0.float()
95
+ weights = torch.stack(
96
+ [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j],
97
+ dim=1,
98
+ )
99
+
100
+ indices_comb = torch.stack(
101
+ [
102
+ torch.stack([i0, j0], dim=1),
103
+ torch.stack([i0, j1], dim=1),
104
+ torch.stack([i1, j0], dim=1),
105
+ torch.stack([i1, j1], dim=1),
106
+ ],
107
+ dim=1,
108
+ )
109
+
110
+ grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype)
111
+ cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype)
112
+
113
+ for k in range(4):
114
+ idx = indices_comb[:, k]
115
+ w = weights[:, k].unsqueeze(-1)
116
+
117
+ stride = torch.tensor([image_w, 1], device=device, dtype=torch.long)
118
+ flat_idx = (idx * stride).sum(-1)
119
+
120
+ grid.view(-1, C).scatter_add_(
121
+ 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w
122
+ )
123
+ cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w)
124
+
125
+ mask = cnt.squeeze(-1) > 0
126
+ grid[mask] = grid[mask] / cnt[mask].repeat(1, C)
127
+
128
+ return grid
129
+
130
+
131
+ def _texture_inpaint_smooth(
132
+ texture: np.ndarray,
133
+ mask: np.ndarray,
134
+ vertices: np.ndarray,
135
+ faces: np.ndarray,
136
+ uv_map: np.ndarray,
137
+ ) -> tuple[np.ndarray, np.ndarray]:
138
+ """Perform texture inpainting using vertex-based color propagation."""
139
+ image_h, image_w, C = texture.shape
140
+ N = vertices.shape[0]
141
+
142
+ # Initialize vertex data structures
143
+ vtx_mask = np.zeros(N, dtype=np.float32)
144
+ vtx_colors = np.zeros((N, C), dtype=np.float32)
145
+ unprocessed = []
146
+ adjacency = [[] for _ in range(N)]
147
+
148
+ # Build adjacency graph and initial color assignment
149
+ for face_idx in range(faces.shape[0]):
150
+ for k in range(3):
151
+ uv_idx_k = faces[face_idx, k]
152
+ v_idx = faces[face_idx, k]
153
+
154
+ # Convert UV to pixel coordinates with boundary clamping
155
+ u = np.clip(
156
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
157
+ )
158
+ v = np.clip(
159
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
160
+ 0,
161
+ image_h - 1,
162
+ )
163
+
164
+ if mask[v, u]:
165
+ vtx_mask[v_idx] = 1.0
166
+ vtx_colors[v_idx] = texture[v, u]
167
+ elif v_idx not in unprocessed:
168
+ unprocessed.append(v_idx)
169
+
170
+ # Build undirected adjacency graph
171
+ neighbor = faces[face_idx, (k + 1) % 3]
172
+ if neighbor not in adjacency[v_idx]:
173
+ adjacency[v_idx].append(neighbor)
174
+ if v_idx not in adjacency[neighbor]:
175
+ adjacency[neighbor].append(v_idx)
176
+
177
+ # Color propagation with dynamic stopping
178
+ remaining_iters, prev_count = 2, 0
179
+ while remaining_iters > 0:
180
+ current_unprocessed = []
181
+
182
+ for v_idx in unprocessed:
183
+ valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0]
184
+ if not valid_neighbors:
185
+ current_unprocessed.append(v_idx)
186
+ continue
187
+
188
+ # Calculate inverse square distance weights
189
+ neighbors_pos = vertices[valid_neighbors]
190
+ dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1)
191
+ weights = 1 / np.maximum(dist_sq, 1e-8)
192
+
193
+ vtx_colors[v_idx] = np.average(
194
+ vtx_colors[valid_neighbors], weights=weights, axis=0
195
+ )
196
+ vtx_mask[v_idx] = 1.0
197
+
198
+ # Update iteration control
199
+ if len(current_unprocessed) == prev_count:
200
+ remaining_iters -= 1
201
+ else:
202
+ remaining_iters = min(remaining_iters + 1, 2)
203
+ prev_count = len(current_unprocessed)
204
+ unprocessed = current_unprocessed
205
+
206
+ # Generate output texture
207
+ inpainted_texture, updated_mask = texture.copy(), mask.copy()
208
+ for face_idx in range(faces.shape[0]):
209
+ for k in range(3):
210
+ v_idx = faces[face_idx, k]
211
+ if not vtx_mask[v_idx]:
212
+ continue
213
+
214
+ # UV coordinate conversion
215
+ uv_idx_k = faces[face_idx, k]
216
+ u = np.clip(
217
+ int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1
218
+ )
219
+ v = np.clip(
220
+ int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))),
221
+ 0,
222
+ image_h - 1,
223
+ )
224
+
225
+ inpainted_texture[v, u] = vtx_colors[v_idx]
226
+ updated_mask[v, u] = 255
227
+
228
+ return inpainted_texture, updated_mask
229
+
230
+
231
+ class TextureBacker:
232
+ """Texture baking pipeline for multi-view projection and fusion.
233
+
234
+ This class performs UV-based texture generation for a 3D mesh using
235
+ multi-view color images, depth, and normal information. The pipeline
236
+ includes mesh normalization and UV unwrapping, visibility-aware
237
+ back-projection, confidence-weighted texture fusion, and inpainting
238
+ of missing texture regions.
239
+
240
+ Args:
241
+ camera_params (CameraSetting): Camera intrinsics and extrinsics used
242
+ for rendering each view.
243
+ view_weights (list[float]): A list of weights for each view, used
244
+ to blend confidence maps during texture fusion.
245
+ render_wh (tuple[int, int], optional): Resolution (width, height) for
246
+ intermediate rendering passes. Defaults to (2048, 2048).
247
+ texture_wh (tuple[int, int], optional): Output texture resolution
248
+ (width, height). Defaults to (2048, 2048).
249
+ bake_angle_thresh (int, optional): Maximum angle (in degrees) between
250
+ view direction and surface normal for projection to be considered valid.
251
+ Defaults to 75.
252
+ mask_thresh (float, optional): Threshold applied to visibility masks
253
+ during rendering. Defaults to 0.5.
254
+ smooth_texture (bool, optional): If True, apply post-processing (e.g.,
255
+ blurring) to the final texture. Defaults to True.
256
+ inpaint_smooth (bool, optional): If True, apply inpainting to smooth.
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ camera_params: CameraSetting,
262
+ view_weights: list[float],
263
+ render_wh: tuple[int, int] = (2048, 2048),
264
+ texture_wh: tuple[int, int] = (2048, 2048),
265
+ bake_angle_thresh: int = 75,
266
+ mask_thresh: float = 0.5,
267
+ smooth_texture: bool = True,
268
+ inpaint_smooth: bool = False,
269
+ ) -> None:
270
+ self.camera_params = camera_params
271
+ self.renderer = None
272
+ self.view_weights = view_weights
273
+ self.device = camera_params.device
274
+ self.render_wh = render_wh
275
+ self.texture_wh = texture_wh
276
+ self.mask_thresh = mask_thresh
277
+ self.smooth_texture = smooth_texture
278
+ self.inpaint_smooth = inpaint_smooth
279
+
280
+ self.bake_angle_thresh = bake_angle_thresh
281
+ self.bake_unreliable_kernel_size = int(
282
+ (2 / 512) * max(self.render_wh[0], self.render_wh[1])
283
+ )
284
+
285
+ def _lazy_init_render(self, camera_params, mask_thresh):
286
+ if self.renderer is None:
287
+ camera = init_kal_camera(camera_params)
288
+ mv = camera.view_matrix() # (n 4 4) world2cam
289
+ p = camera.intrinsics.projection_matrix()
290
+ # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
291
+ p[:, 1, 1] = -p[:, 1, 1]
292
+ self.renderer = DiffrastRender(
293
+ p_matrix=p,
294
+ mv_matrix=mv,
295
+ resolution_hw=camera_params.resolution_hw,
296
+ context=dr.RasterizeCudaContext(),
297
+ mask_thresh=mask_thresh,
298
+ grad_db=False,
299
+ device=self.device,
300
+ antialias_mask=True,
301
+ )
302
+
303
+ def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
304
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
305
+ self.scale, self.center = scale, center
306
+
307
+ vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces)
308
+ uvs[:, 1] = 1 - uvs[:, 1]
309
+ mesh.vertices = mesh.vertices[vmapping]
310
+ mesh.faces = indices
311
+ mesh.visual.uv = uvs
312
+
313
+ return mesh
314
+
315
+ def get_mesh_np_attrs(
316
+ self,
317
+ mesh: trimesh.Trimesh,
318
+ scale: float = None,
319
+ center: np.ndarray = None,
320
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
321
+ vertices = mesh.vertices.copy()
322
+ faces = mesh.faces.copy()
323
+ uv_map = mesh.visual.uv.copy()
324
+ uv_map[:, 1] = 1.0 - uv_map[:, 1]
325
+
326
+ if scale is not None:
327
+ vertices = vertices / scale
328
+ if center is not None:
329
+ vertices = vertices + center
330
+
331
+ return vertices, faces, uv_map
332
+
333
+ def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor:
334
+ depth_image_np = depth_image.cpu().numpy()
335
+ depth_image_np = (depth_image_np * 255).astype(np.uint8)
336
+ depth_edges = cv2.Canny(depth_image_np, 30, 80)
337
+ sketch_image = (
338
+ torch.from_numpy(depth_edges).to(depth_image.device).float() / 255
339
+ )
340
+ sketch_image = sketch_image.unsqueeze(-1)
341
+
342
+ return sketch_image
343
+
344
+ def compute_enhanced_viewnormal(
345
+ self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
346
+ ) -> torch.Tensor:
347
+ rast, _ = self.renderer.compute_dr_raster(vertices, faces)
348
+ rendered_view_normals = []
349
+ for idx in range(len(mv_mtx)):
350
+ pos_cam = _transform_vertices(mv_mtx[idx], vertices, keepdim=True)
351
+ pos_cam = pos_cam[:, :3] / pos_cam[:, 3:]
352
+ v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3))
353
+ face_norm = F.normalize(
354
+ torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1
355
+ )
356
+ vertex_norm = (
357
+ torch.from_numpy(
358
+ trimesh.geometry.mean_vertex_normals(
359
+ len(pos_cam), faces.cpu(), face_norm.cpu()
360
+ )
361
+ )
362
+ .to(vertices.device)
363
+ .contiguous()
364
+ )
365
+ im_base_normals, _ = dr.interpolate(
366
+ vertex_norm[None, ...].float(),
367
+ rast[idx : idx + 1],
368
+ faces.to(torch.int32),
369
+ )
370
+ rendered_view_normals.append(im_base_normals)
371
+
372
+ rendered_view_normals = torch.cat(rendered_view_normals, dim=0)
373
+
374
+ return rendered_view_normals
375
+
376
+ def back_project(
377
+ self, image, vis_mask, depth, normal, uv
378
+ ) -> tuple[torch.Tensor, torch.Tensor]:
379
+ image = np.array(image)
380
+ image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
381
+ if image.ndim == 2:
382
+ image = image.unsqueeze(-1)
383
+ image = image / 255
384
+
385
+ depth_inv = (1.0 - depth) * vis_mask
386
+ sketch_image = self._render_depth_edges(depth_inv)
387
+
388
+ cos = F.cosine_similarity(
389
+ torch.tensor([[0, 0, 1]], device=self.device),
390
+ normal.view(-1, 3),
391
+ ).view_as(normal[..., :1])
392
+ cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0
393
+
394
+ k = self.bake_unreliable_kernel_size * 2 + 1
395
+ kernel = torch.ones((1, 1, k, k), device=self.device)
396
+
397
+ vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float()
398
+ vis_mask = F.conv2d(
399
+ 1.0 - vis_mask,
400
+ kernel,
401
+ padding=k // 2,
402
+ )
403
+ vis_mask = 1.0 - (vis_mask > 0).float()
404
+ vis_mask = vis_mask.squeeze(0).permute(1, 2, 0)
405
+
406
+ sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0)
407
+ sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2)
408
+ sketch_image = (sketch_image > 0).float()
409
+ sketch_image = sketch_image.squeeze(0).permute(1, 2, 0)
410
+ vis_mask = vis_mask * (sketch_image < 0.5)
411
+
412
+ cos[vis_mask == 0] = 0
413
+ valid_pixels = (vis_mask != 0).view(-1)
414
+
415
+ return (
416
+ self._scatter_texture(uv, image, valid_pixels),
417
+ self._scatter_texture(uv, cos, valid_pixels),
418
+ )
419
+
420
+ def _scatter_texture(self, uv, data, mask):
421
+ def __filter_data(data, mask):
422
+ return data.view(-1, data.shape[-1])[mask]
423
+
424
+ return _bilinear_interpolation_scattering(
425
+ self.texture_wh[1],
426
+ self.texture_wh[0],
427
+ __filter_data(uv, mask)[..., [1, 0]],
428
+ __filter_data(data, mask),
429
+ )
430
+
431
+ @torch.no_grad()
432
+ def fast_bake_texture(
433
+ self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor]
434
+ ) -> tuple[torch.Tensor, torch.Tensor]:
435
+ channel = textures[0].shape[-1]
436
+ texture_merge = torch.zeros(self.texture_wh + [channel]).to(
437
+ self.device
438
+ )
439
+ trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device)
440
+ for texture, cos_map in zip(textures, confidence_maps):
441
+ view_sum = (cos_map > 0).sum()
442
+ painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum()
443
+ if painted_sum / view_sum > 0.99:
444
+ continue
445
+ texture_merge += texture * cos_map
446
+ trust_map_merge += cos_map
447
+ texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8)
448
+
449
+ return texture_merge, trust_map_merge > 1e-8
450
+
451
+ def uv_inpaint(
452
+ self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray
453
+ ) -> np.ndarray:
454
+ if self.inpaint_smooth:
455
+ vertices, faces, uv_map = self.get_mesh_np_attrs(mesh)
456
+ texture, mask = _texture_inpaint_smooth(
457
+ texture, mask, vertices, faces, uv_map
458
+ )
459
+
460
+ texture = texture.clip(0, 1)
461
+ texture = cv2.inpaint(
462
+ (texture * 255).astype(np.uint8),
463
+ 255 - mask,
464
+ 3,
465
+ cv2.INPAINT_NS,
466
+ )
467
+
468
+ return texture
469
+
470
+ @spaces.GPU
471
+ def compute_texture(
472
+ self,
473
+ colors: list[Image.Image],
474
+ mesh: trimesh.Trimesh,
475
+ ) -> trimesh.Trimesh:
476
+ self._lazy_init_render(self.camera_params, self.mask_thresh)
477
+
478
+ vertices = torch.from_numpy(mesh.vertices).to(self.device).float()
479
+ faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int)
480
+ uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float()
481
+
482
+ rendered_depth, masks = self.renderer.render_depth(vertices, faces)
483
+ norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks)
484
+ render_uvs, _ = self.renderer.render_uv(vertices, faces, uv_map)
485
+ view_normals = self.compute_enhanced_viewnormal(
486
+ self.renderer.mv_mtx, vertices, faces
487
+ )
488
+
489
+ textures, weighted_cos_maps = [], []
490
+ for color, mask, dep, normal, uv, weight in zip(
491
+ colors,
492
+ masks,
493
+ norm_deps,
494
+ view_normals,
495
+ render_uvs,
496
+ self.view_weights,
497
+ ):
498
+ texture, cos_map = self.back_project(color, mask, dep, normal, uv)
499
+ textures.append(texture)
500
+ weighted_cos_maps.append(weight * (cos_map**4))
501
+
502
+ texture, mask = self.fast_bake_texture(textures, weighted_cos_maps)
503
+
504
+ texture_np = texture.cpu().numpy()
505
+ mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8)
506
+
507
+ return texture_np, mask_np
508
+
509
+ def __call__(
510
+ self,
511
+ colors: list[Image.Image],
512
+ mesh: trimesh.Trimesh,
513
+ output_path: str,
514
+ ) -> trimesh.Trimesh:
515
+ """Runs the texture baking and exports the textured mesh.
516
+
517
+ Args:
518
+ colors (list[Image.Image]): List of input view images.
519
+ mesh (trimesh.Trimesh): Input mesh to be textured.
520
+ output_path (str): Path to save the output textured mesh (.obj or .glb).
521
+
522
+ Returns:
523
+ trimesh.Trimesh: The textured mesh with UV and texture image.
524
+ """
525
+ mesh = self.load_mesh(mesh)
526
+ texture_np, mask_np = self.compute_texture(colors, mesh)
527
+
528
+ texture_np = self.uv_inpaint(mesh, texture_np, mask_np)
529
+ if self.smooth_texture:
530
+ texture_np = post_process_texture(texture_np)
531
+
532
+ vertices, faces, uv_map = self.get_mesh_np_attrs(
533
+ mesh, self.scale, self.center
534
+ )
535
+ textured_mesh = save_mesh_with_mtl(
536
+ vertices, faces, uv_map, texture_np, output_path
537
+ )
538
+
539
+ return textured_mesh
540
+
541
+
542
+ def parse_args():
543
+ parser = argparse.ArgumentParser(description="Backproject texture")
544
+ parser.add_argument(
545
+ "--color_path",
546
+ nargs="+",
547
+ type=str,
548
+ help="Multiview color image in grid file paths",
549
+ )
550
+ parser.add_argument(
551
+ "--mesh_path",
552
+ type=str,
553
+ help="Mesh path, .obj, .glb or .ply",
554
+ )
555
+ parser.add_argument(
556
+ "--output_path",
557
+ type=str,
558
+ help="Output mesh path with suffix",
559
+ )
560
+ parser.add_argument(
561
+ "--num_images", type=int, default=6, help="Number of images to render."
562
+ )
563
+ parser.add_argument(
564
+ "--elevation",
565
+ nargs="+",
566
+ type=float,
567
+ default=[20.0, -10.0],
568
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
569
+ )
570
+ parser.add_argument(
571
+ "--distance",
572
+ type=float,
573
+ default=5,
574
+ help="Camera distance (default: 5)",
575
+ )
576
+ parser.add_argument(
577
+ "--resolution_hw",
578
+ type=int,
579
+ nargs=2,
580
+ default=(2048, 2048),
581
+ help="Resolution of the output images (default: (2048, 2048))",
582
+ )
583
+ parser.add_argument(
584
+ "--fov",
585
+ type=float,
586
+ default=30,
587
+ help="Field of view in degrees (default: 30)",
588
+ )
589
+ parser.add_argument(
590
+ "--device",
591
+ type=str,
592
+ choices=["cpu", "cuda"],
593
+ default="cuda",
594
+ help="Device to run on (default: `cuda`)",
595
+ )
596
+ parser.add_argument(
597
+ "--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
598
+ )
599
+ parser.add_argument(
600
+ "--texture_wh",
601
+ nargs=2,
602
+ type=int,
603
+ default=[2048, 2048],
604
+ help="Texture resolution width and height",
605
+ )
606
+ parser.add_argument(
607
+ "--mesh_sipmlify_ratio",
608
+ type=float,
609
+ default=0.9,
610
+ help="Mesh simplification ratio (default: 0.9)",
611
+ )
612
+ parser.add_argument(
613
+ "--delight", action="store_true", help="Use delighting model."
614
+ )
615
+ parser.add_argument(
616
+ "--no_smooth_texture",
617
+ action="store_true",
618
+ help="Do not smooth the texture.",
619
+ )
620
+ parser.add_argument(
621
+ "--save_glb_path", type=str, default=None, help="Save glb path."
622
+ )
623
+ parser.add_argument(
624
+ "--no_save_delight_img",
625
+ action="store_true",
626
+ help="Disable saving delight image",
627
+ )
628
+ parser.add_argument("--n_max_faces", type=int, default=30000)
629
+ args, unknown = parser.parse_known_args()
630
+
631
+ return args
632
+
633
+
634
+ def entrypoint(
635
+ delight_model: DelightingModel = None,
636
+ imagesr_model: ImageRealESRGAN = None,
637
+ **kwargs,
638
+ ) -> trimesh.Trimesh:
639
+ args = parse_args()
640
+ for k, v in kwargs.items():
641
+ if hasattr(args, k) and v is not None:
642
+ setattr(args, k, v)
643
+
644
+ # Setup camera parameters.
645
+ camera_params = CameraSetting(
646
+ num_images=args.num_images,
647
+ elevation=args.elevation,
648
+ distance=args.distance,
649
+ resolution_hw=args.resolution_hw,
650
+ fov=math.radians(args.fov),
651
+ device=args.device,
652
+ )
653
+
654
+ args.color_path = as_list(args.color_path)
655
+ if args.delight and delight_model is None:
656
+ delight_model = DelightingModel()
657
+
658
+ color_grid = [Image.open(color_path) for color_path in args.color_path]
659
+ color_grid = vcat_pil_images(color_grid, image_mode="RGBA")
660
+ if args.delight:
661
+ color_grid = delight_model(color_grid)
662
+ if not args.no_save_delight_img:
663
+ save_dir = os.path.dirname(args.output_path)
664
+ os.makedirs(save_dir, exist_ok=True)
665
+ color_grid.save(f"{save_dir}/color_delight.png")
666
+
667
+ multiviews = get_images_from_grid(color_grid, img_size=512)
668
+ view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02]
669
+ view_weights += [0.01] * (len(multiviews) - len(view_weights))
670
+
671
+ # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
672
+ if imagesr_model is None:
673
+ imagesr_model = ImageRealESRGAN(outscale=4)
674
+ multiviews = [imagesr_model(img) for img in multiviews]
675
+ multiviews = [img.convert("RGB") for img in multiviews]
676
+ mesh = trimesh.load(args.mesh_path)
677
+ if isinstance(mesh, trimesh.Scene):
678
+ mesh = mesh.dump(concatenate=True)
679
+
680
+ if not args.skip_fix_mesh:
681
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
682
+ mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device)
683
+ mesh.vertices, mesh.faces = mesh_fixer(
684
+ filter_ratio=args.mesh_sipmlify_ratio,
685
+ max_hole_size=0.04,
686
+ resolution=1024,
687
+ num_views=1000,
688
+ norm_mesh_ratio=0.5,
689
+ )
690
+ if len(mesh.faces) > args.n_max_faces:
691
+ mesh.vertices, mesh.faces = mesh_fixer(
692
+ filter_ratio=0.8,
693
+ max_hole_size=0.04,
694
+ resolution=1024,
695
+ num_views=1000,
696
+ norm_mesh_ratio=0.5,
697
+ )
698
+ # Restore scale.
699
+ mesh.vertices = mesh.vertices / scale
700
+ mesh.vertices = mesh.vertices + center
701
+
702
+ # Baking texture to mesh.
703
+ texture_backer = TextureBacker(
704
+ camera_params=camera_params,
705
+ view_weights=view_weights,
706
+ render_wh=args.resolution_hw,
707
+ texture_wh=args.texture_wh,
708
+ smooth_texture=not args.no_smooth_texture,
709
+ )
710
+
711
+ textured_mesh = texture_backer(multiviews, mesh, args.output_path)
712
+
713
+ if args.save_glb_path is not None:
714
+ os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True)
715
+ textured_mesh.export(args.save_glb_path)
716
+
717
+ return textured_mesh
718
+
719
+
720
+ if __name__ == "__main__":
721
+ entrypoint()
embodied_gen/data/convex_decomposer.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import logging
18
+ import multiprocessing as mp
19
+ import os
20
+
21
+ import coacd
22
+ import numpy as np
23
+ import trimesh
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ __all__ = [
28
+ "decompose_convex_coacd",
29
+ "decompose_convex_mesh",
30
+ "decompose_convex_mp",
31
+ ]
32
+
33
+
34
+ def decompose_convex_coacd(
35
+ filename: str,
36
+ outfile: str,
37
+ params: dict,
38
+ verbose: bool = False,
39
+ auto_scale: bool = True,
40
+ scale_factor: float = 1.0,
41
+ ) -> None:
42
+ coacd.set_log_level("info" if verbose else "warn")
43
+
44
+ mesh = trimesh.load(filename, force="mesh")
45
+ mesh = coacd.Mesh(mesh.vertices, mesh.faces)
46
+
47
+ result = coacd.run_coacd(mesh, **params)
48
+
49
+ meshes = []
50
+ for v, f in result:
51
+ meshes.append(trimesh.Trimesh(v, f))
52
+
53
+ # Compute collision_scale because convex decomposition usually makes the mesh larger.
54
+ if auto_scale:
55
+ all_mesh = sum([trimesh.Trimesh(*m) for m in result])
56
+ convex_mesh_shape = np.ptp(all_mesh.vertices, axis=0)
57
+ visual_mesh_shape = np.ptp(mesh.vertices, axis=0)
58
+ scale_factor *= visual_mesh_shape / convex_mesh_shape
59
+
60
+ combined = trimesh.Scene()
61
+ for mesh_part in meshes:
62
+ mesh_part.vertices *= scale_factor
63
+ combined.add_geometry(mesh_part)
64
+
65
+ combined.export(outfile)
66
+
67
+
68
+ def decompose_convex_mesh(
69
+ filename: str,
70
+ outfile: str,
71
+ threshold: float = 0.05,
72
+ max_convex_hull: int = -1,
73
+ preprocess_mode: str = "auto",
74
+ preprocess_resolution: int = 30,
75
+ resolution: int = 2000,
76
+ mcts_nodes: int = 20,
77
+ mcts_iterations: int = 150,
78
+ mcts_max_depth: int = 3,
79
+ pca: bool = False,
80
+ merge: bool = True,
81
+ seed: int = 0,
82
+ auto_scale: bool = True,
83
+ scale_factor: float = 1.005,
84
+ verbose: bool = False,
85
+ ) -> str:
86
+ """Decompose a mesh into convex parts using the CoACD algorithm."""
87
+ coacd.set_log_level("info" if verbose else "warn")
88
+
89
+ if os.path.exists(outfile):
90
+ logger.warning(f"Output file {outfile} already exists, removing it.")
91
+ os.remove(outfile)
92
+
93
+ params = dict(
94
+ threshold=threshold,
95
+ max_convex_hull=max_convex_hull,
96
+ preprocess_mode=preprocess_mode,
97
+ preprocess_resolution=preprocess_resolution,
98
+ resolution=resolution,
99
+ mcts_nodes=mcts_nodes,
100
+ mcts_iterations=mcts_iterations,
101
+ mcts_max_depth=mcts_max_depth,
102
+ pca=pca,
103
+ merge=merge,
104
+ seed=seed,
105
+ )
106
+
107
+ try:
108
+ decompose_convex_coacd(
109
+ filename, outfile, params, verbose, auto_scale, scale_factor
110
+ )
111
+ if os.path.exists(outfile):
112
+ return outfile
113
+ except Exception as e:
114
+ if verbose:
115
+ print(f"Decompose convex first attempt failed: {e}.")
116
+
117
+ if preprocess_mode != "on":
118
+ try:
119
+ params["preprocess_mode"] = "on"
120
+ decompose_convex_coacd(
121
+ filename, outfile, params, verbose, auto_scale, scale_factor
122
+ )
123
+ if os.path.exists(outfile):
124
+ return outfile
125
+ except Exception as e:
126
+ if verbose:
127
+ print(
128
+ f"Decompose convex second attempt with preprocess_mode='on' failed: {e}"
129
+ )
130
+
131
+ raise RuntimeError(f"Convex decomposition failed on {filename}")
132
+
133
+
134
+ def decompose_convex_mp(
135
+ filename: str,
136
+ outfile: str,
137
+ threshold: float = 0.05,
138
+ max_convex_hull: int = -1,
139
+ preprocess_mode: str = "auto",
140
+ preprocess_resolution: int = 30,
141
+ resolution: int = 2000,
142
+ mcts_nodes: int = 20,
143
+ mcts_iterations: int = 150,
144
+ mcts_max_depth: int = 3,
145
+ pca: bool = False,
146
+ merge: bool = True,
147
+ seed: int = 0,
148
+ verbose: bool = False,
149
+ auto_scale: bool = True,
150
+ ) -> str:
151
+ """Decompose a mesh into convex parts using the CoACD algorithm in a separate process.
152
+
153
+ See https://simulately.wiki/docs/toolkits/ConvexDecomp for details.
154
+ """
155
+ params = dict(
156
+ threshold=threshold,
157
+ max_convex_hull=max_convex_hull,
158
+ preprocess_mode=preprocess_mode,
159
+ preprocess_resolution=preprocess_resolution,
160
+ resolution=resolution,
161
+ mcts_nodes=mcts_nodes,
162
+ mcts_iterations=mcts_iterations,
163
+ mcts_max_depth=mcts_max_depth,
164
+ pca=pca,
165
+ merge=merge,
166
+ seed=seed,
167
+ )
168
+
169
+ ctx = mp.get_context("spawn")
170
+ p = ctx.Process(
171
+ target=decompose_convex_coacd,
172
+ args=(filename, outfile, params, verbose, auto_scale),
173
+ )
174
+ p.start()
175
+ p.join()
176
+ if p.exitcode == 0 and os.path.exists(outfile):
177
+ return outfile
178
+
179
+ if preprocess_mode != "on":
180
+ params["preprocess_mode"] = "on"
181
+ p = ctx.Process(
182
+ target=decompose_convex_coacd,
183
+ args=(filename, outfile, params, verbose, auto_scale),
184
+ )
185
+ p.start()
186
+ p.join()
187
+ if p.exitcode == 0 and os.path.exists(outfile):
188
+ return outfile
189
+
190
+ raise RuntimeError(f"Convex decomposition failed on {filename}")
embodied_gen/data/datasets.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import json
19
+ import logging
20
+ import os
21
+ import random
22
+ from typing import Any, Callable, Dict, List, Literal, Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from PIL import Image
28
+ from torch import nn
29
+ from torch.utils.data import Dataset
30
+ from torchvision import transforms
31
+
32
+ logging.basicConfig(
33
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
34
+ )
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ __all__ = [
39
+ "Asset3dGenDataset",
40
+ "PanoGSplatDataset",
41
+ ]
42
+
43
+
44
+ class Asset3dGenDataset(Dataset):
45
+ def __init__(
46
+ self,
47
+ index_file: str,
48
+ target_hw: Tuple[int, int],
49
+ transform: Callable = None,
50
+ control_transform: Callable = None,
51
+ max_train_samples: int = None,
52
+ sub_idxs: List[List[int]] = None,
53
+ seed: int = 79,
54
+ ) -> None:
55
+ if not os.path.exists(index_file):
56
+ raise FileNotFoundError(f"{index_file} index_file not found.")
57
+
58
+ self.index_file = index_file
59
+ self.target_hw = target_hw
60
+ self.transform = transform
61
+ self.control_transform = control_transform
62
+ self.max_train_samples = max_train_samples
63
+ self.meta_info = self.prepare_data_index(index_file)
64
+ self.data_list = sorted(self.meta_info.keys())
65
+ self.sub_idxs = sub_idxs # sub_idxs [[0,1,2], [3,4,5], [...], ...]
66
+ self.image_num = 6 # hardcode temp.
67
+ random.seed(seed)
68
+ logger.info(f"Trainset: {len(self)} asset3d instances.")
69
+
70
+ def __len__(self) -> int:
71
+ return len(self.meta_info)
72
+
73
+ def prepare_data_index(self, index_file: str) -> Dict[str, Any]:
74
+ with open(index_file, "r") as fin:
75
+ meta_info = json.load(fin)
76
+
77
+ meta_info_filtered = dict()
78
+ for idx, uid in enumerate(meta_info):
79
+ if "status" not in meta_info[uid]:
80
+ continue
81
+ if meta_info[uid]["status"] != "success":
82
+ continue
83
+ if self.max_train_samples and idx >= self.max_train_samples:
84
+ break
85
+
86
+ meta_info_filtered[uid] = meta_info[uid]
87
+
88
+ logger.info(
89
+ f"Load {len(meta_info)} assets, keep {len(meta_info_filtered)} valids." # noqa
90
+ )
91
+
92
+ return meta_info_filtered
93
+
94
+ def fetch_sample_images(
95
+ self,
96
+ uid: str,
97
+ attrs: List[str],
98
+ sub_index: int = None,
99
+ transform: Callable = None,
100
+ ) -> torch.Tensor:
101
+ sample = self.meta_info[uid]
102
+ images = []
103
+ for attr in attrs:
104
+ item = sample[attr]
105
+ if sub_index is not None:
106
+ item = item[sub_index]
107
+ mode = "L" if attr == "image_mask" else "RGB"
108
+ image = Image.open(item).convert(mode)
109
+ if transform is not None:
110
+ image = transform(image)
111
+ if len(image.shape) == 2:
112
+ image = image[..., None]
113
+ images.append(image)
114
+
115
+ images = torch.cat(images, dim=0)
116
+
117
+ return images
118
+
119
+ def fetch_sample_grid_images(
120
+ self,
121
+ uid: str,
122
+ attrs: List[str],
123
+ sub_idxs: List[List[int]],
124
+ transform: Callable = None,
125
+ ) -> torch.Tensor:
126
+ assert transform is not None
127
+
128
+ grid_image = []
129
+ for row_idxs in sub_idxs:
130
+ row_image = []
131
+ for row_idx in row_idxs:
132
+ image = self.fetch_sample_images(
133
+ uid, attrs, row_idx, transform
134
+ )
135
+ row_image.append(image)
136
+ row_image = torch.cat(row_image, dim=2) # (c h w)
137
+ grid_image.append(row_image)
138
+
139
+ grid_image = torch.cat(grid_image, dim=1)
140
+
141
+ return grid_image
142
+
143
+ def compute_text_embeddings(
144
+ self, embed_path: str, original_size: Tuple[int, int]
145
+ ) -> Dict[str, nn.Module]:
146
+ data_dict = torch.load(embed_path)
147
+ prompt_embeds = data_dict["prompt_embeds"][0]
148
+ add_text_embeds = data_dict["pooled_prompt_embeds"][0]
149
+
150
+ # Need changed if random crop, set as crop_top_left [y1, x1], center crop as [0, 0]. # noqa
151
+ crops_coords_top_left = (0, 0)
152
+ add_time_ids = list(
153
+ original_size + crops_coords_top_left + self.target_hw
154
+ )
155
+ add_time_ids = torch.tensor([add_time_ids])
156
+ # add_time_ids = add_time_ids.repeat((len(add_text_embeds), 1))
157
+
158
+ unet_added_cond_kwargs = {
159
+ "text_embeds": add_text_embeds,
160
+ "time_ids": add_time_ids,
161
+ }
162
+
163
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
164
+
165
+ def visualize_item(
166
+ self,
167
+ control: torch.Tensor,
168
+ color: torch.Tensor,
169
+ save_dir: str = None,
170
+ ) -> List[Image.Image]:
171
+ to_pil = transforms.ToPILImage()
172
+
173
+ color = (color + 1) / 2
174
+ color_pil = to_pil(color)
175
+ normal_pil = to_pil(control[0:3])
176
+ position_pil = to_pil(control[3:6])
177
+ mask_pil = to_pil(control[6:])
178
+
179
+ if save_dir is not None:
180
+ os.makedirs(save_dir, exist_ok=True)
181
+ color_pil.save(f"{save_dir}/rgb.jpg")
182
+ normal_pil.save(f"{save_dir}/normal.jpg")
183
+ position_pil.save(f"{save_dir}/position.jpg")
184
+ mask_pil.save(f"{save_dir}/mask.jpg")
185
+ logger.info(f"Visualization in {save_dir}")
186
+
187
+ return normal_pil, position_pil, mask_pil, color_pil
188
+
189
+ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
190
+ uid = self.data_list[index]
191
+
192
+ sub_idxs = self.sub_idxs
193
+ if sub_idxs is None:
194
+ sub_idxs = [[random.randint(0, self.image_num - 1)]]
195
+
196
+ input_image = self.fetch_sample_grid_images(
197
+ uid,
198
+ attrs=["image_view_normal", "image_position", "image_mask"],
199
+ sub_idxs=sub_idxs,
200
+ transform=self.control_transform,
201
+ )
202
+ assert input_image.shape[1:] == self.target_hw
203
+
204
+ output_image = self.fetch_sample_grid_images(
205
+ uid,
206
+ attrs=["image_color"],
207
+ sub_idxs=sub_idxs,
208
+ transform=self.transform,
209
+ )
210
+
211
+ sample = self.meta_info[uid]
212
+ text_feats = self.compute_text_embeddings(
213
+ sample["text_feat"], tuple(sample["image_hw"])
214
+ )
215
+
216
+ data = dict(
217
+ pixel_values=output_image,
218
+ conditioning_pixel_values=input_image,
219
+ prompt_embeds=text_feats["prompt_embeds"],
220
+ text_embeds=text_feats["text_embeds"],
221
+ time_ids=text_feats["time_ids"],
222
+ )
223
+
224
+ return data
225
+
226
+
227
+ class PanoGSplatDataset(Dataset):
228
+ """A PyTorch Dataset for loading panorama-based 3D Gaussian Splatting data.
229
+
230
+ This dataset is designed to be compatible with train and eval pipelines
231
+ that use COLMAP-style camera conventions.
232
+
233
+ Args:
234
+ data_dir (str): Root directory where the dataset file is located.
235
+ split (str): Dataset split to use, either "train" or "eval".
236
+ data_name (str, optional): Name of the dataset file (default: "gs_data.pt").
237
+ max_sample_num (int, optional): Maximum number of samples to load. If None,
238
+ all available samples in the split will be used.
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ data_dir: str,
244
+ split: str = Literal["train", "eval"],
245
+ data_name: str = "gs_data.pt",
246
+ max_sample_num: int = None,
247
+ ) -> None:
248
+ self.data_path = os.path.join(data_dir, data_name)
249
+ self.split = split
250
+ self.max_sample_num = max_sample_num
251
+ if not os.path.exists(self.data_path):
252
+ raise FileNotFoundError(
253
+ f"Dataset file {self.data_path} not found. Please provide the correct path."
254
+ )
255
+ self.data = torch.load(self.data_path, weights_only=False)
256
+ self.frames = self.data[split]
257
+ if max_sample_num is not None:
258
+ self.frames = self.frames[:max_sample_num]
259
+ self.points = self.data.get("points", None)
260
+ self.points_rgb = self.data.get("points_rgb", None)
261
+
262
+ def __len__(self) -> int:
263
+ return len(self.frames)
264
+
265
+ def cvt_blender_to_colmap_coord(self, c2w: np.ndarray) -> np.ndarray:
266
+ # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
267
+ tranformed_c2w = np.copy(c2w)
268
+ tranformed_c2w[:3, 1:3] *= -1
269
+
270
+ return tranformed_c2w
271
+
272
+ def __getitem__(self, index: int) -> dict[str, any]:
273
+ data = self.frames[index]
274
+ c2w = self.cvt_blender_to_colmap_coord(data["camtoworld"])
275
+ item = dict(
276
+ camtoworld=c2w,
277
+ K=data["K"],
278
+ image_h=data["image_h"],
279
+ image_w=data["image_w"],
280
+ )
281
+ if "image" in data:
282
+ item["image"] = data["image"]
283
+ if "image_id" in data:
284
+ item["image_id"] = data["image_id"]
285
+
286
+ return item
287
+
288
+
289
+ if __name__ == "__main__":
290
+ index_file = "datasets/objaverse/v1.0/statistics_1.0_gobjaverse_filter/view6s_v4/meta_ac2e0ddea8909db26d102c8465b5bcb2.json" # noqa
291
+ target_hw = (512, 512)
292
+ transform_list = [
293
+ transforms.Resize(
294
+ target_hw, interpolation=transforms.InterpolationMode.BILINEAR
295
+ ),
296
+ transforms.CenterCrop(target_hw),
297
+ transforms.ToTensor(),
298
+ transforms.Normalize([0.5], [0.5]),
299
+ ]
300
+ image_transform = transforms.Compose(transform_list)
301
+ control_transform = transforms.Compose(transform_list[:-1])
302
+
303
+ sub_idxs = [[0, 1, 2], [3, 4, 5]] # None
304
+ if sub_idxs is not None:
305
+ target_hw = (
306
+ target_hw[0] * len(sub_idxs),
307
+ target_hw[1] * len(sub_idxs[0]),
308
+ )
309
+
310
+ dataset = Asset3dGenDataset(
311
+ index_file,
312
+ target_hw,
313
+ image_transform,
314
+ control_transform,
315
+ sub_idxs=sub_idxs,
316
+ )
317
+ data = dataset[0]
318
+ dataset.visualize_item(
319
+ data["conditioning_pixel_values"], data["pixel_values"], save_dir="./"
320
+ )
embodied_gen/data/differentiable_render.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import json
20
+ import logging
21
+ import math
22
+ import os
23
+ from collections import defaultdict
24
+ from typing import List, Union
25
+
26
+ import cv2
27
+ import imageio
28
+ import numpy as np
29
+ import nvdiffrast.torch as dr
30
+ import PIL.Image as Image
31
+ import torch
32
+ from tqdm import tqdm
33
+ from embodied_gen.data.utils import (
34
+ CameraSetting,
35
+ DiffrastRender,
36
+ as_list,
37
+ calc_vertex_normals,
38
+ import_kaolin_mesh,
39
+ init_kal_camera,
40
+ normalize_vertices_array,
41
+ render_pbr,
42
+ save_images,
43
+ )
44
+ from embodied_gen.utils.enum import RenderItems
45
+
46
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
47
+ os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
48
+ "~/.cache/torch_extensions"
49
+ )
50
+ logging.basicConfig(
51
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
52
+ )
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ __all__ = [
57
+ "ImageRender",
58
+ "create_mp4_from_images",
59
+ "create_gif_from_images",
60
+ ]
61
+
62
+
63
+ def create_mp4_from_images(
64
+ images: list[np.ndarray],
65
+ output_path: str,
66
+ fps: int = 10,
67
+ prompt: str = None,
68
+ ):
69
+ font = cv2.FONT_HERSHEY_SIMPLEX
70
+ font_scale = 0.5
71
+ font_thickness = 1
72
+ color = (255, 255, 255)
73
+ position = (20, 25)
74
+
75
+ with imageio.get_writer(output_path, fps=fps) as writer:
76
+ for image in images:
77
+ image = image.clip(min=0, max=1)
78
+ image = (255.0 * image).astype(np.uint8)
79
+ image = image[..., :3]
80
+ if prompt is not None:
81
+ cv2.putText(
82
+ image,
83
+ prompt,
84
+ position,
85
+ font,
86
+ font_scale,
87
+ color,
88
+ font_thickness,
89
+ )
90
+
91
+ writer.append_data(image)
92
+
93
+ logger.info(f"MP4 video saved to {output_path}")
94
+
95
+
96
+ def create_gif_from_images(
97
+ images: list[np.ndarray], output_path: str, fps: int = 10
98
+ ) -> None:
99
+ pil_images = []
100
+ for image in images:
101
+ image = image.clip(min=0, max=1)
102
+ image = (255.0 * image).astype(np.uint8)
103
+ image = Image.fromarray(image, mode="RGBA")
104
+ pil_images.append(image.convert("RGB"))
105
+
106
+ duration = 1000 // fps
107
+ pil_images[0].save(
108
+ output_path,
109
+ save_all=True,
110
+ append_images=pil_images[1:],
111
+ duration=duration,
112
+ loop=0,
113
+ )
114
+
115
+ logger.info(f"GIF saved to {output_path}")
116
+
117
+
118
+ class ImageRender(object):
119
+ """A differentiable mesh renderer supporting multi-view rendering.
120
+
121
+ This class wraps a differentiable rasterization using `nvdiffrast` to
122
+ render mesh geometry to various maps (normal, depth, alpha, albedo, etc.).
123
+
124
+ Args:
125
+ render_items (list[RenderItems]): A list of rendering targets to
126
+ generate (e.g., IMAGE, DEPTH, NORMAL, etc.).
127
+ camera_params (CameraSetting): The camera parameters for rendering,
128
+ including intrinsic and extrinsic matrices.
129
+ recompute_vtx_normal (bool, optional): If True, recomputes
130
+ vertex normals from the mesh geometry. Defaults to True.
131
+ with_mtl (bool, optional): Whether to load `.mtl` material files
132
+ for meshes. Defaults to False.
133
+ gen_color_gif (bool, optional): Generate a GIF of rendered
134
+ color images. Defaults to False.
135
+ gen_color_mp4 (bool, optional): Generate an MP4 video of rendered
136
+ color images. Defaults to False.
137
+ gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of
138
+ view-space normals. Defaults to False.
139
+ gen_glonormal_mp4 (bool, optional): Generate an MP4 video of
140
+ global-space normals. Defaults to False.
141
+ no_index_file (bool, optional): If True, skip saving the `index.json`
142
+ summary file. Defaults to False.
143
+ light_factor (float, optional): A scalar multiplier for
144
+ PBR light intensity. Defaults to 1.0.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ render_items: list[RenderItems],
150
+ camera_params: CameraSetting,
151
+ recompute_vtx_normal: bool = True,
152
+ with_mtl: bool = False,
153
+ gen_color_gif: bool = False,
154
+ gen_color_mp4: bool = False,
155
+ gen_viewnormal_mp4: bool = False,
156
+ gen_glonormal_mp4: bool = False,
157
+ no_index_file: bool = False,
158
+ light_factor: float = 1.0,
159
+ ) -> None:
160
+ camera = init_kal_camera(camera_params)
161
+ self.camera = camera
162
+
163
+ # Setup MVP matrix and renderer.
164
+ mv = camera.view_matrix() # (n 4 4) world2cam
165
+ p = camera.intrinsics.projection_matrix()
166
+ # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa
167
+ p[:, 1, 1] = -p[:, 1, 1]
168
+ # mvp = torch.bmm(p, mv) # camera.view_projection_matrix()
169
+ self.mv = mv
170
+ self.p = p
171
+
172
+ renderer = DiffrastRender(
173
+ p_matrix=p,
174
+ mv_matrix=mv,
175
+ resolution_hw=camera_params.resolution_hw,
176
+ context=dr.RasterizeCudaContext(),
177
+ mask_thresh=0.5,
178
+ grad_db=False,
179
+ device=camera_params.device,
180
+ antialias_mask=True,
181
+ )
182
+ self.renderer = renderer
183
+ self.recompute_vtx_normal = recompute_vtx_normal
184
+ self.render_items = render_items
185
+ self.device = camera_params.device
186
+ self.with_mtl = with_mtl
187
+ self.gen_color_gif = gen_color_gif
188
+ self.gen_color_mp4 = gen_color_mp4
189
+ self.gen_viewnormal_mp4 = gen_viewnormal_mp4
190
+ self.gen_glonormal_mp4 = gen_glonormal_mp4
191
+ self.light_factor = light_factor
192
+ self.no_index_file = no_index_file
193
+
194
+ def render_mesh(
195
+ self,
196
+ mesh_path: Union[str, List[str]],
197
+ output_root: str,
198
+ uuid: Union[str, List[str]] = None,
199
+ prompts: List[str] = None,
200
+ ) -> None:
201
+ mesh_path = as_list(mesh_path)
202
+ if uuid is None:
203
+ uuid = [os.path.basename(p).split(".")[0] for p in mesh_path]
204
+ uuid = as_list(uuid)
205
+ assert len(mesh_path) == len(uuid)
206
+ os.makedirs(output_root, exist_ok=True)
207
+
208
+ meta_info = dict()
209
+ for idx, (path, uid) in tqdm(
210
+ enumerate(zip(mesh_path, uuid)), total=len(mesh_path)
211
+ ):
212
+ output_dir = os.path.join(output_root, uid)
213
+ os.makedirs(output_dir, exist_ok=True)
214
+ prompt = prompts[idx] if prompts else None
215
+ data_dict = self(path, output_dir, prompt)
216
+ meta_info[uid] = data_dict
217
+
218
+ if self.no_index_file:
219
+ return
220
+
221
+ index_file = os.path.join(output_root, "index.json")
222
+ with open(index_file, "w") as fout:
223
+ json.dump(meta_info, fout)
224
+
225
+ logger.info(f"Rendering meta info logged in {index_file}")
226
+
227
+ def __call__(
228
+ self, mesh_path: str, output_dir: str, prompt: str = None
229
+ ) -> dict[str, str]:
230
+ """Render a single mesh and return paths to the rendered outputs.
231
+
232
+ Processes the input mesh, renders multiple modalities (e.g., normals,
233
+ depth, albedo), and optionally saves video or image sequences.
234
+
235
+ Args:
236
+ mesh_path (str): Path to the mesh file (.obj/.glb).
237
+ output_dir (str): Directory to save rendered outputs.
238
+ prompt (str, optional): Optional caption prompt for MP4 metadata.
239
+
240
+ Returns:
241
+ dict[str, str]: A mapping render types to the saved image paths.
242
+ """
243
+ try:
244
+ mesh = import_kaolin_mesh(mesh_path, self.with_mtl)
245
+ except Exception as e:
246
+ logger.error(f"[ERROR MESH LOAD]: {e}, skip {mesh_path}")
247
+ return
248
+
249
+ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices)
250
+ if self.recompute_vtx_normal:
251
+ mesh.vertex_normals = calc_vertex_normals(
252
+ mesh.vertices, mesh.faces
253
+ )
254
+
255
+ mesh = mesh.to(self.device)
256
+ vertices, faces, vertex_normals = (
257
+ mesh.vertices,
258
+ mesh.faces,
259
+ mesh.vertex_normals,
260
+ )
261
+
262
+ # Perform rendering.
263
+ data_dict = defaultdict(list)
264
+ if RenderItems.ALPHA.value in self.render_items:
265
+ masks, _ = self.renderer.render_rast_alpha(vertices, faces)
266
+ render_paths = save_images(
267
+ masks, f"{output_dir}/{RenderItems.ALPHA}"
268
+ )
269
+ data_dict[RenderItems.ALPHA.value] = render_paths
270
+
271
+ if RenderItems.GLOBAL_NORMAL.value in self.render_items:
272
+ rendered_normals, masks = self.renderer.render_global_normal(
273
+ vertices, faces, vertex_normals
274
+ )
275
+ if self.gen_glonormal_mp4:
276
+ if isinstance(rendered_normals, torch.Tensor):
277
+ rendered_normals = rendered_normals.detach().cpu().numpy()
278
+ create_mp4_from_images(
279
+ rendered_normals,
280
+ output_path=f"{output_dir}/normal.mp4",
281
+ fps=15,
282
+ prompt=prompt,
283
+ )
284
+ else:
285
+ render_paths = save_images(
286
+ rendered_normals,
287
+ f"{output_dir}/{RenderItems.GLOBAL_NORMAL}",
288
+ cvt_color=cv2.COLOR_BGR2RGB,
289
+ )
290
+ data_dict[RenderItems.GLOBAL_NORMAL.value] = render_paths
291
+
292
+ if RenderItems.VIEW_NORMAL.value in self.render_items:
293
+ assert (
294
+ RenderItems.GLOBAL_NORMAL in self.render_items
295
+ ), f"Must render global normal firstly, got render_items: {self.render_items}." # noqa
296
+ rendered_view_normals = self.renderer.transform_normal(
297
+ rendered_normals, self.mv, masks, to_view=True
298
+ )
299
+
300
+ if self.gen_viewnormal_mp4:
301
+ create_mp4_from_images(
302
+ rendered_view_normals,
303
+ output_path=f"{output_dir}/view_normal.mp4",
304
+ fps=15,
305
+ prompt=prompt,
306
+ )
307
+ else:
308
+ render_paths = save_images(
309
+ rendered_view_normals,
310
+ f"{output_dir}/{RenderItems.VIEW_NORMAL}",
311
+ cvt_color=cv2.COLOR_BGR2RGB,
312
+ )
313
+ data_dict[RenderItems.VIEW_NORMAL.value] = render_paths
314
+
315
+ if RenderItems.POSITION_MAP.value in self.render_items:
316
+ rendered_position, masks = self.renderer.render_position(
317
+ vertices, faces
318
+ )
319
+ norm_position = self.renderer.normalize_map_by_mask(
320
+ rendered_position, masks
321
+ )
322
+ render_paths = save_images(
323
+ norm_position,
324
+ f"{output_dir}/{RenderItems.POSITION_MAP}",
325
+ cvt_color=cv2.COLOR_BGR2RGB,
326
+ )
327
+ data_dict[RenderItems.POSITION_MAP.value] = render_paths
328
+
329
+ if RenderItems.DEPTH.value in self.render_items:
330
+ rendered_depth, masks = self.renderer.render_depth(vertices, faces)
331
+ norm_depth = self.renderer.normalize_map_by_mask(
332
+ rendered_depth, masks
333
+ )
334
+ render_paths = save_images(
335
+ norm_depth,
336
+ f"{output_dir}/{RenderItems.DEPTH}",
337
+ )
338
+ data_dict[RenderItems.DEPTH.value] = render_paths
339
+
340
+ render_paths = save_images(
341
+ rendered_depth,
342
+ f"{output_dir}/{RenderItems.DEPTH}_exr",
343
+ to_uint8=False,
344
+ format=".exr",
345
+ )
346
+ data_dict[f"{RenderItems.DEPTH.value}_exr"] = render_paths
347
+
348
+ if RenderItems.IMAGE.value in self.render_items:
349
+ images = []
350
+ albedos = []
351
+ diffuses = []
352
+ masks, _ = self.renderer.render_rast_alpha(vertices, faces)
353
+ try:
354
+ for idx, cam in enumerate(self.camera):
355
+ image, albedo, diffuse, _ = render_pbr(
356
+ mesh, cam, light_factor=self.light_factor
357
+ )
358
+ image = torch.cat([image[0], masks[idx]], axis=-1)
359
+ images.append(image.detach().cpu().numpy())
360
+
361
+ if RenderItems.ALBEDO.value in self.render_items:
362
+ albedo = torch.cat([albedo[0], masks[idx]], axis=-1)
363
+ albedos.append(albedo.detach().cpu().numpy())
364
+
365
+ if RenderItems.DIFFUSE.value in self.render_items:
366
+ diffuse = torch.cat([diffuse[0], masks[idx]], axis=-1)
367
+ diffuses.append(diffuse.detach().cpu().numpy())
368
+
369
+ except Exception as e:
370
+ logger.error(f"[ERROR pbr render]: {e}, skip {mesh_path}")
371
+ return
372
+
373
+ if self.gen_color_gif:
374
+ create_gif_from_images(
375
+ images,
376
+ output_path=f"{output_dir}/color.gif",
377
+ fps=15,
378
+ )
379
+
380
+ if self.gen_color_mp4:
381
+ create_mp4_from_images(
382
+ images,
383
+ output_path=f"{output_dir}/color.mp4",
384
+ fps=15,
385
+ prompt=prompt,
386
+ )
387
+
388
+ if self.gen_color_mp4 or self.gen_color_gif:
389
+ return data_dict
390
+
391
+ render_paths = save_images(
392
+ images,
393
+ f"{output_dir}/{RenderItems.IMAGE}",
394
+ cvt_color=cv2.COLOR_BGRA2RGBA,
395
+ )
396
+ data_dict[RenderItems.IMAGE.value] = render_paths
397
+
398
+ render_paths = save_images(
399
+ albedos,
400
+ f"{output_dir}/{RenderItems.ALBEDO}",
401
+ cvt_color=cv2.COLOR_BGRA2RGBA,
402
+ )
403
+ data_dict[RenderItems.ALBEDO.value] = render_paths
404
+
405
+ render_paths = save_images(
406
+ diffuses,
407
+ f"{output_dir}/{RenderItems.DIFFUSE}",
408
+ cvt_color=cv2.COLOR_BGRA2RGBA,
409
+ )
410
+ data_dict[RenderItems.DIFFUSE.value] = render_paths
411
+
412
+ data_dict["status"] = "success"
413
+
414
+ logger.info(f"Finish rendering in {output_dir}")
415
+
416
+ return data_dict
417
+
418
+
419
+ def parse_args():
420
+ parser = argparse.ArgumentParser(description="Render settings")
421
+
422
+ parser.add_argument(
423
+ "--mesh_path",
424
+ type=str,
425
+ nargs="+",
426
+ help="Paths to the mesh files for rendering.",
427
+ )
428
+ parser.add_argument(
429
+ "--output_root",
430
+ type=str,
431
+ help="Root directory for output",
432
+ )
433
+ parser.add_argument(
434
+ "--uuid",
435
+ type=str,
436
+ nargs="+",
437
+ default=None,
438
+ help="uuid for rendering saving.",
439
+ )
440
+ parser.add_argument(
441
+ "--num_images", type=int, default=6, help="Number of images to render."
442
+ )
443
+ parser.add_argument(
444
+ "--elevation",
445
+ type=float,
446
+ nargs="+",
447
+ default=[20.0, -10.0],
448
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
449
+ )
450
+ parser.add_argument(
451
+ "--distance",
452
+ type=float,
453
+ default=5,
454
+ help="Camera distance (default: 5)",
455
+ )
456
+ parser.add_argument(
457
+ "--resolution_hw",
458
+ type=int,
459
+ nargs=2,
460
+ default=(512, 512),
461
+ help="Resolution of the output images (default: (512, 512))",
462
+ )
463
+ parser.add_argument(
464
+ "--fov",
465
+ type=float,
466
+ default=30,
467
+ help="Field of view in degrees (default: 30)",
468
+ )
469
+ parser.add_argument(
470
+ "--pbr_light_factor",
471
+ type=float,
472
+ default=1.0,
473
+ help="Light factor for mesh PBR rendering (default: 1.)",
474
+ )
475
+ parser.add_argument(
476
+ "--with_mtl",
477
+ action="store_true",
478
+ help="Whether to render with mesh material.",
479
+ )
480
+ parser.add_argument(
481
+ "--gen_color_gif",
482
+ action="store_true",
483
+ help="Whether to generate color .gif rendering file.",
484
+ )
485
+ parser.add_argument(
486
+ "--no_index_file",
487
+ action="store_true",
488
+ help="Whether skip the index file saving.",
489
+ )
490
+ parser.add_argument(
491
+ "--gen_color_mp4",
492
+ action="store_true",
493
+ help="Whether to generate color .mp4 rendering file.",
494
+ )
495
+ parser.add_argument(
496
+ "--gen_viewnormal_mp4",
497
+ action="store_true",
498
+ help="Whether to generate view normal .mp4 rendering file.",
499
+ )
500
+ parser.add_argument(
501
+ "--gen_glonormal_mp4",
502
+ action="store_true",
503
+ help="Whether to generate global normal .mp4 rendering file.",
504
+ )
505
+ parser.add_argument(
506
+ "--video_prompts",
507
+ type=str,
508
+ nargs="+",
509
+ default=None,
510
+ help="Text prompts for the rendering.",
511
+ )
512
+
513
+ args, unknown = parser.parse_known_args()
514
+
515
+ if args.uuid is None and args.mesh_path is not None:
516
+ args.uuid = []
517
+ for path in args.mesh_path:
518
+ uuid = os.path.basename(path).split(".")[0]
519
+ args.uuid.append(uuid)
520
+
521
+ return args
522
+
523
+
524
+ def entrypoint(**kwargs) -> None:
525
+ args = parse_args()
526
+ for k, v in kwargs.items():
527
+ if hasattr(args, k) and v is not None:
528
+ setattr(args, k, v)
529
+
530
+ camera_settings = CameraSetting(
531
+ num_images=args.num_images,
532
+ elevation=args.elevation,
533
+ distance=args.distance,
534
+ resolution_hw=args.resolution_hw,
535
+ fov=math.radians(args.fov),
536
+ device="cuda",
537
+ )
538
+
539
+ render_items = [
540
+ RenderItems.ALPHA.value,
541
+ RenderItems.GLOBAL_NORMAL.value,
542
+ RenderItems.VIEW_NORMAL.value,
543
+ RenderItems.POSITION_MAP.value,
544
+ RenderItems.IMAGE.value,
545
+ RenderItems.DEPTH.value,
546
+ # RenderItems.ALBEDO.value,
547
+ # RenderItems.DIFFUSE.value,
548
+ ]
549
+
550
+ gen_video = (
551
+ args.gen_color_gif
552
+ or args.gen_color_mp4
553
+ or args.gen_viewnormal_mp4
554
+ or args.gen_glonormal_mp4
555
+ )
556
+ if gen_video:
557
+ render_items = []
558
+ if args.gen_color_gif or args.gen_color_mp4:
559
+ render_items.append(RenderItems.IMAGE.value)
560
+ if args.gen_glonormal_mp4:
561
+ render_items.append(RenderItems.GLOBAL_NORMAL.value)
562
+ if args.gen_viewnormal_mp4:
563
+ render_items.append(RenderItems.VIEW_NORMAL.value)
564
+ if RenderItems.GLOBAL_NORMAL.value not in render_items:
565
+ render_items.append(RenderItems.GLOBAL_NORMAL.value)
566
+
567
+ image_render = ImageRender(
568
+ render_items=render_items,
569
+ camera_params=camera_settings,
570
+ with_mtl=args.with_mtl,
571
+ gen_color_gif=args.gen_color_gif,
572
+ gen_color_mp4=args.gen_color_mp4,
573
+ gen_viewnormal_mp4=args.gen_viewnormal_mp4,
574
+ gen_glonormal_mp4=args.gen_glonormal_mp4,
575
+ light_factor=args.pbr_light_factor,
576
+ no_index_file=gen_video or args.no_index_file,
577
+ )
578
+ image_render.render_mesh(
579
+ mesh_path=args.mesh_path,
580
+ output_root=args.output_root,
581
+ uuid=args.uuid,
582
+ prompts=args.video_prompts,
583
+ )
584
+
585
+ return
586
+
587
+
588
+ if __name__ == "__main__":
589
+ entrypoint()
embodied_gen/data/mesh_operator.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import logging
19
+ import multiprocessing as mp
20
+ import os
21
+ from typing import Tuple, Union
22
+
23
+ import coacd
24
+ import igraph
25
+ import numpy as np
26
+ import pyvista as pv
27
+ import spaces
28
+ import torch
29
+ import trimesh
30
+ import utils3d
31
+ from pymeshfix import _meshfix
32
+ from tqdm import tqdm
33
+
34
+ logging.basicConfig(
35
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
36
+ )
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ __all__ = [
41
+ "MeshFixer",
42
+ ]
43
+
44
+
45
+ def _radical_inverse(base, n):
46
+ val = 0
47
+ inv_base = 1.0 / base
48
+ inv_base_n = inv_base
49
+ while n > 0:
50
+ digit = n % base
51
+ val += digit * inv_base_n
52
+ n //= base
53
+ inv_base_n *= inv_base
54
+ return val
55
+
56
+
57
+ def _halton_sequence(dim, n):
58
+ PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
59
+ return [_radical_inverse(PRIMES[dim], n) for dim in range(dim)]
60
+
61
+
62
+ def _hammersley_sequence(dim, n, num_samples):
63
+ return [n / num_samples] + _halton_sequence(dim - 1, n)
64
+
65
+
66
+ def _sphere_hammersley_seq(n, num_samples, offset=(0, 0), remap=False):
67
+ """Generate a point on a unit sphere using the Hammersley sequence.
68
+
69
+ Args:
70
+ n (int): The index of the sample.
71
+ num_samples (int): The total number of samples.
72
+ offset (tuple, optional): Offset for the u and v coordinates.
73
+ remap (bool, optional): Whether to remap the u coordinate.
74
+
75
+ Returns:
76
+ list: A list containing the spherical coordinates [phi, theta].
77
+ """
78
+ u, v = _hammersley_sequence(2, n, num_samples)
79
+ u += offset[0] / num_samples
80
+ v += offset[1]
81
+
82
+ if remap:
83
+ u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
84
+
85
+ theta = np.arccos(1 - 2 * u) - np.pi / 2
86
+ phi = v * 2 * np.pi
87
+ return [phi, theta]
88
+
89
+
90
+ class MeshFixer(object):
91
+ """MeshFixer simplifies and repairs 3D triangle meshes by TSDF.
92
+
93
+ Attributes:
94
+ vertices (torch.Tensor): A tensor of shape (V, 3) representing vertex positions.
95
+ faces (torch.Tensor): A tensor of shape (F, 3) representing face indices.
96
+ device (str): Device to run computations on, typically "cuda" or "cpu".
97
+
98
+ Main logic reference: https://github.com/microsoft/TRELLIS/blob/main/trellis/utils/postprocessing_utils.py#L22
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ vertices: Union[torch.Tensor, np.ndarray],
104
+ faces: Union[torch.Tensor, np.ndarray],
105
+ device: str = "cuda",
106
+ ) -> None:
107
+ self.device = device
108
+ if isinstance(vertices, np.ndarray):
109
+ vertices = torch.tensor(vertices)
110
+ self.vertices = vertices
111
+
112
+ if isinstance(faces, np.ndarray):
113
+ faces = torch.tensor(faces)
114
+ self.faces = faces
115
+
116
+ @staticmethod
117
+ def log_mesh_changes(method):
118
+ def wrapper(self, *args, **kwargs):
119
+ logger.info(
120
+ f"Before {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
121
+ )
122
+ result = method(self, *args, **kwargs)
123
+ logger.info(
124
+ f"After {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
125
+ )
126
+ return result
127
+
128
+ return wrapper
129
+
130
+ @log_mesh_changes
131
+ def fill_holes(
132
+ self,
133
+ max_hole_size: float,
134
+ max_hole_nbe: int,
135
+ resolution: int,
136
+ num_views: int,
137
+ norm_mesh_ratio: float = 1.0,
138
+ ) -> None:
139
+ self.vertices = self.vertices * norm_mesh_ratio
140
+ vertices, self.faces = self._fill_holes(
141
+ self.vertices,
142
+ self.faces,
143
+ max_hole_size,
144
+ max_hole_nbe,
145
+ resolution,
146
+ num_views,
147
+ )
148
+ self.vertices = vertices / norm_mesh_ratio
149
+
150
+ @staticmethod
151
+ @torch.no_grad()
152
+ def _fill_holes(
153
+ vertices: torch.Tensor,
154
+ faces: torch.Tensor,
155
+ max_hole_size: float,
156
+ max_hole_nbe: int,
157
+ resolution: int,
158
+ num_views: int,
159
+ ) -> Union[torch.Tensor, torch.Tensor]:
160
+ yaws, pitchs = [], []
161
+ for i in range(num_views):
162
+ y, p = _sphere_hammersley_seq(i, num_views)
163
+ yaws.append(y)
164
+ pitchs.append(p)
165
+
166
+ yaws, pitchs = torch.tensor(yaws).to(vertices), torch.tensor(
167
+ pitchs
168
+ ).to(vertices)
169
+ radius, fov = 2.0, torch.deg2rad(torch.tensor(40)).to(vertices)
170
+ projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
171
+
172
+ views = []
173
+ for yaw, pitch in zip(yaws, pitchs):
174
+ orig = (
175
+ torch.tensor(
176
+ [
177
+ torch.sin(yaw) * torch.cos(pitch),
178
+ torch.cos(yaw) * torch.cos(pitch),
179
+ torch.sin(pitch),
180
+ ]
181
+ ).to(vertices)
182
+ * radius
183
+ )
184
+ view = utils3d.torch.view_look_at(
185
+ orig,
186
+ torch.tensor([0, 0, 0]).to(vertices),
187
+ torch.tensor([0, 0, 1]).to(vertices),
188
+ )
189
+ views.append(view)
190
+ views = torch.stack(views, dim=0)
191
+
192
+ # Rasterize the mesh
193
+ visibility = torch.zeros(
194
+ faces.shape[0], dtype=torch.int32, device=faces.device
195
+ )
196
+ rastctx = utils3d.torch.RastContext(backend="cuda")
197
+
198
+ for i in tqdm(
199
+ range(views.shape[0]), total=views.shape[0], desc="Rasterizing"
200
+ ):
201
+ view = views[i]
202
+ buffers = utils3d.torch.rasterize_triangle_faces(
203
+ rastctx,
204
+ vertices[None],
205
+ faces,
206
+ resolution,
207
+ resolution,
208
+ view=view,
209
+ projection=projection,
210
+ )
211
+ face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1
212
+ face_id = torch.unique(face_id).long()
213
+ visibility[face_id] += 1
214
+
215
+ # Normalize visibility by the number of views
216
+ visibility = visibility.float() / num_views
217
+
218
+ # Mincut: Identify outer and inner faces
219
+ edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
220
+ boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
221
+ connected_components = utils3d.torch.compute_connected_components(
222
+ faces, edges, face2edge
223
+ )
224
+
225
+ outer_face_indices = torch.zeros(
226
+ faces.shape[0], dtype=torch.bool, device=faces.device
227
+ )
228
+ for i in range(len(connected_components)):
229
+ outer_face_indices[connected_components[i]] = visibility[
230
+ connected_components[i]
231
+ ] > min(
232
+ max(
233
+ visibility[connected_components[i]].quantile(0.75).item(),
234
+ 0.25,
235
+ ),
236
+ 0.5,
237
+ )
238
+
239
+ outer_face_indices = outer_face_indices.nonzero().reshape(-1)
240
+ inner_face_indices = torch.nonzero(visibility == 0).reshape(-1)
241
+
242
+ if inner_face_indices.shape[0] == 0:
243
+ return vertices, faces
244
+
245
+ # Construct dual graph (faces as nodes, edges as edges)
246
+ dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(
247
+ face2edge
248
+ )
249
+ dual_edge2edge = edges[dual_edge2edge]
250
+ dual_edges_weights = torch.norm(
251
+ vertices[dual_edge2edge[:, 0]] - vertices[dual_edge2edge[:, 1]],
252
+ dim=1,
253
+ )
254
+
255
+ # Mincut: Construct main graph and solve the mincut problem
256
+ g = igraph.Graph()
257
+ g.add_vertices(faces.shape[0])
258
+ g.add_edges(dual_edges.cpu().numpy())
259
+ g.es["weight"] = dual_edges_weights.cpu().numpy()
260
+
261
+ g.add_vertex("s") # source
262
+ g.add_vertex("t") # target
263
+
264
+ g.add_edges(
265
+ [(f, "s") for f in inner_face_indices],
266
+ attributes={
267
+ "weight": torch.ones(
268
+ inner_face_indices.shape[0], dtype=torch.float32
269
+ )
270
+ .cpu()
271
+ .numpy()
272
+ },
273
+ )
274
+ g.add_edges(
275
+ [(f, "t") for f in outer_face_indices],
276
+ attributes={
277
+ "weight": torch.ones(
278
+ outer_face_indices.shape[0], dtype=torch.float32
279
+ )
280
+ .cpu()
281
+ .numpy()
282
+ },
283
+ )
284
+
285
+ cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist())
286
+ remove_face_indices = torch.tensor(
287
+ [v for v in cut.partition[0] if v < faces.shape[0]],
288
+ dtype=torch.long,
289
+ device=faces.device,
290
+ )
291
+
292
+ # Check if the cut is valid with each connected component
293
+ to_remove_cc = utils3d.torch.compute_connected_components(
294
+ faces[remove_face_indices]
295
+ )
296
+ valid_remove_cc = []
297
+ cutting_edges = []
298
+ for cc in to_remove_cc:
299
+ # Check visibility median for connected component
300
+ visibility_median = visibility[remove_face_indices[cc]].median()
301
+ if visibility_median > 0.25:
302
+ continue
303
+
304
+ # Check if the cutting loop is small enough
305
+ cc_edge_indices, cc_edges_degree = torch.unique(
306
+ face2edge[remove_face_indices[cc]], return_counts=True
307
+ )
308
+ cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
309
+ cc_new_boundary_edge_indices = cc_boundary_edge_indices[
310
+ ~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)
311
+ ]
312
+ if len(cc_new_boundary_edge_indices) > 0:
313
+ cc_new_boundary_edge_cc = (
314
+ utils3d.torch.compute_edge_connected_components(
315
+ edges[cc_new_boundary_edge_indices]
316
+ )
317
+ )
318
+ cc_new_boundary_edges_cc_center = [
319
+ vertices[edges[cc_new_boundary_edge_indices[edge_cc]]]
320
+ .mean(dim=1)
321
+ .mean(dim=0)
322
+ for edge_cc in cc_new_boundary_edge_cc
323
+ ]
324
+ cc_new_boundary_edges_cc_area = []
325
+ for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
326
+ _e1 = (
327
+ vertices[
328
+ edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]
329
+ ]
330
+ - cc_new_boundary_edges_cc_center[i]
331
+ )
332
+ _e2 = (
333
+ vertices[
334
+ edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]
335
+ ]
336
+ - cc_new_boundary_edges_cc_center[i]
337
+ )
338
+ cc_new_boundary_edges_cc_area.append(
339
+ torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum()
340
+ * 0.5
341
+ )
342
+ cutting_edges.append(cc_new_boundary_edge_indices)
343
+ if any(
344
+ [
345
+ _l > max_hole_size
346
+ for _l in cc_new_boundary_edges_cc_area
347
+ ]
348
+ ):
349
+ continue
350
+
351
+ valid_remove_cc.append(cc)
352
+
353
+ if len(valid_remove_cc) > 0:
354
+ remove_face_indices = remove_face_indices[
355
+ torch.cat(valid_remove_cc)
356
+ ]
357
+ mask = torch.ones(
358
+ faces.shape[0], dtype=torch.bool, device=faces.device
359
+ )
360
+ mask[remove_face_indices] = 0
361
+ faces = faces[mask]
362
+ faces, vertices = utils3d.torch.remove_unreferenced_vertices(
363
+ faces, vertices
364
+ )
365
+
366
+ tqdm.write(f"Removed {(~mask).sum()} faces by mincut")
367
+ else:
368
+ tqdm.write(f"Removed 0 faces by mincut")
369
+
370
+ # Fill small boundaries (holes)
371
+ mesh = _meshfix.PyTMesh()
372
+ mesh.load_array(vertices.cpu().numpy(), faces.cpu().numpy())
373
+ mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
374
+
375
+ _vertices, _faces = mesh.return_arrays()
376
+ vertices = torch.tensor(_vertices).to(vertices)
377
+ faces = torch.tensor(_faces).to(faces)
378
+
379
+ return vertices, faces
380
+
381
+ @property
382
+ def vertices_np(self) -> np.ndarray:
383
+ return self.vertices.cpu().numpy()
384
+
385
+ @property
386
+ def faces_np(self) -> np.ndarray:
387
+ return self.faces.cpu().numpy()
388
+
389
+ @log_mesh_changes
390
+ def simplify(self, ratio: float) -> None:
391
+ """Simplify the mesh using quadric edge collapse decimation.
392
+
393
+ Args:
394
+ ratio (float): Ratio of faces to filter out.
395
+ """
396
+ if ratio <= 0 or ratio >= 1:
397
+ raise ValueError("Simplify ratio must be between 0 and 1.")
398
+
399
+ # Convert to PyVista format for simplification
400
+ mesh = pv.PolyData(
401
+ self.vertices_np,
402
+ np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
403
+ )
404
+ mesh.clean(inplace=True)
405
+ mesh.clear_data()
406
+ mesh = mesh.triangulate()
407
+ mesh = mesh.decimate(ratio, progress_bar=True)
408
+
409
+ # Update vertices and faces
410
+ self.vertices = torch.tensor(
411
+ mesh.points, device=self.device, dtype=torch.float32
412
+ )
413
+ self.faces = torch.tensor(
414
+ mesh.faces.reshape(-1, 4)[:, 1:],
415
+ device=self.device,
416
+ dtype=torch.int32,
417
+ )
418
+
419
+ @spaces.GPU
420
+ def __call__(
421
+ self,
422
+ filter_ratio: float,
423
+ max_hole_size: float,
424
+ resolution: int,
425
+ num_views: int,
426
+ norm_mesh_ratio: float = 1.0,
427
+ ) -> Tuple[np.ndarray, np.ndarray]:
428
+ """Post-process the mesh by simplifying and filling holes.
429
+
430
+ This method performs a two-step process:
431
+ 1. Simplifies mesh by reducing faces using quadric edge decimation.
432
+ 2. Fills holes by removing invisible faces, repairing small boundaries.
433
+
434
+ Args:
435
+ filter_ratio (float): Ratio of faces to simplify out.
436
+ Must be in the range (0, 1).
437
+ max_hole_size (float): Maximum area of a hole to fill. Connected
438
+ components of holes larger than this size will not be repaired.
439
+ resolution (int): Resolution of the rasterization buffer.
440
+ num_views (int): Number of viewpoints to sample for rasterization.
441
+ norm_mesh_ratio (float, optional): A scaling factor applied to the
442
+ vertices of the mesh during processing.
443
+
444
+ Returns:
445
+ Tuple[np.ndarray, np.ndarray]:
446
+ - vertices: Simplified and repaired vertex array of (V, 3).
447
+ - faces: Simplified and repaired face array of (F, 3).
448
+ """
449
+ self.vertices = self.vertices.to(self.device)
450
+ self.faces = self.faces.to(self.device)
451
+
452
+ self.simplify(ratio=filter_ratio)
453
+ self.fill_holes(
454
+ max_hole_size=max_hole_size,
455
+ max_hole_nbe=int(250 * np.sqrt(1 - filter_ratio)),
456
+ resolution=resolution,
457
+ num_views=num_views,
458
+ norm_mesh_ratio=norm_mesh_ratio,
459
+ )
460
+
461
+ return self.vertices_np, self.faces_np
embodied_gen/data/utils.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import math
19
+ import os
20
+ import random
21
+ import zipfile
22
+ from shutil import rmtree
23
+ from typing import List, Tuple, Union
24
+
25
+ import cv2
26
+ import kaolin as kal
27
+ import numpy as np
28
+ import nvdiffrast.torch as dr
29
+ import torch
30
+ import torch.nn.functional as F
31
+ from PIL import Image, ImageEnhance
32
+
33
+ try:
34
+ from kolors.models.modeling_chatglm import ChatGLMModel
35
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
36
+ except ImportError:
37
+ ChatGLMTokenizer = None
38
+ ChatGLMModel = None
39
+ import logging
40
+ from dataclasses import dataclass, field
41
+
42
+ import trimesh
43
+ from kaolin.render.camera import Camera
44
+ from torch import nn
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ __all__ = [
50
+ "DiffrastRender",
51
+ "save_images",
52
+ "render_pbr",
53
+ "prelabel_text_feature",
54
+ "calc_vertex_normals",
55
+ "normalize_vertices_array",
56
+ "load_mesh_to_unit_cube",
57
+ "as_list",
58
+ "CameraSetting",
59
+ "import_kaolin_mesh",
60
+ "save_mesh_with_mtl",
61
+ "get_images_from_grid",
62
+ "post_process_texture",
63
+ "quat_mult",
64
+ "quat_to_rotmat",
65
+ "gamma_shs",
66
+ "resize_pil",
67
+ "trellis_preprocess",
68
+ "delete_dir",
69
+ ]
70
+
71
+
72
+ class DiffrastRender(object):
73
+ """A class to handle differentiable rendering using nvdiffrast.
74
+
75
+ This class provides methods to render position, depth, and normal maps
76
+ with optional anti-aliasing and gradient disabling for rasterization.
77
+
78
+ Attributes:
79
+ p_mtx (torch.Tensor): Projection matrix.
80
+ mv_mtx (torch.Tensor): Model-view matrix.
81
+ mvp_mtx (torch.Tensor): Model-view-projection matrix, calculated as
82
+ p_mtx @ mv_mtx if not provided.
83
+ resolution_hw (Tuple[int, int]): Height and width of the rendering resolution. # noqa
84
+ _ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): Rasterization context. # noqa
85
+ mask_thresh (float): Threshold for mask creation.
86
+ grad_db (bool): Whether to disable gradients during rasterization.
87
+ antialias_mask (bool): Whether to apply anti-aliasing to the mask.
88
+ device (str): Device used for rendering ('cuda' or 'cpu').
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ p_matrix: torch.Tensor,
94
+ mv_matrix: torch.Tensor,
95
+ resolution_hw: Tuple[int, int],
96
+ context: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None,
97
+ mvp_matrix: torch.Tensor = None,
98
+ mask_thresh: float = 0.5,
99
+ grad_db: bool = False,
100
+ antialias_mask: bool = True,
101
+ align_coordinate: bool = True,
102
+ device: str = "cuda",
103
+ ) -> None:
104
+ self.p_mtx = p_matrix
105
+ self.mv_mtx = mv_matrix
106
+ if mvp_matrix is None:
107
+ self.mvp_mtx = torch.bmm(p_matrix, mv_matrix)
108
+
109
+ self.resolution_hw = resolution_hw
110
+ if context is None:
111
+ context = dr.RasterizeCudaContext(device=device)
112
+ self._ctx = context
113
+ self.mask_thresh = mask_thresh
114
+ self.grad_db = grad_db
115
+ self.antialias_mask = antialias_mask
116
+ self.align_coordinate = align_coordinate
117
+ self.device = device
118
+
119
+ def compute_dr_raster(
120
+ self,
121
+ vertices: torch.Tensor,
122
+ faces: torch.Tensor,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ vertices_clip = self.transform_vertices(vertices, matrix=self.mvp_mtx)
125
+ rast, _ = dr.rasterize(
126
+ self._ctx,
127
+ vertices_clip,
128
+ faces.int(),
129
+ resolution=self.resolution_hw,
130
+ grad_db=self.grad_db,
131
+ )
132
+
133
+ return rast, vertices_clip
134
+
135
+ def transform_vertices(
136
+ self,
137
+ vertices: torch.Tensor,
138
+ matrix: torch.Tensor,
139
+ ) -> torch.Tensor:
140
+ verts_ones = torch.ones(
141
+ (len(vertices), 1), device=vertices.device, dtype=vertices.dtype
142
+ )
143
+ verts_homo = torch.cat([vertices, verts_ones], dim=-1)
144
+ trans_vertices = torch.matmul(verts_homo, matrix.permute(0, 2, 1))
145
+
146
+ return trans_vertices
147
+
148
+ def normalize_map_by_mask_separately(
149
+ self, map: torch.Tensor, mask: torch.Tensor
150
+ ) -> torch.Tensor:
151
+ # Normalize each map separately by mask, normalized map in [0, 1].
152
+ normalized_maps = []
153
+ for map_item, mask_item in zip(map, mask):
154
+ normalized_map = self.normalize_map_by_mask(map_item, mask_item)
155
+ normalized_maps.append(normalized_map)
156
+
157
+ normalized_maps = torch.stack(normalized_maps, dim=0)
158
+
159
+ return normalized_maps
160
+
161
+ @staticmethod
162
+ def normalize_map_by_mask(
163
+ map: torch.Tensor, mask: torch.Tensor
164
+ ) -> torch.Tensor:
165
+ # Normalize all maps in total by mask, normalized map in [0, 1].
166
+ foreground = (mask == 1).squeeze(dim=-1)
167
+ foreground_elements = map[foreground]
168
+ if len(foreground_elements) == 0:
169
+ return map
170
+
171
+ min_val, _ = foreground_elements.min(dim=0)
172
+ max_val, _ = foreground_elements.max(dim=0)
173
+ val_range = (max_val - min_val).clip(min=1e-6)
174
+
175
+ normalized_map = (map - min_val) / val_range
176
+ normalized_map = torch.lerp(
177
+ torch.zeros_like(normalized_map), normalized_map, mask
178
+ )
179
+ normalized_map[normalized_map < 0] = 0
180
+
181
+ return normalized_map
182
+
183
+ def _compute_mask(
184
+ self,
185
+ rast: torch.Tensor,
186
+ vertices_clip: torch.Tensor,
187
+ faces: torch.Tensor,
188
+ ) -> torch.Tensor:
189
+ mask = (rast[..., 3:] > 0).float()
190
+ mask = mask.clip(min=0, max=1)
191
+
192
+ if self.antialias_mask is True:
193
+ mask = dr.antialias(mask, rast, vertices_clip, faces)
194
+ else:
195
+ foreground = mask > self.mask_thresh
196
+ mask[foreground] = 1
197
+ mask[~foreground] = 0
198
+
199
+ return mask
200
+
201
+ def render_rast_alpha(
202
+ self,
203
+ vertices: torch.Tensor,
204
+ faces: torch.Tensor,
205
+ ):
206
+ faces = faces.to(torch.int32)
207
+ rast, vertices_clip = self.compute_dr_raster(vertices, faces)
208
+ mask = self._compute_mask(rast, vertices_clip, faces)
209
+
210
+ return mask, rast
211
+
212
+ def render_position(
213
+ self,
214
+ vertices: torch.Tensor,
215
+ faces: torch.Tensor,
216
+ ) -> Union[torch.Tensor, torch.Tensor]:
217
+ # Vertices in model coordinate system, real position coordinate number.
218
+ faces = faces.to(torch.int32)
219
+ mask, rast = self.render_rast_alpha(vertices, faces)
220
+
221
+ vertices_model = vertices[None, ...].contiguous().float()
222
+ position_map, _ = dr.interpolate(vertices_model, rast, faces)
223
+ # Align with blender.
224
+ if self.align_coordinate:
225
+ position_map = position_map[..., [0, 2, 1]]
226
+ position_map[..., 1] = -position_map[..., 1]
227
+
228
+ position_map = torch.lerp(
229
+ torch.zeros_like(position_map), position_map, mask
230
+ )
231
+
232
+ return position_map, mask
233
+
234
+ def render_uv(
235
+ self,
236
+ vertices: torch.Tensor,
237
+ faces: torch.Tensor,
238
+ vtx_uv: torch.Tensor,
239
+ ) -> Union[torch.Tensor, torch.Tensor]:
240
+ faces = faces.to(torch.int32)
241
+ mask, rast = self.render_rast_alpha(vertices, faces)
242
+ uv_map, _ = dr.interpolate(vtx_uv, rast, faces)
243
+ uv_map = torch.lerp(torch.zeros_like(uv_map), uv_map, mask)
244
+
245
+ return uv_map, mask
246
+
247
+ def render_depth(
248
+ self,
249
+ vertices: torch.Tensor,
250
+ faces: torch.Tensor,
251
+ ) -> Union[torch.Tensor, torch.Tensor]:
252
+ # Vertices in model coordinate system, real depth coordinate number.
253
+ faces = faces.to(torch.int32)
254
+ mask, rast = self.render_rast_alpha(vertices, faces)
255
+
256
+ vertices_camera = self.transform_vertices(vertices, matrix=self.mv_mtx)
257
+ vertices_camera = vertices_camera[..., 2:3].contiguous().float()
258
+ depth_map, _ = dr.interpolate(vertices_camera, rast, faces)
259
+ # Change camera depth minus to positive.
260
+ if self.align_coordinate:
261
+ depth_map = -depth_map
262
+ depth_map = torch.lerp(torch.zeros_like(depth_map), depth_map, mask)
263
+
264
+ return depth_map, mask
265
+
266
+ def render_global_normal(
267
+ self,
268
+ vertices: torch.Tensor,
269
+ faces: torch.Tensor,
270
+ vertice_normals: torch.Tensor,
271
+ ) -> Union[torch.Tensor, torch.Tensor]:
272
+ # NOTE: vertice_normals in [-1, 1], return normal in [0, 1].
273
+ # vertices / vertice_normals in model coordinate system.
274
+ faces = faces.to(torch.int32)
275
+ mask, rast = self.render_rast_alpha(vertices, faces)
276
+ im_base_normals, _ = dr.interpolate(
277
+ vertice_normals[None, ...].float(), rast, faces
278
+ )
279
+
280
+ if im_base_normals is not None:
281
+ faces = faces.to(torch.int64)
282
+ vertices_cam = self.transform_vertices(
283
+ vertices, matrix=self.mv_mtx
284
+ )
285
+ face_vertices_ndc = kal.ops.mesh.index_vertices_by_faces(
286
+ vertices_cam[..., :3], faces
287
+ )
288
+ face_normal_sign = kal.ops.mesh.face_normals(face_vertices_ndc)[
289
+ ..., 2
290
+ ]
291
+ for idx in range(len(im_base_normals)):
292
+ face_idx = (rast[idx, ..., -1].long() - 1).contiguous()
293
+ im_normal_sign = torch.sign(face_normal_sign[idx, face_idx])
294
+ im_normal_sign[face_idx == -1] = 0
295
+ im_base_normals[idx] *= im_normal_sign.unsqueeze(-1)
296
+
297
+ normal = (im_base_normals + 1) / 2
298
+ normal = normal.clip(min=0, max=1)
299
+ normal = torch.lerp(torch.zeros_like(normal), normal, mask)
300
+
301
+ return normal, mask
302
+
303
+ def transform_normal(
304
+ self,
305
+ normals: torch.Tensor,
306
+ trans_matrix: torch.Tensor,
307
+ masks: torch.Tensor,
308
+ to_view: bool,
309
+ ) -> torch.Tensor:
310
+ # NOTE: input normals in [0, 1], output normals in [0, 1].
311
+ normals = normals.clone()
312
+ assert len(normals) == len(trans_matrix)
313
+
314
+ if not to_view:
315
+ # Flip the sign on the x-axis to match inv bae system for global transformation. # noqa
316
+ normals[..., 0] = 1 - normals[..., 0]
317
+
318
+ normals = 2 * normals - 1
319
+ b, h, w, c = normals.shape
320
+
321
+ transformed_normals = []
322
+ for normal, matrix in zip(normals, trans_matrix):
323
+ # Transform normals using the transformation matrix (4x4).
324
+ reshaped_normals = normal.view(-1, c) # (h w 3) -> (hw 3)
325
+ padded_vectors = torch.nn.functional.pad(
326
+ reshaped_normals, pad=(0, 1), mode="constant", value=0.0
327
+ )
328
+ transformed_normal = torch.matmul(
329
+ padded_vectors, matrix.transpose(0, 1)
330
+ )[..., :3]
331
+
332
+ # Normalize and clip the normals to [0, 1] range.
333
+ transformed_normal = F.normalize(transformed_normal, p=2, dim=-1)
334
+ transformed_normal = (transformed_normal + 1) / 2
335
+
336
+ if to_view:
337
+ # Flip the sign on the x-axis to match bae system for view transformation. # noqa
338
+ transformed_normal[..., 0] = 1 - transformed_normal[..., 0]
339
+
340
+ transformed_normals.append(transformed_normal.view(h, w, c))
341
+
342
+ transformed_normals = torch.stack(transformed_normals, dim=0)
343
+
344
+ if masks is not None:
345
+ transformed_normals = torch.lerp(
346
+ torch.zeros_like(transformed_normals),
347
+ transformed_normals,
348
+ masks,
349
+ )
350
+
351
+ return transformed_normals
352
+
353
+
354
+ def _az_el_to_points(
355
+ azimuths: np.ndarray, elevations: np.ndarray
356
+ ) -> np.ndarray:
357
+ x = np.cos(azimuths) * np.cos(elevations)
358
+ y = np.sin(azimuths) * np.cos(elevations)
359
+ z = np.sin(elevations)
360
+
361
+ return np.stack([x, y, z], axis=-1)
362
+
363
+
364
+ def _compute_az_el_by_views(
365
+ num_view: int, el: float
366
+ ) -> Tuple[np.ndarray, np.ndarray]:
367
+ azimuths = np.arange(num_view) / num_view * np.pi * 2
368
+ elevations = np.deg2rad(np.array([el] * num_view))
369
+
370
+ return azimuths, elevations
371
+
372
+
373
+ def _compute_cam_pts_by_az_el(
374
+ azs: np.ndarray,
375
+ els: np.ndarray,
376
+ distance: float,
377
+ extra_pts: np.ndarray = None,
378
+ ) -> np.ndarray:
379
+ distances = np.array([distance for _ in range(len(azs))])
380
+ cam_pts = _az_el_to_points(azs, els) * distances[:, None]
381
+
382
+ if extra_pts is not None:
383
+ cam_pts = np.concatenate([cam_pts, extra_pts], axis=0)
384
+
385
+ # Align coordinate system.
386
+ cam_pts = cam_pts[:, [0, 2, 1]] # xyz -> xzy
387
+ cam_pts[..., 2] = -cam_pts[..., 2]
388
+
389
+ return cam_pts
390
+
391
+
392
+ def compute_cam_pts_by_views(
393
+ num_view: int, el: float, distance: float, extra_pts: np.ndarray = None
394
+ ) -> torch.Tensor:
395
+ """Computes object-center camera points for a given number of views.
396
+
397
+ Args:
398
+ num_view (int): The number of views (camera positions) to compute.
399
+ el (float): The elevation angle in degrees.
400
+ distance (float): The distance from the origin to the camera.
401
+ extra_pts (np.ndarray): Extra camera points postion.
402
+
403
+ Returns:
404
+ torch.Tensor: A tensor containing the camera points for each view, with shape `(num_view, 3)`. # noqa
405
+ """
406
+ azimuths, elevations = _compute_az_el_by_views(num_view, el)
407
+ cam_pts = _compute_cam_pts_by_az_el(
408
+ azimuths, elevations, distance, extra_pts
409
+ )
410
+
411
+ return cam_pts
412
+
413
+
414
+ def save_images(
415
+ images: Union[list[np.ndarray], list[torch.Tensor]],
416
+ output_dir: str,
417
+ cvt_color: str = None,
418
+ format: str = ".png",
419
+ to_uint8: bool = True,
420
+ verbose: bool = False,
421
+ ) -> List[str]:
422
+ # NOTE: images in [0, 1]
423
+ os.makedirs(output_dir, exist_ok=True)
424
+ save_paths = []
425
+ for idx, image in enumerate(images):
426
+ if isinstance(image, torch.Tensor):
427
+ image = image.detach().cpu().numpy()
428
+ if to_uint8:
429
+ image = image.clip(min=0, max=1)
430
+ image = (255.0 * image).astype(np.uint8)
431
+ if cvt_color is not None:
432
+ image = cv2.cvtColor(image, cvt_color)
433
+ save_path = os.path.join(output_dir, f"{idx:04d}{format}")
434
+ save_paths.append(save_path)
435
+
436
+ cv2.imwrite(save_path, image)
437
+
438
+ if verbose:
439
+ logger.info(f"Images saved in {output_dir}")
440
+
441
+ return save_paths
442
+
443
+
444
+ def _current_lighting(
445
+ azimuths: List[float],
446
+ elevations: List[float],
447
+ light_factor: float = 1.0,
448
+ device: str = "cuda",
449
+ ):
450
+ # azimuths, elevations in degress.
451
+ directions = []
452
+ for az, el in zip(azimuths, elevations):
453
+ az, el = math.radians(az), math.radians(el)
454
+ direction = kal.render.lighting.sg_direction_from_azimuth_elevation(
455
+ az, el
456
+ )
457
+ directions.append(direction)
458
+ directions = torch.cat(directions, dim=0)
459
+
460
+ amplitude = torch.ones_like(directions) * light_factor
461
+ light_condition = kal.render.lighting.SgLightingParameters(
462
+ amplitude=amplitude,
463
+ direction=directions,
464
+ sharpness=3,
465
+ ).to(device)
466
+
467
+ # light_condition = kal.render.lighting.SgLightingParameters.from_sun(
468
+ # directions, strength=1, angle=90, color=None
469
+ # ).to(device)
470
+
471
+ return light_condition
472
+
473
+
474
+ def render_pbr(
475
+ mesh,
476
+ camera,
477
+ device="cuda",
478
+ cxt=None,
479
+ custom_materials=None,
480
+ light_factor=1.0,
481
+ ):
482
+ if cxt is None:
483
+ cxt = dr.RasterizeCudaContext()
484
+
485
+ light_condition = _current_lighting(
486
+ azimuths=[0, 90, 180, 270],
487
+ elevations=[90, 60, 30, 20],
488
+ light_factor=light_factor,
489
+ device=device,
490
+ )
491
+ render_res = kal.render.easy_render.render_mesh(
492
+ camera,
493
+ mesh,
494
+ lighting=light_condition,
495
+ nvdiffrast_context=cxt,
496
+ custom_materials=custom_materials,
497
+ )
498
+
499
+ image = render_res[kal.render.easy_render.RenderPass.render]
500
+ image = image.clip(0, 1)
501
+
502
+ albedo = render_res[kal.render.easy_render.RenderPass.albedo]
503
+ albedo = albedo.clip(0, 1)
504
+
505
+ diffuse = render_res[kal.render.easy_render.RenderPass.diffuse]
506
+ diffuse = diffuse.clip(0, 1)
507
+
508
+ normal = render_res[kal.render.easy_render.RenderPass.normals]
509
+ normal = normal.clip(-1, 1)
510
+
511
+ return image, albedo, diffuse, normal
512
+
513
+
514
+ def _move_to_target_device(data, device: str):
515
+ if isinstance(data, dict):
516
+ for key, value in data.items():
517
+ data[key] = _move_to_target_device(value, device)
518
+ elif isinstance(data, torch.Tensor):
519
+ return data.to(device)
520
+
521
+ return data
522
+
523
+
524
+ def _encode_prompt(
525
+ prompt_batch,
526
+ text_encoders,
527
+ tokenizers,
528
+ proportion_empty_prompts=0,
529
+ is_train=True,
530
+ ):
531
+ prompt_embeds_list = []
532
+
533
+ captions = []
534
+ for caption in prompt_batch:
535
+ if random.random() < proportion_empty_prompts:
536
+ captions.append("")
537
+ elif isinstance(caption, str):
538
+ captions.append(caption)
539
+ elif isinstance(caption, (list, np.ndarray)):
540
+ captions.append(random.choice(caption) if is_train else caption[0])
541
+
542
+ with torch.no_grad():
543
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
544
+ text_inputs = tokenizer(
545
+ captions,
546
+ padding="max_length",
547
+ max_length=256,
548
+ truncation=True,
549
+ return_tensors="pt",
550
+ ).to(text_encoder.device)
551
+
552
+ output = text_encoder(
553
+ input_ids=text_inputs.input_ids,
554
+ attention_mask=text_inputs.attention_mask,
555
+ position_ids=text_inputs.position_ids,
556
+ output_hidden_states=True,
557
+ )
558
+
559
+ # We are only interested in the pooled output of the text encoder.
560
+ prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
561
+ pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()
562
+ bs_embed, seq_len, _ = prompt_embeds.shape
563
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
564
+ prompt_embeds_list.append(prompt_embeds)
565
+
566
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
567
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
568
+
569
+ return prompt_embeds, pooled_prompt_embeds
570
+
571
+
572
+ def load_llm_models(pretrained_model_name_or_path: str, device: str):
573
+ tokenizer = ChatGLMTokenizer.from_pretrained(
574
+ pretrained_model_name_or_path,
575
+ subfolder="text_encoder",
576
+ )
577
+ text_encoder = ChatGLMModel.from_pretrained(
578
+ pretrained_model_name_or_path,
579
+ subfolder="text_encoder",
580
+ ).to(device)
581
+
582
+ text_encoders = [
583
+ text_encoder,
584
+ ]
585
+ tokenizers = [
586
+ tokenizer,
587
+ ]
588
+
589
+ logger.info(f"Load model from {pretrained_model_name_or_path} done.")
590
+
591
+ return tokenizers, text_encoders
592
+
593
+
594
+ def prelabel_text_feature(
595
+ prompt_batch: List[str],
596
+ output_dir: str,
597
+ tokenizers: nn.Module,
598
+ text_encoders: nn.Module,
599
+ ) -> List[str]:
600
+ os.makedirs(output_dir, exist_ok=True)
601
+
602
+ # prompt_batch ["text..."]
603
+ prompt_embeds, pooled_prompt_embeds = _encode_prompt(
604
+ prompt_batch, text_encoders, tokenizers
605
+ )
606
+
607
+ prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu")
608
+ pooled_prompt_embeds = _move_to_target_device(
609
+ pooled_prompt_embeds, device="cpu"
610
+ )
611
+
612
+ data_dict = dict(
613
+ prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds
614
+ )
615
+
616
+ save_path = os.path.join(output_dir, "text_feat.pth")
617
+ torch.save(data_dict, save_path)
618
+
619
+ return save_path
620
+
621
+
622
+ def _calc_face_normals(
623
+ vertices: torch.Tensor, # V,3 first vertex may be unreferenced
624
+ faces: torch.Tensor, # F,3 long, first face may be all zero
625
+ normalize: bool = False,
626
+ ) -> torch.Tensor: # F,3
627
+ full_vertices = vertices[faces] # F,C=3,3
628
+ v0, v1, v2 = full_vertices.unbind(dim=1) # F,3
629
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=1) # F,3
630
+ if normalize:
631
+ face_normals = F.normalize(
632
+ face_normals, eps=1e-6, dim=1
633
+ ) # TODO inplace?
634
+ return face_normals # F,3
635
+
636
+
637
+ def calc_vertex_normals(
638
+ vertices: torch.Tensor, # V,3 first vertex may be unreferenced
639
+ faces: torch.Tensor, # F,3 long, first face may be all zero
640
+ face_normals: torch.Tensor = None, # F,3, not normalized
641
+ ) -> torch.Tensor: # F,3
642
+ _F = faces.shape[0]
643
+
644
+ if face_normals is None:
645
+ face_normals = _calc_face_normals(vertices, faces)
646
+
647
+ vertex_normals = torch.zeros(
648
+ (vertices.shape[0], 3, 3), dtype=vertices.dtype, device=vertices.device
649
+ ) # V,C=3,3
650
+ vertex_normals.scatter_add_(
651
+ dim=0,
652
+ index=faces[:, :, None].expand(_F, 3, 3),
653
+ src=face_normals[:, None, :].expand(_F, 3, 3),
654
+ )
655
+ vertex_normals = vertex_normals.sum(dim=1) # V,3
656
+ return F.normalize(vertex_normals, eps=1e-6, dim=1)
657
+
658
+
659
+ def normalize_vertices_array(
660
+ vertices: Union[torch.Tensor, np.ndarray],
661
+ mesh_scale: float = 1.0,
662
+ exec_norm: bool = True,
663
+ ):
664
+ if isinstance(vertices, torch.Tensor):
665
+ bbmin, bbmax = vertices.min(0)[0], vertices.max(0)[0]
666
+ else:
667
+ bbmin, bbmax = vertices.min(0), vertices.max(0) # (3,)
668
+ center = (bbmin + bbmax) * 0.5
669
+ bbsize = bbmax - bbmin
670
+ scale = 2 * mesh_scale / bbsize.max()
671
+ if exec_norm:
672
+ vertices = (vertices - center) * scale
673
+
674
+ return vertices, scale, center
675
+
676
+
677
+ def load_mesh_to_unit_cube(
678
+ mesh_file: str,
679
+ mesh_scale: float = 1.0,
680
+ ) -> tuple[trimesh.Trimesh, float, list[float]]:
681
+ if not os.path.exists(mesh_file):
682
+ raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.")
683
+
684
+ mesh = trimesh.load(mesh_file)
685
+ if isinstance(mesh, trimesh.Scene):
686
+ mesh = trimesh.utils.concatenate(mesh)
687
+
688
+ vertices, scale, center = normalize_vertices_array(
689
+ mesh.vertices, mesh_scale
690
+ )
691
+ mesh.vertices = vertices
692
+
693
+ return mesh, scale, center
694
+
695
+
696
+ def as_list(obj):
697
+ if isinstance(obj, (list, tuple)):
698
+ return obj
699
+ elif isinstance(obj, set):
700
+ return list(obj)
701
+ elif obj is None:
702
+ return obj
703
+ else:
704
+ return [obj]
705
+
706
+
707
+ @dataclass
708
+ class CameraSetting:
709
+ """Camera settings for images rendering."""
710
+
711
+ num_images: int
712
+ elevation: list[float]
713
+ distance: float
714
+ resolution_hw: tuple[int, int]
715
+ fov: float
716
+ at: tuple[float, float, float] = field(
717
+ default_factory=lambda: (0.0, 0.0, 0.0)
718
+ )
719
+ up: tuple[float, float, float] = field(
720
+ default_factory=lambda: (0.0, 1.0, 0.0)
721
+ )
722
+ device: str = "cuda"
723
+ near: float = 1e-2
724
+ far: float = 1e2
725
+
726
+ def __post_init__(
727
+ self,
728
+ ):
729
+ h = self.resolution_hw[0]
730
+ f = (h / 2) / math.tan(self.fov / 2)
731
+ cx = self.resolution_hw[1] / 2
732
+ cy = self.resolution_hw[0] / 2
733
+ Ks = [
734
+ [f, 0, cx],
735
+ [0, f, cy],
736
+ [0, 0, 1],
737
+ ]
738
+
739
+ self.Ks = Ks
740
+
741
+
742
+ def _compute_az_el_by_camera_params(
743
+ camera_params: CameraSetting, flip_az: bool = False
744
+ ):
745
+ num_view = camera_params.num_images // len(camera_params.elevation)
746
+ view_interval = 2 * np.pi / num_view / 2
747
+ if num_view == 1:
748
+ view_interval = np.pi / 2
749
+ azimuths = []
750
+ elevations = []
751
+ for idx, el in enumerate(camera_params.elevation):
752
+ azs = np.arange(num_view) / num_view * np.pi * 2 + idx * view_interval
753
+ if flip_az:
754
+ azs *= -1
755
+ els = np.deg2rad(np.array([el] * num_view))
756
+ azimuths.append(azs)
757
+ elevations.append(els)
758
+
759
+ azimuths = np.concatenate(azimuths, axis=0)
760
+ elevations = np.concatenate(elevations, axis=0)
761
+
762
+ return azimuths, elevations
763
+
764
+
765
+ def init_kal_camera(
766
+ camera_params: CameraSetting,
767
+ flip_az: bool = False,
768
+ ) -> Camera:
769
+ azimuths, elevations = _compute_az_el_by_camera_params(
770
+ camera_params, flip_az
771
+ )
772
+ cam_pts = _compute_cam_pts_by_az_el(
773
+ azimuths, elevations, camera_params.distance
774
+ )
775
+
776
+ up = torch.cat(
777
+ [
778
+ torch.tensor(camera_params.up).repeat(camera_params.num_images, 1),
779
+ ],
780
+ dim=0,
781
+ )
782
+
783
+ camera = Camera.from_args(
784
+ eye=torch.tensor(cam_pts),
785
+ at=torch.tensor(camera_params.at),
786
+ up=up,
787
+ fov=camera_params.fov,
788
+ height=camera_params.resolution_hw[0],
789
+ width=camera_params.resolution_hw[1],
790
+ near=camera_params.near,
791
+ far=camera_params.far,
792
+ device=camera_params.device,
793
+ )
794
+
795
+ return camera
796
+
797
+
798
+ def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False):
799
+ if mesh_path.endswith(".glb"):
800
+ mesh = kal.io.gltf.import_mesh(mesh_path)
801
+ elif mesh_path.endswith(".obj"):
802
+ with_material = True if with_mtl else False
803
+ mesh = kal.io.obj.import_mesh(mesh_path, with_materials=with_material)
804
+ if with_mtl and mesh.materials and len(mesh.materials) > 0:
805
+ material = kal.render.materials.PBRMaterial()
806
+ assert (
807
+ "map_Kd" in mesh.materials[0]
808
+ ), "'map_Kd' not found in materials."
809
+ material.diffuse_texture = mesh.materials[0]["map_Kd"] / 255.0
810
+ mesh.materials = [material]
811
+ elif mesh_path.endswith(".ply"):
812
+ mesh = trimesh.load(mesh_path)
813
+ mesh_path = mesh_path.replace(".ply", ".obj")
814
+ mesh.export(mesh_path)
815
+ mesh = kal.io.obj.import_mesh(mesh_path)
816
+ elif mesh_path.endswith(".off"):
817
+ mesh = kal.io.off.import_mesh(mesh_path)
818
+ else:
819
+ raise RuntimeError(
820
+ f"{mesh_path} mesh type not supported, "
821
+ "supported mesh type `.glb`, `.obj`, `.ply`, `.off`."
822
+ )
823
+
824
+ return mesh
825
+
826
+
827
+ def save_mesh_with_mtl(
828
+ vertices: np.ndarray,
829
+ faces: np.ndarray,
830
+ uvs: np.ndarray,
831
+ texture: Union[Image.Image, np.ndarray],
832
+ output_path: str,
833
+ material_base=(250, 250, 250, 255),
834
+ ) -> trimesh.Trimesh:
835
+ if isinstance(texture, np.ndarray):
836
+ texture = Image.fromarray(texture)
837
+
838
+ mesh = trimesh.Trimesh(
839
+ vertices,
840
+ faces,
841
+ visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture),
842
+ )
843
+ mesh.visual.material = trimesh.visual.material.SimpleMaterial(
844
+ image=texture,
845
+ diffuse=material_base,
846
+ ambient=material_base,
847
+ specular=material_base,
848
+ )
849
+
850
+ dir_name = os.path.dirname(output_path)
851
+ os.makedirs(dir_name, exist_ok=True)
852
+
853
+ _ = mesh.export(output_path)
854
+ # texture.save(os.path.join(dir_name, f"{file_name}_texture.png"))
855
+
856
+ logger.info(f"Saved mesh with texture to {output_path}")
857
+
858
+ return mesh
859
+
860
+
861
+ def get_images_from_grid(
862
+ image: Union[str, Image.Image], img_size: int
863
+ ) -> list[Image.Image]:
864
+ if isinstance(image, str):
865
+ image = Image.open(image)
866
+
867
+ view_images = np.array(image)
868
+ height, width, _ = view_images.shape
869
+ rows = height // img_size
870
+ cols = width // img_size
871
+ blocks = []
872
+ for i in range(rows):
873
+ for j in range(cols):
874
+ block = view_images[
875
+ i * img_size : (i + 1) * img_size,
876
+ j * img_size : (j + 1) * img_size,
877
+ :,
878
+ ]
879
+ blocks.append(Image.fromarray(block))
880
+
881
+ return blocks
882
+
883
+
884
+ def enhance_image(
885
+ image: Image.Image,
886
+ contrast_factor: float = 1.3,
887
+ color_factor: float = 1.2,
888
+ brightness_factor: float = 0.95,
889
+ ) -> Image.Image:
890
+ enhancer_contrast = ImageEnhance.Contrast(image)
891
+ img_contrasted = enhancer_contrast.enhance(contrast_factor)
892
+
893
+ enhancer_color = ImageEnhance.Color(img_contrasted)
894
+ img_colored = enhancer_color.enhance(color_factor)
895
+
896
+ enhancer_brightness = ImageEnhance.Brightness(img_colored)
897
+ enhanced_image = enhancer_brightness.enhance(brightness_factor)
898
+
899
+ return enhanced_image
900
+
901
+
902
+ def post_process_texture(texture: np.ndarray, iter: int = 1) -> np.ndarray:
903
+ for _ in range(iter):
904
+ texture = cv2.fastNlMeansDenoisingColored(texture, None, 2, 2, 7, 15)
905
+ texture = cv2.bilateralFilter(
906
+ texture, d=5, sigmaColor=20, sigmaSpace=20
907
+ )
908
+
909
+ texture = enhance_image(
910
+ image=Image.fromarray(texture),
911
+ contrast_factor=1.3,
912
+ color_factor=1.2,
913
+ brightness_factor=0.95,
914
+ )
915
+
916
+ return np.array(texture)
917
+
918
+
919
+ def quat_mult(q1, q2):
920
+ # NOTE:
921
+ # Q1 is the quaternion that rotates the vector from the original position to the final position # noqa
922
+ # Q2 is the quaternion that been rotated
923
+ w1, x1, y1, z1 = q1.T
924
+ w2, x2, y2, z2 = q2.T
925
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
926
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
927
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
928
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
929
+ return torch.stack([w, x, y, z]).T
930
+
931
+
932
+ def quat_to_rotmat(quats: torch.Tensor, mode="wxyz") -> torch.Tensor:
933
+ """Convert quaternion to rotation matrix."""
934
+ quats = F.normalize(quats, p=2, dim=-1)
935
+
936
+ if mode == "xyzw":
937
+ x, y, z, w = torch.unbind(quats, dim=-1)
938
+ elif mode == "wxyz":
939
+ w, x, y, z = torch.unbind(quats, dim=-1)
940
+ else:
941
+ raise ValueError(f"Invalid mode: {mode}.")
942
+
943
+ R = torch.stack(
944
+ [
945
+ 1 - 2 * (y**2 + z**2),
946
+ 2 * (x * y - w * z),
947
+ 2 * (x * z + w * y),
948
+ 2 * (x * y + w * z),
949
+ 1 - 2 * (x**2 + z**2),
950
+ 2 * (y * z - w * x),
951
+ 2 * (x * z - w * y),
952
+ 2 * (y * z + w * x),
953
+ 1 - 2 * (x**2 + y**2),
954
+ ],
955
+ dim=-1,
956
+ )
957
+
958
+ return R.reshape(quats.shape[:-1] + (3, 3))
959
+
960
+
961
+ def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
962
+ C0 = 0.28209479177387814 # Constant for normalization in spherical harmonics # noqa
963
+ # Clip to the range [0.0, 1.0], apply gamma correction, and then un-clip back # noqa
964
+ new_shs = torch.clip(shs * C0 + 0.5, 0.0, 1.0)
965
+ new_shs = (torch.pow(new_shs, gamma) - 0.5) / C0
966
+ return new_shs
967
+
968
+
969
+ def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image:
970
+ max_size = max(image.size)
971
+ scale = min(1, 1024 / max_size)
972
+ if scale < 1:
973
+ new_size = (int(image.width * scale), int(image.height * scale))
974
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
975
+
976
+ return image
977
+
978
+
979
+ def trellis_preprocess(image: Image.Image) -> Image.Image:
980
+ """Process the input image as trellis done."""
981
+ image_np = np.array(image)
982
+ alpha = image_np[:, :, 3]
983
+ bbox = np.argwhere(alpha > 0.8 * 255)
984
+ bbox = (
985
+ np.min(bbox[:, 1]),
986
+ np.min(bbox[:, 0]),
987
+ np.max(bbox[:, 1]),
988
+ np.max(bbox[:, 0]),
989
+ )
990
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
991
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
992
+ size = int(size * 1.2)
993
+ bbox = (
994
+ center[0] - size // 2,
995
+ center[1] - size // 2,
996
+ center[0] + size // 2,
997
+ center[1] + size // 2,
998
+ )
999
+ image = image.crop(bbox)
1000
+ image = image.resize((518, 518), Image.Resampling.LANCZOS)
1001
+ image = np.array(image).astype(np.float32) / 255
1002
+ image = image[:, :, :3] * image[:, :, 3:4]
1003
+ image = Image.fromarray((image * 255).astype(np.uint8))
1004
+
1005
+ return image
1006
+
1007
+
1008
+ def zip_files(input_paths: list[str], output_zip: str) -> str:
1009
+ with zipfile.ZipFile(output_zip, "w", zipfile.ZIP_DEFLATED) as zipf:
1010
+ for input_path in input_paths:
1011
+ if not os.path.exists(input_path):
1012
+ raise FileNotFoundError(f"File not found: {input_path}")
1013
+
1014
+ if os.path.isdir(input_path):
1015
+ for root, _, files in os.walk(input_path):
1016
+ for file in files:
1017
+ file_path = os.path.join(root, file)
1018
+ arcname = os.path.relpath(
1019
+ file_path, start=os.path.commonpath(input_paths)
1020
+ )
1021
+ zipf.write(file_path, arcname=arcname)
1022
+ else:
1023
+ arcname = os.path.relpath(
1024
+ input_path, start=os.path.commonpath(input_paths)
1025
+ )
1026
+ zipf.write(input_path, arcname=arcname)
1027
+
1028
+ return output_zip
1029
+
1030
+
1031
+ def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None:
1032
+ for item in os.listdir(folder_path):
1033
+ if keep_subs is not None and item in keep_subs:
1034
+ continue
1035
+ item_path = os.path.join(folder_path, item)
1036
+ if os.path.isdir(item_path):
1037
+ rmtree(item_path)
1038
+ else:
1039
+ os.remove(item_path)
embodied_gen/envs/pick_embodiedgen.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import json
18
+ import os
19
+
20
+ import numpy as np
21
+ import sapien
22
+ import torch
23
+ import torchvision.transforms as transforms
24
+ from mani_skill.envs.sapien_env import BaseEnv
25
+ from mani_skill.sensors.camera import CameraConfig
26
+ from mani_skill.utils import sapien_utils
27
+ from mani_skill.utils.building import actors
28
+ from mani_skill.utils.building.ground import build_ground
29
+ from mani_skill.utils.registration import register_env
30
+ from mani_skill.utils.structs.actor import Actor
31
+ from mani_skill.utils.structs.pose import Pose
32
+ from mani_skill.utils.structs.types import (
33
+ GPUMemoryConfig,
34
+ SceneConfig,
35
+ SimConfig,
36
+ )
37
+ from mani_skill.utils.visualization.misc import tile_images
38
+ from tqdm import tqdm
39
+ from embodied_gen.models.gs_model import GaussianOperator
40
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
41
+ from embodied_gen.utils.geometry import bfs_placement, quaternion_multiply
42
+ from embodied_gen.utils.log import logger
43
+ from embodied_gen.utils.process_media import alpha_blend_rgba
44
+ from embodied_gen.utils.simulation import (
45
+ SIM_COORD_ALIGN,
46
+ load_assets_from_layout_file,
47
+ )
48
+
49
+ __all__ = ["PickEmbodiedGen"]
50
+
51
+
52
+ @register_env("PickEmbodiedGen-v1", max_episode_steps=100)
53
+ class PickEmbodiedGen(BaseEnv):
54
+ SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"]
55
+ goal_thresh = 0.0
56
+
57
+ def __init__(
58
+ self,
59
+ *args,
60
+ robot_uids: str | list[str] = "panda",
61
+ robot_init_qpos_noise: float = 0.02,
62
+ num_envs: int = 1,
63
+ reconfiguration_freq: int = None,
64
+ **kwargs,
65
+ ):
66
+ self.robot_init_qpos_noise = robot_init_qpos_noise
67
+ if reconfiguration_freq is None:
68
+ if num_envs == 1:
69
+ reconfiguration_freq = 1
70
+ else:
71
+ reconfiguration_freq = 0
72
+
73
+ # Init params from kwargs.
74
+ layout_file = kwargs.pop("layout_file", None)
75
+ replace_objs = kwargs.pop("replace_objs", True)
76
+ self.enable_grasp = kwargs.pop("enable_grasp", False)
77
+ self.init_3dgs_quat = kwargs.pop(
78
+ "init_3dgs_quat", [0.7071, 0, 0, 0.7071]
79
+ )
80
+ # Add small offset in z-axis to avoid collision.
81
+ self.objs_z_offset = kwargs.pop("objs_z_offset", 0.002)
82
+ self.robot_z_offset = kwargs.pop("robot_z_offset", 0.002)
83
+ self.camera_cfg = kwargs.pop("camera_cfg", None)
84
+ if self.camera_cfg is None:
85
+ self.camera_cfg = dict(
86
+ camera_eye=[0.9, 0.0, 1.1],
87
+ camera_target_pt=[0.0, 0.0, 0.9],
88
+ image_hw=[256, 256],
89
+ fovy_deg=75,
90
+ )
91
+
92
+ self.layouts = self.init_env_layouts(
93
+ layout_file, num_envs, replace_objs
94
+ )
95
+ self.robot_pose = self.compute_robot_init_pose(
96
+ self.layouts, num_envs, self.robot_z_offset
97
+ )
98
+ self.env_actors = dict()
99
+ self.image_transform = transforms.PILToTensor()
100
+
101
+ super().__init__(
102
+ *args,
103
+ robot_uids=robot_uids,
104
+ reconfiguration_freq=reconfiguration_freq,
105
+ num_envs=num_envs,
106
+ **kwargs,
107
+ )
108
+
109
+ self.bg_images = dict()
110
+ if self.render_mode == "hybrid":
111
+ self.bg_images = self.render_gs3d_images(
112
+ self.layouts, num_envs, self.init_3dgs_quat
113
+ )
114
+
115
+ @staticmethod
116
+ def init_env_layouts(
117
+ layout_file: str, num_envs: int, replace_objs: bool
118
+ ) -> list[LayoutInfo]:
119
+ layouts = []
120
+ for env_idx in range(num_envs):
121
+ if replace_objs and env_idx > 0:
122
+ layout_info = bfs_placement(layout_file)
123
+ else:
124
+ layout_info = json.load(open(layout_file, "r"))
125
+ layout_info = LayoutInfo.from_dict(layout_info)
126
+
127
+ layout_path = layout_file.replace(".json", f"_env{env_idx}.json")
128
+ with open(layout_path, "w") as f:
129
+ json.dump(layout_info.to_dict(), f, indent=4)
130
+
131
+ layouts.append(layout_path)
132
+
133
+ return layouts
134
+
135
+ @staticmethod
136
+ def compute_robot_init_pose(
137
+ layouts: list[str], num_envs: int, z_offset: float = 0.0
138
+ ) -> list[list[float]]:
139
+ robot_pose = []
140
+ for env_idx in range(num_envs):
141
+ layout = json.load(open(layouts[env_idx], "r"))
142
+ layout = LayoutInfo.from_dict(layout)
143
+ robot_node = layout.relation[Scene3DItemEnum.ROBOT.value]
144
+ x, y, z, qx, qy, qz, qw = layout.position[robot_node]
145
+ robot_pose.append([x, y, z + z_offset, qw, qx, qy, qz])
146
+
147
+ return robot_pose
148
+
149
+ @property
150
+ def _default_sim_config(self):
151
+ return SimConfig(
152
+ scene_config=SceneConfig(
153
+ solver_position_iterations=30,
154
+ # contact_offset=0.04,
155
+ # rest_offset=0.001,
156
+ ),
157
+ # sim_freq=200,
158
+ control_freq=50,
159
+ gpu_memory_config=GPUMemoryConfig(
160
+ max_rigid_contact_count=2**20, max_rigid_patch_count=2**19
161
+ ),
162
+ )
163
+
164
+ @property
165
+ def _default_sensor_configs(self):
166
+ pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
167
+
168
+ return [
169
+ CameraConfig("base_camera", pose, 128, 128, np.pi / 2, 0.01, 100)
170
+ ]
171
+
172
+ @property
173
+ def _default_human_render_camera_configs(self):
174
+ pose = sapien_utils.look_at(
175
+ eye=self.camera_cfg["camera_eye"],
176
+ target=self.camera_cfg["camera_target_pt"],
177
+ )
178
+
179
+ return CameraConfig(
180
+ "render_camera",
181
+ pose,
182
+ self.camera_cfg["image_hw"][1],
183
+ self.camera_cfg["image_hw"][0],
184
+ np.deg2rad(self.camera_cfg["fovy_deg"]),
185
+ 0.01,
186
+ 100,
187
+ )
188
+
189
+ def _load_agent(self, options: dict):
190
+ self.ground = build_ground(self.scene)
191
+ super()._load_agent(options, sapien.Pose(p=[-10, 0, 10]))
192
+
193
+ def _load_scene(self, options: dict):
194
+ all_objects = []
195
+ logger.info(f"Loading EmbodiedGen assets...")
196
+ for env_idx in range(self.num_envs):
197
+ env_actors = load_assets_from_layout_file(
198
+ self.scene,
199
+ self.layouts[env_idx],
200
+ z_offset=self.objs_z_offset,
201
+ env_idx=env_idx,
202
+ )
203
+ self.env_actors[f"env{env_idx}"] = env_actors
204
+ all_objects.extend(env_actors.values())
205
+
206
+ self.obj = all_objects[-1]
207
+ for obj in all_objects:
208
+ self.remove_from_state_dict_registry(obj)
209
+
210
+ self.all_objects = Actor.merge(all_objects, name="all_objects")
211
+ self.add_to_state_dict_registry(self.all_objects)
212
+
213
+ self.goal_site = actors.build_sphere(
214
+ self.scene,
215
+ radius=self.goal_thresh,
216
+ color=[0, 1, 0, 0],
217
+ name="goal_site",
218
+ body_type="kinematic",
219
+ add_collision=False,
220
+ initial_pose=sapien.Pose(),
221
+ )
222
+ self._hidden_objects.append(self.goal_site)
223
+
224
+ def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
225
+ with torch.device(self.device):
226
+ b = len(env_idx)
227
+ goal_xyz = torch.zeros((b, 3))
228
+ goal_xyz[:, :2] = torch.rand((b, 2)) * 0.2 - 0.1
229
+ self.goal_site.set_pose(Pose.create_from_pq(goal_xyz))
230
+
231
+ qpos = np.array(
232
+ [
233
+ 0.0,
234
+ np.pi / 8,
235
+ 0,
236
+ -np.pi * 3 / 8,
237
+ 0,
238
+ np.pi * 3 / 4,
239
+ np.pi / 4,
240
+ 0.04,
241
+ 0.04,
242
+ ]
243
+ )
244
+ qpos = (
245
+ np.random.normal(
246
+ 0, self.robot_init_qpos_noise, (self.num_envs, len(qpos))
247
+ )
248
+ + qpos
249
+ )
250
+ qpos[:, -2:] = 0.04
251
+ self.agent.robot.set_root_pose(np.array(self.robot_pose))
252
+ self.agent.reset(qpos)
253
+ self.agent.init_qpos = qpos
254
+ self.agent.controller.controllers["gripper"].reset()
255
+
256
+ def render_gs3d_images(
257
+ self, layouts: list[str], num_envs: int, init_quat: list[float]
258
+ ) -> dict[str, np.ndarray]:
259
+ sim_coord_align = (
260
+ torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device)
261
+ )
262
+ cameras = self.scene.sensors.copy()
263
+ cameras.update(self.scene.human_render_cameras)
264
+
265
+ # Preload the background Gaussian Splatting model.
266
+ asset_root = os.path.dirname(layouts[0])
267
+ layout = LayoutInfo.from_dict(json.load(open(layouts[0], "r")))
268
+ bg_node = layout.relation[Scene3DItemEnum.BACKGROUND.value]
269
+ gs_path = os.path.join(
270
+ asset_root, layout.assets[bg_node], "gs_model.ply"
271
+ )
272
+ raw_gs: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
273
+ bg_images = dict()
274
+ for env_idx in tqdm(range(num_envs), desc="Pre-rendering Background"):
275
+ layout = json.load(open(layouts[env_idx], "r"))
276
+ layout = LayoutInfo.from_dict(layout)
277
+ x, y, z, qx, qy, qz, qw = layout.position[bg_node]
278
+ qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], init_quat)
279
+ init_pose = torch.tensor([x, y, z, qx, qy, qz, qw])
280
+ gs_model = raw_gs.get_gaussians(instance_pose=init_pose)
281
+ for key in cameras:
282
+ camera = cameras[key]
283
+ Ks = camera.camera.get_intrinsic_matrix() # (n_env, 3, 3)
284
+ c2w = camera.camera.get_model_matrix() # (n_env, 4, 4)
285
+ result = gs_model.render(
286
+ c2w[env_idx] @ sim_coord_align,
287
+ Ks[env_idx],
288
+ image_width=camera.config.width,
289
+ image_height=camera.config.height,
290
+ )
291
+ bg_images[f"{key}-env{env_idx}"] = result.rgb[..., ::-1]
292
+
293
+ return bg_images
294
+
295
+ def render(self):
296
+ if self.render_mode is None:
297
+ raise RuntimeError("render_mode is not set.")
298
+ if self.render_mode == "human":
299
+ return self.render_human()
300
+ elif self.render_mode == "rgb_array":
301
+ res = self.render_rgb_array()
302
+ return res
303
+ elif self.render_mode == "sensors":
304
+ res = self.render_sensors()
305
+ return res
306
+ elif self.render_mode == "all":
307
+ return self.render_all()
308
+ elif self.render_mode == "hybrid":
309
+ return self.hybrid_render()
310
+ else:
311
+ raise NotImplementedError(
312
+ f"Unsupported render mode {self.render_mode}."
313
+ )
314
+
315
+ def render_rgb_array(
316
+ self, camera_name: str = None, return_alpha: bool = False
317
+ ):
318
+ for obj in self._hidden_objects:
319
+ obj.show_visual()
320
+ self.scene.update_render(
321
+ update_sensors=False, update_human_render_cameras=True
322
+ )
323
+ images = []
324
+ render_images = self.scene.get_human_render_camera_images(
325
+ camera_name, return_alpha
326
+ )
327
+ for image in render_images.values():
328
+ images.append(image)
329
+ if len(images) == 0:
330
+ return None
331
+ if len(images) == 1:
332
+ return images[0]
333
+ for obj in self._hidden_objects:
334
+ obj.hide_visual()
335
+ return tile_images(images)
336
+
337
+ def render_sensors(self):
338
+ images = []
339
+ sensor_images = self.get_sensor_images()
340
+ for image in sensor_images.values():
341
+ for img in image.values():
342
+ images.append(img)
343
+ return tile_images(images)
344
+
345
+ def hybrid_render(self):
346
+ fg_images = self.render_rgb_array(
347
+ return_alpha=True
348
+ ) # (n_env, h, w, 3)
349
+ images = []
350
+ for key in self.bg_images:
351
+ if "render_camera" not in key:
352
+ continue
353
+ env_idx = int(key.split("-env")[-1])
354
+ rgba = alpha_blend_rgba(
355
+ fg_images[env_idx].cpu().numpy(), self.bg_images[key]
356
+ )
357
+ images.append(self.image_transform(rgba))
358
+
359
+ images = torch.stack(images, dim=0)
360
+ images = images.permute(0, 2, 3, 1)
361
+
362
+ return images[..., :3]
363
+
364
+ def evaluate(self):
365
+ obj_to_goal_pos = (
366
+ self.obj.pose.p
367
+ ) # self.goal_site.pose.p - self.obj.pose.p
368
+ is_obj_placed = (
369
+ torch.linalg.norm(obj_to_goal_pos, axis=1) <= self.goal_thresh
370
+ )
371
+ is_grasped = self.agent.is_grasping(self.obj)
372
+ is_robot_static = self.agent.is_static(0.2)
373
+
374
+ return dict(
375
+ is_grasped=is_grasped,
376
+ obj_to_goal_pos=obj_to_goal_pos,
377
+ is_obj_placed=is_obj_placed,
378
+ is_robot_static=is_robot_static,
379
+ is_grasping=self.agent.is_grasping(self.obj),
380
+ success=torch.logical_and(is_obj_placed, is_robot_static),
381
+ )
382
+
383
+ def _get_obs_extra(self, info: dict):
384
+
385
+ return dict()
386
+
387
+ def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict):
388
+ tcp_to_obj_dist = torch.linalg.norm(
389
+ self.obj.pose.p - self.agent.tcp.pose.p, axis=1
390
+ )
391
+ reaching_reward = 1 - torch.tanh(5 * tcp_to_obj_dist)
392
+ reward = reaching_reward
393
+
394
+ is_grasped = info["is_grasped"]
395
+ reward += is_grasped
396
+
397
+ # obj_to_goal_dist = torch.linalg.norm(
398
+ # self.goal_site.pose.p - self.obj.pose.p, axis=1
399
+ # )
400
+ obj_to_goal_dist = torch.linalg.norm(
401
+ self.obj.pose.p - self.obj.pose.p, axis=1
402
+ )
403
+ place_reward = 1 - torch.tanh(5 * obj_to_goal_dist)
404
+ reward += place_reward * is_grasped
405
+
406
+ reward += info["is_obj_placed"] * is_grasped
407
+
408
+ static_reward = 1 - torch.tanh(
409
+ 5
410
+ * torch.linalg.norm(self.agent.robot.get_qvel()[..., :-2], axis=1)
411
+ )
412
+ reward += static_reward * info["is_obj_placed"] * is_grasped
413
+
414
+ reward[info["success"]] = 6
415
+ return reward
416
+
417
+ def compute_normalized_dense_reward(
418
+ self, obs: any, action: torch.Tensor, info: dict
419
+ ):
420
+ return self.compute_dense_reward(obs=obs, action=action, info=info) / 6
embodied_gen/models/delight_model.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import os
19
+ from typing import Union
20
+
21
+ import cv2
22
+ import numpy as np
23
+ import spaces
24
+ import torch
25
+ from diffusers import (
26
+ EulerAncestralDiscreteScheduler,
27
+ StableDiffusionInstructPix2PixPipeline,
28
+ )
29
+ from huggingface_hub import snapshot_download
30
+ from PIL import Image
31
+ from embodied_gen.models.segment_model import RembgRemover
32
+ from embodied_gen.utils.log import logger
33
+
34
+ __all__ = [
35
+ "DelightingModel",
36
+ ]
37
+
38
+
39
+ class DelightingModel(object):
40
+ """A model to remove the lighting in image space.
41
+
42
+ This model is encapsulated based on the Hunyuan3D-Delight model
43
+ from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa
44
+
45
+ Attributes:
46
+ image_guide_scale (float): Weight of image guidance in diffusion process.
47
+ text_guide_scale (float): Weight of text (prompt) guidance in diffusion process.
48
+ num_infer_step (int): Number of inference steps for diffusion model.
49
+ mask_erosion_size (int): Size of erosion kernel for alpha mask cleanup.
50
+ device (str): Device used for inference, e.g., 'cuda' or 'cpu'.
51
+ seed (int): Random seed for diffusion model reproducibility.
52
+ model_path (str): Filesystem path to pretrained model weights.
53
+ pipeline: Lazy-loaded diffusion pipeline instance.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ model_path: str = None,
59
+ num_infer_step: int = 50,
60
+ mask_erosion_size: int = 3,
61
+ image_guide_scale: float = 1.5,
62
+ text_guide_scale: float = 1.0,
63
+ device: str = "cuda",
64
+ seed: int = 0,
65
+ ) -> None:
66
+ self.image_guide_scale = image_guide_scale
67
+ self.text_guide_scale = text_guide_scale
68
+ self.num_infer_step = num_infer_step
69
+ self.mask_erosion_size = mask_erosion_size
70
+ self.kernel = np.ones(
71
+ (self.mask_erosion_size, self.mask_erosion_size), np.uint8
72
+ )
73
+ self.seed = seed
74
+ self.device = device
75
+ self.pipeline = None # lazy load model adapt to @spaces.GPU
76
+
77
+ if model_path is None:
78
+ suffix = "hunyuan3d-delight-v2-0"
79
+ model_path = snapshot_download(
80
+ repo_id="tencent/Hunyuan3D-2", allow_patterns=f"{suffix}/*"
81
+ )
82
+ model_path = os.path.join(model_path, suffix)
83
+
84
+ self.model_path = model_path
85
+
86
+ def _lazy_init_pipeline(self):
87
+ if self.pipeline is None:
88
+ logger.info("Loading Delighting Model...")
89
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
90
+ self.model_path,
91
+ torch_dtype=torch.float16,
92
+ safety_checker=None,
93
+ )
94
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
95
+ pipeline.scheduler.config
96
+ )
97
+ pipeline.set_progress_bar_config(disable=True)
98
+
99
+ pipeline.to(self.device, torch.float16)
100
+ self.pipeline = pipeline
101
+
102
+ def recenter_image(
103
+ self, image: Image.Image, border_ratio: float = 0.2
104
+ ) -> Image.Image:
105
+ if image.mode == "RGB":
106
+ return image
107
+ elif image.mode == "L":
108
+ image = image.convert("RGB")
109
+ return image
110
+
111
+ alpha_channel = np.array(image)[:, :, 3]
112
+ non_zero_indices = np.argwhere(alpha_channel > 0)
113
+ if non_zero_indices.size == 0:
114
+ raise ValueError("Image is fully transparent")
115
+
116
+ min_row, min_col = non_zero_indices.min(axis=0)
117
+ max_row, max_col = non_zero_indices.max(axis=0)
118
+
119
+ cropped_image = image.crop(
120
+ (min_col, min_row, max_col + 1, max_row + 1)
121
+ )
122
+
123
+ width, height = cropped_image.size
124
+ border_width = int(width * border_ratio)
125
+ border_height = int(height * border_ratio)
126
+
127
+ new_width = width + 2 * border_width
128
+ new_height = height + 2 * border_height
129
+
130
+ square_size = max(new_width, new_height)
131
+
132
+ new_image = Image.new(
133
+ "RGBA", (square_size, square_size), (255, 255, 255, 0)
134
+ )
135
+
136
+ paste_x = (square_size - new_width) // 2 + border_width
137
+ paste_y = (square_size - new_height) // 2 + border_height
138
+
139
+ new_image.paste(cropped_image, (paste_x, paste_y))
140
+
141
+ return new_image
142
+
143
+ @spaces.GPU
144
+ @torch.no_grad()
145
+ def __call__(
146
+ self,
147
+ image: Union[str, np.ndarray, Image.Image],
148
+ preprocess: bool = False,
149
+ target_wh: tuple[int, int] = None,
150
+ ) -> Image.Image:
151
+ self._lazy_init_pipeline()
152
+
153
+ if isinstance(image, str):
154
+ image = Image.open(image)
155
+ elif isinstance(image, np.ndarray):
156
+ image = Image.fromarray(image)
157
+
158
+ if preprocess:
159
+ bg_remover = RembgRemover()
160
+ image = bg_remover(image)
161
+ image = self.recenter_image(image)
162
+
163
+ if target_wh is not None:
164
+ image = image.resize(target_wh)
165
+ else:
166
+ target_wh = image.size
167
+
168
+ image_array = np.array(image)
169
+ assert image_array.shape[-1] == 4, "Image must have alpha channel"
170
+
171
+ raw_alpha_channel = image_array[:, :, 3]
172
+ alpha_channel = cv2.erode(raw_alpha_channel, self.kernel, iterations=1)
173
+ image_array[alpha_channel == 0, :3] = 255 # must be white background
174
+ image_array[:, :, 3] = alpha_channel
175
+
176
+ image = self.pipeline(
177
+ prompt="",
178
+ image=Image.fromarray(image_array).convert("RGB"),
179
+ generator=torch.manual_seed(self.seed),
180
+ num_inference_steps=self.num_infer_step,
181
+ image_guidance_scale=self.image_guide_scale,
182
+ guidance_scale=self.text_guide_scale,
183
+ ).images[0]
184
+
185
+ alpha_channel = Image.fromarray(alpha_channel)
186
+ rgba_image = image.convert("RGBA").resize(target_wh)
187
+ rgba_image.putalpha(alpha_channel)
188
+
189
+ return rgba_image
190
+
191
+
192
+ if __name__ == "__main__":
193
+ delighting_model = DelightingModel()
194
+ image_path = "apps/assets/example_image/sample_12.jpg"
195
+ image = delighting_model(
196
+ image_path, preprocess=True, target_wh=(512, 512)
197
+ ) # noqa
198
+ image.save("delight.png")
199
+
200
+ # image_path = "embodied_gen/scripts/test_robot.png"
201
+ # image = delighting_model(image_path)
202
+ # image.save("delighting_image_a2.png")
embodied_gen/models/gs_model.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import logging
19
+ import os
20
+ import struct
21
+ from dataclasses import dataclass
22
+ from typing import Optional
23
+
24
+ import cv2
25
+ import numpy as np
26
+ import torch
27
+ from gsplat.cuda._wrapper import spherical_harmonics
28
+ from gsplat.rendering import rasterization
29
+ from plyfile import PlyData
30
+ from scipy.spatial.transform import Rotation
31
+ from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat
32
+
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ __all__ = [
38
+ "RenderResult",
39
+ "GaussianOperator",
40
+ ]
41
+
42
+ SH_C0 = 0.2820947917738781
43
+
44
+
45
+ @dataclass
46
+ class RenderResult:
47
+ rgb: np.ndarray
48
+ depth: np.ndarray
49
+ opacity: np.ndarray
50
+ mask_threshold: float = 10
51
+ mask: Optional[np.ndarray] = None
52
+ rgba: Optional[np.ndarray] = None
53
+
54
+ def __post_init__(self):
55
+ if isinstance(self.rgb, torch.Tensor):
56
+ rgb = (self.rgb * 255).to(torch.uint8)
57
+ self.rgb = rgb.cpu().numpy()[..., ::-1]
58
+ if isinstance(self.depth, torch.Tensor):
59
+ self.depth = self.depth.cpu().numpy()
60
+ if isinstance(self.opacity, torch.Tensor):
61
+ opacity = (self.opacity * 255).to(torch.uint8)
62
+ self.opacity = opacity.cpu().numpy()
63
+ mask = np.where(self.opacity > self.mask_threshold, 255, 0)
64
+ self.mask = mask.astype(np.uint8)
65
+ self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
66
+
67
+
68
+ @dataclass
69
+ class GaussianBase:
70
+ _opacities: torch.Tensor
71
+ _means: torch.Tensor
72
+ _scales: torch.Tensor
73
+ _quats: torch.Tensor
74
+ _rgbs: Optional[torch.Tensor] = None
75
+ _features_dc: Optional[torch.Tensor] = None
76
+ _features_rest: Optional[torch.Tensor] = None
77
+ sh_degree: Optional[int] = 0
78
+ device: str = "cuda"
79
+
80
+ def __post_init__(self):
81
+ self.active_sh_degree: int = self.sh_degree
82
+ self.to(self.device)
83
+
84
+ def to(self, device: str) -> None:
85
+ for k, v in self.__dict__.items():
86
+ if not isinstance(v, torch.Tensor):
87
+ continue
88
+ self.__dict__[k] = v.to(device)
89
+
90
+ def get_numpy_data(self):
91
+ data = {}
92
+ for k, v in self.__dict__.items():
93
+ if not isinstance(v, torch.Tensor):
94
+ continue
95
+ data[k] = v.detach().cpu().numpy()
96
+
97
+ return data
98
+
99
+ def quat_norm(self, x: torch.Tensor) -> torch.Tensor:
100
+ return x / x.norm(dim=-1, keepdim=True)
101
+
102
+ @classmethod
103
+ def load_from_ply(
104
+ cls,
105
+ path: str,
106
+ gamma: float = 1.0,
107
+ device: str = "cuda",
108
+ ) -> "GaussianBase":
109
+ plydata = PlyData.read(path)
110
+ xyz = torch.stack(
111
+ (
112
+ torch.tensor(plydata.elements[0]["x"], dtype=torch.float32),
113
+ torch.tensor(plydata.elements[0]["y"], dtype=torch.float32),
114
+ torch.tensor(plydata.elements[0]["z"], dtype=torch.float32),
115
+ ),
116
+ dim=1,
117
+ )
118
+
119
+ opacities = torch.tensor(
120
+ plydata.elements[0]["opacity"], dtype=torch.float32
121
+ ).unsqueeze(-1)
122
+ features_dc = torch.zeros((xyz.shape[0], 3), dtype=torch.float32)
123
+ features_dc[:, 0] = torch.tensor(
124
+ plydata.elements[0]["f_dc_0"], dtype=torch.float32
125
+ )
126
+ features_dc[:, 1] = torch.tensor(
127
+ plydata.elements[0]["f_dc_1"], dtype=torch.float32
128
+ )
129
+ features_dc[:, 2] = torch.tensor(
130
+ plydata.elements[0]["f_dc_2"], dtype=torch.float32
131
+ )
132
+
133
+ scale_names = [
134
+ p.name
135
+ for p in plydata.elements[0].properties
136
+ if p.name.startswith("scale_")
137
+ ]
138
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
139
+ scales = torch.zeros(
140
+ (xyz.shape[0], len(scale_names)), dtype=torch.float32
141
+ )
142
+ for idx, attr_name in enumerate(scale_names):
143
+ scales[:, idx] = torch.tensor(
144
+ plydata.elements[0][attr_name], dtype=torch.float32
145
+ )
146
+
147
+ rot_names = [
148
+ p.name
149
+ for p in plydata.elements[0].properties
150
+ if p.name.startswith("rot_")
151
+ ]
152
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
153
+ rots = torch.zeros((xyz.shape[0], len(rot_names)), dtype=torch.float32)
154
+ for idx, attr_name in enumerate(rot_names):
155
+ rots[:, idx] = torch.tensor(
156
+ plydata.elements[0][attr_name], dtype=torch.float32
157
+ )
158
+
159
+ rots = rots / torch.norm(rots, dim=-1, keepdim=True)
160
+
161
+ # extra features
162
+ extra_f_names = [
163
+ p.name
164
+ for p in plydata.elements[0].properties
165
+ if p.name.startswith("f_rest_")
166
+ ]
167
+ extra_f_names = sorted(
168
+ extra_f_names, key=lambda x: int(x.split("_")[-1])
169
+ )
170
+
171
+ max_sh_degree = int(np.sqrt((len(extra_f_names) + 3) / 3) - 1)
172
+ if max_sh_degree != 0:
173
+ features_extra = torch.zeros(
174
+ (xyz.shape[0], len(extra_f_names)), dtype=torch.float32
175
+ )
176
+ for idx, attr_name in enumerate(extra_f_names):
177
+ features_extra[:, idx] = torch.tensor(
178
+ plydata.elements[0][attr_name], dtype=torch.float32
179
+ )
180
+
181
+ features_extra = features_extra.view(
182
+ (features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)
183
+ )
184
+ features_extra = features_extra.permute(0, 2, 1)
185
+
186
+ if abs(gamma - 1.0) > 1e-3:
187
+ features_dc = gamma_shs(features_dc, gamma)
188
+ features_extra[..., :] = 0.0
189
+ opacities *= 0.8
190
+
191
+ shs = torch.cat(
192
+ [
193
+ features_dc.reshape(-1, 3),
194
+ features_extra.reshape(len(features_dc), -1),
195
+ ],
196
+ dim=-1,
197
+ )
198
+ else:
199
+ # sh_dim is 0, only dc features
200
+ shs = features_dc
201
+ features_extra = None
202
+
203
+ return cls(
204
+ sh_degree=max_sh_degree,
205
+ _means=xyz,
206
+ _opacities=opacities,
207
+ _rgbs=shs,
208
+ _scales=scales,
209
+ _quats=rots,
210
+ _features_dc=features_dc,
211
+ _features_rest=features_extra,
212
+ device=device,
213
+ )
214
+
215
+ def save_to_ply(self, path: str, enable_mask: bool = False) -> None:
216
+ os.makedirs(os.path.dirname(path), exist_ok=True)
217
+ numpy_data = self.get_numpy_data()
218
+ means = numpy_data["_means"]
219
+ scales = numpy_data["_scales"]
220
+ quats = numpy_data["_quats"]
221
+ opacities = numpy_data["_opacities"]
222
+ sh0 = numpy_data["_features_dc"]
223
+ shN = numpy_data.get("_features_rest", np.zeros((means.shape[0], 0)))
224
+ shN = shN.reshape(means.shape[0], -1)
225
+
226
+ # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays # noqa
227
+ if enable_mask:
228
+ invalid_mask = (
229
+ np.isnan(means).any(axis=1)
230
+ | np.isinf(means).any(axis=1)
231
+ | np.isnan(scales).any(axis=1)
232
+ | np.isinf(scales).any(axis=1)
233
+ | np.isnan(quats).any(axis=1)
234
+ | np.isinf(quats).any(axis=1)
235
+ | np.isnan(opacities).any(axis=0)
236
+ | np.isinf(opacities).any(axis=0)
237
+ | np.isnan(sh0).any(axis=1)
238
+ | np.isinf(sh0).any(axis=1)
239
+ | np.isnan(shN).any(axis=1)
240
+ | np.isinf(shN).any(axis=1)
241
+ )
242
+
243
+ # Filter out rows with NaNs or Infs from all data arrays
244
+ means = means[~invalid_mask]
245
+ scales = scales[~invalid_mask]
246
+ quats = quats[~invalid_mask]
247
+ opacities = opacities[~invalid_mask]
248
+ sh0 = sh0[~invalid_mask]
249
+ shN = shN[~invalid_mask]
250
+
251
+ num_points = means.shape[0]
252
+ with open(path, "wb") as f:
253
+ # Write PLY header
254
+ f.write(b"ply\n")
255
+ f.write(b"format binary_little_endian 1.0\n")
256
+ f.write(f"element vertex {num_points}\n".encode())
257
+ f.write(b"property float x\n")
258
+ f.write(b"property float y\n")
259
+ f.write(b"property float z\n")
260
+
261
+ for i, data in enumerate([sh0, shN]):
262
+ prefix = "f_dc" if i == 0 else "f_rest"
263
+ for j in range(data.shape[1]):
264
+ f.write(f"property float {prefix}_{j}\n".encode())
265
+
266
+ f.write(b"property float opacity\n")
267
+
268
+ for i in range(scales.shape[1]):
269
+ f.write(f"property float scale_{i}\n".encode())
270
+ for i in range(quats.shape[1]):
271
+ f.write(f"property float rot_{i}\n".encode())
272
+
273
+ f.write(b"end_header\n")
274
+
275
+ # Write vertex data
276
+ for i in range(num_points):
277
+ f.write(struct.pack("<fff", *means[i])) # x, y, z
278
+
279
+ for data in [sh0, shN]:
280
+ for j in range(data.shape[1]):
281
+ f.write(struct.pack("<f", data[i, j]))
282
+
283
+ f.write(struct.pack("<f", opacities[i].item())) # opacity
284
+
285
+ for data in [scales, quats]:
286
+ for j in range(data.shape[1]):
287
+ f.write(struct.pack("<f", data[i, j]))
288
+
289
+ return
290
+
291
+
292
+ @dataclass
293
+ class GaussianOperator(GaussianBase):
294
+ """Gaussian Splatting operator.
295
+
296
+ Supports transformation, scaling, color computation, and
297
+ rasterization-based rendering.
298
+
299
+ Inherits:
300
+ GaussianBase: Base class with Gaussian params (means, scales, etc.)
301
+
302
+ Functionality includes:
303
+ - Applying instance poses to transform Gaussian means and quaternions.
304
+ - Scaling Gaussians to a real-world size.
305
+ - Computing colors using spherical harmonics.
306
+ - Rendering images via differentiable rasterization.
307
+ - Exporting transformed and rescaled models to .ply format.
308
+ """
309
+
310
+ def _compute_transform(
311
+ self,
312
+ means: torch.Tensor,
313
+ quats: torch.Tensor,
314
+ instance_pose: torch.Tensor,
315
+ ):
316
+ """Compute the transform of the GS models.
317
+
318
+ Args:
319
+ means: tensor of gs means.
320
+ quats: tensor of gs quaternions.
321
+ instance_pose: instances poses in [x y z qx qy qz qw] format.
322
+
323
+ """
324
+ # (x y z qx qy qz qw) -> (x y z qw qx qy qz)
325
+ instance_pose = instance_pose[[0, 1, 2, 6, 3, 4, 5]]
326
+ cur_instances_quats = self.quat_norm(instance_pose[3:])
327
+ rot_cur = quat_to_rotmat(cur_instances_quats, mode="wxyz")
328
+
329
+ # update the means
330
+ num_gs = means.shape[0]
331
+ trans_per_pts = torch.stack([instance_pose[:3]] * num_gs, dim=0)
332
+ quat_per_pts = torch.stack([instance_pose[3:]] * num_gs, dim=0)
333
+ rot_per_pts = torch.stack([rot_cur] * num_gs, dim=0) # (num_gs, 3, 3)
334
+
335
+ # update the means
336
+ cur_means = (
337
+ torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1)
338
+ + trans_per_pts
339
+ )
340
+
341
+ # update the quats
342
+ _quats = self.quat_norm(quats)
343
+ cur_quats = quat_mult(quat_per_pts, _quats)
344
+
345
+ return cur_means, cur_quats
346
+
347
+ def get_gaussians(
348
+ self,
349
+ c2w: torch.Tensor = None,
350
+ instance_pose: torch.Tensor = None,
351
+ apply_activate: bool = False,
352
+ ) -> "GaussianBase":
353
+ """Get Gaussian data under the given instance_pose."""
354
+ if c2w is None:
355
+ c2w = torch.eye(4).to(self.device)
356
+
357
+ if instance_pose is not None:
358
+ # compute the transformed gs means and quats
359
+ world_means, world_quats = self._compute_transform(
360
+ self._means, self._quats, instance_pose.float().to(self.device)
361
+ )
362
+ else:
363
+ world_means, world_quats = self._means, self._quats
364
+
365
+ # get colors of gaussians
366
+ if self._features_rest is not None:
367
+ colors = torch.cat(
368
+ (self._features_dc[:, None, :], self._features_rest), dim=1
369
+ )
370
+ else:
371
+ colors = self._features_dc[:, None, :]
372
+
373
+ if self.sh_degree > 0:
374
+ viewdirs = world_means.detach() - c2w[..., :3, 3] # (N, 3)
375
+ viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
376
+ rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
377
+ rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
378
+ else:
379
+ rgbs = torch.sigmoid(colors[:, 0, :])
380
+
381
+ gs_dict = dict(
382
+ _means=world_means,
383
+ _opacities=(
384
+ torch.sigmoid(self._opacities)
385
+ if apply_activate
386
+ else self._opacities
387
+ ),
388
+ _rgbs=rgbs,
389
+ _scales=(
390
+ torch.exp(self._scales) if apply_activate else self._scales
391
+ ),
392
+ _quats=self.quat_norm(world_quats),
393
+ _features_dc=self._features_dc,
394
+ _features_rest=self._features_rest,
395
+ sh_degree=self.sh_degree,
396
+ device=self.device,
397
+ )
398
+
399
+ return GaussianOperator(**gs_dict)
400
+
401
+ def rescale(self, scale: float):
402
+ if scale != 1.0:
403
+ self._means *= scale
404
+ self._scales += torch.log(self._scales.new_tensor(scale))
405
+
406
+ def set_scale_by_height(self, real_height: float) -> None:
407
+ def _ptp(tensor, dim):
408
+ val = tensor.max(dim=dim).values - tensor.min(dim=dim).values
409
+ return val.tolist()
410
+
411
+ xyz_scale = max(_ptp(self._means, dim=0))
412
+ self.rescale(1 / (xyz_scale + 1e-6)) # Normalize to [-0.5, 0.5]
413
+ raw_height = _ptp(self._means, dim=0)[1]
414
+ scale = real_height / raw_height
415
+
416
+ self.rescale(scale)
417
+
418
+ return
419
+
420
+ @staticmethod
421
+ def resave_ply(
422
+ in_ply: str,
423
+ out_ply: str,
424
+ real_height: float = None,
425
+ instance_pose: np.ndarray = None,
426
+ device: str = "cuda",
427
+ ) -> None:
428
+ gs_model = GaussianOperator.load_from_ply(in_ply, device=device)
429
+
430
+ if instance_pose is not None:
431
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
432
+
433
+ if real_height is not None:
434
+ gs_model.set_scale_by_height(real_height)
435
+
436
+ gs_model.save_to_ply(out_ply)
437
+
438
+ return
439
+
440
+ @staticmethod
441
+ def trans_to_quatpose(
442
+ rot_matrix: list[list[float]],
443
+ trans_matrix: list[float] = [0, 0, 0],
444
+ ) -> torch.Tensor:
445
+ if isinstance(rot_matrix, list):
446
+ rot_matrix = np.array(rot_matrix)
447
+
448
+ rot = Rotation.from_matrix(rot_matrix)
449
+ qx, qy, qz, qw = rot.as_quat()
450
+ instance_pose = torch.tensor([*trans_matrix, qx, qy, qz, qw])
451
+
452
+ return instance_pose
453
+
454
+ def render(
455
+ self,
456
+ c2w: torch.Tensor,
457
+ Ks: torch.Tensor,
458
+ image_width: int,
459
+ image_height: int,
460
+ ) -> RenderResult:
461
+ gs = self.get_gaussians(c2w, apply_activate=True)
462
+ renders, alphas, _ = rasterization(
463
+ means=gs._means,
464
+ quats=gs._quats,
465
+ scales=gs._scales,
466
+ opacities=gs._opacities.squeeze(),
467
+ colors=gs._rgbs,
468
+ viewmats=torch.linalg.inv(c2w)[None, ...],
469
+ Ks=Ks[None, ...],
470
+ width=image_width,
471
+ height=image_height,
472
+ packed=False,
473
+ absgrad=True,
474
+ sparse_grad=False,
475
+ # rasterize_mode="classic",
476
+ rasterize_mode="antialiased",
477
+ **{
478
+ "near_plane": 0.01,
479
+ "far_plane": 1000000000,
480
+ "radius_clip": 0.0,
481
+ "render_mode": "RGB+ED",
482
+ },
483
+ )
484
+ renders = renders[0]
485
+ alphas = alphas[0].squeeze(-1)
486
+
487
+ assert renders.shape[-1] == 4, f"Must render rgb, depth and alpha"
488
+ rendered_rgb, rendered_depth = torch.split(renders, [3, 1], dim=-1)
489
+
490
+ return RenderResult(
491
+ torch.clamp(rendered_rgb, min=0, max=1),
492
+ rendered_depth,
493
+ alphas[..., None],
494
+ )
495
+
496
+
497
+ if __name__ == "__main__":
498
+ input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply"
499
+ output_gs = "./gs_model.ply"
500
+ gs_model: GaussianOperator = GaussianOperator.load_from_ply(input_gs)
501
+
502
+ # 绕 x 轴旋转 180°
503
+ R_x = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
504
+ instance_pose = gs_model.trans_to_quatpose(R_x)
505
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
506
+
507
+ gs_model.rescale(2)
508
+
509
+ gs_model.set_scale_by_height(1.3)
510
+
511
+ gs_model.save_to_ply(output_gs)
embodied_gen/models/image_comm_model.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+ # Text-to-Image generation models from Hugging Face community.
17
+
18
+ import os
19
+ from abc import ABC, abstractmethod
20
+
21
+ import torch
22
+ from diffusers import (
23
+ ChromaPipeline,
24
+ Cosmos2TextToImagePipeline,
25
+ DPMSolverMultistepScheduler,
26
+ FluxPipeline,
27
+ KolorsPipeline,
28
+ StableDiffusion3Pipeline,
29
+ )
30
+ from diffusers.quantizers import PipelineQuantizationConfig
31
+ from huggingface_hub import snapshot_download
32
+ from PIL import Image
33
+ from transformers import AutoModelForCausalLM, SiglipProcessor
34
+
35
+ __all__ = [
36
+ "build_hf_image_pipeline",
37
+ ]
38
+
39
+
40
+ class BasePipelineLoader(ABC):
41
+ def __init__(self, device="cuda"):
42
+ self.device = device
43
+
44
+ @abstractmethod
45
+ def load(self):
46
+ pass
47
+
48
+
49
+ class BasePipelineRunner(ABC):
50
+ def __init__(self, pipe):
51
+ self.pipe = pipe
52
+
53
+ @abstractmethod
54
+ def run(self, prompt: str, **kwargs) -> Image.Image:
55
+ pass
56
+
57
+
58
+ # ===== SD3.5-medium =====
59
+ class SD35Loader(BasePipelineLoader):
60
+ def load(self):
61
+ pipe = StableDiffusion3Pipeline.from_pretrained(
62
+ "stabilityai/stable-diffusion-3.5-medium",
63
+ torch_dtype=torch.float16,
64
+ )
65
+ pipe = pipe.to(self.device)
66
+ pipe.enable_model_cpu_offload()
67
+ pipe.enable_xformers_memory_efficient_attention()
68
+ pipe.enable_attention_slicing()
69
+ return pipe
70
+
71
+
72
+ class SD35Runner(BasePipelineRunner):
73
+ def run(self, prompt: str, **kwargs) -> Image.Image:
74
+ return self.pipe(prompt=prompt, **kwargs).images
75
+
76
+
77
+ # ===== Cosmos2 =====
78
+ class CosmosLoader(BasePipelineLoader):
79
+ def __init__(
80
+ self,
81
+ model_id="nvidia/Cosmos-Predict2-2B-Text2Image",
82
+ local_dir="weights/cosmos2",
83
+ device="cuda",
84
+ ):
85
+ super().__init__(device)
86
+ self.model_id = model_id
87
+ self.local_dir = local_dir
88
+
89
+ def _patch(self):
90
+ def patch_model(cls):
91
+ orig = cls.from_pretrained
92
+
93
+ def new(*args, **kwargs):
94
+ kwargs.setdefault("attn_implementation", "flash_attention_2")
95
+ kwargs.setdefault("torch_dtype", torch.bfloat16)
96
+ return orig(*args, **kwargs)
97
+
98
+ cls.from_pretrained = new
99
+
100
+ def patch_processor(cls):
101
+ orig = cls.from_pretrained
102
+
103
+ def new(*args, **kwargs):
104
+ kwargs.setdefault("use_fast", True)
105
+ return orig(*args, **kwargs)
106
+
107
+ cls.from_pretrained = new
108
+
109
+ patch_model(AutoModelForCausalLM)
110
+ patch_processor(SiglipProcessor)
111
+
112
+ def load(self):
113
+ self._patch()
114
+ snapshot_download(
115
+ repo_id=self.model_id,
116
+ local_dir=self.local_dir,
117
+ local_dir_use_symlinks=False,
118
+ resume_download=True,
119
+ )
120
+
121
+ config = PipelineQuantizationConfig(
122
+ quant_backend="bitsandbytes_4bit",
123
+ quant_kwargs={
124
+ "load_in_4bit": True,
125
+ "bnb_4bit_quant_type": "nf4",
126
+ "bnb_4bit_compute_dtype": torch.bfloat16,
127
+ "bnb_4bit_use_double_quant": True,
128
+ },
129
+ components_to_quantize=["text_encoder", "transformer", "unet"],
130
+ )
131
+
132
+ pipe = Cosmos2TextToImagePipeline.from_pretrained(
133
+ self.model_id,
134
+ torch_dtype=torch.bfloat16,
135
+ quantization_config=config,
136
+ use_safetensors=True,
137
+ safety_checker=None,
138
+ requires_safety_checker=False,
139
+ ).to(self.device)
140
+ return pipe
141
+
142
+
143
+ class CosmosRunner(BasePipelineRunner):
144
+ def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
145
+ return self.pipe(
146
+ prompt=prompt, negative_prompt=negative_prompt, **kwargs
147
+ ).images
148
+
149
+
150
+ # ===== Kolors =====
151
+ class KolorsLoader(BasePipelineLoader):
152
+ def load(self):
153
+ pipe = KolorsPipeline.from_pretrained(
154
+ "Kwai-Kolors/Kolors-diffusers",
155
+ torch_dtype=torch.float16,
156
+ variant="fp16",
157
+ ).to(self.device)
158
+ pipe.enable_model_cpu_offload()
159
+ pipe.enable_xformers_memory_efficient_attention()
160
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
161
+ pipe.scheduler.config, use_karras_sigmas=True
162
+ )
163
+ return pipe
164
+
165
+
166
+ class KolorsRunner(BasePipelineRunner):
167
+ def run(self, prompt: str, **kwargs) -> Image.Image:
168
+ return self.pipe(prompt=prompt, **kwargs).images
169
+
170
+
171
+ # ===== Flux =====
172
+ class FluxLoader(BasePipelineLoader):
173
+ def load(self):
174
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
175
+ pipe = FluxPipeline.from_pretrained(
176
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
177
+ )
178
+ pipe.enable_model_cpu_offload()
179
+ pipe.enable_xformers_memory_efficient_attention()
180
+ pipe.enable_attention_slicing()
181
+ return pipe.to(self.device)
182
+
183
+
184
+ class FluxRunner(BasePipelineRunner):
185
+ def run(self, prompt: str, **kwargs) -> Image.Image:
186
+ return self.pipe(prompt=prompt, **kwargs).images
187
+
188
+
189
+ # ===== Chroma =====
190
+ class ChromaLoader(BasePipelineLoader):
191
+ def load(self):
192
+ return ChromaPipeline.from_pretrained(
193
+ "lodestones/Chroma", torch_dtype=torch.bfloat16
194
+ ).to(self.device)
195
+
196
+
197
+ class ChromaRunner(BasePipelineRunner):
198
+ def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image:
199
+ return self.pipe(
200
+ prompt=prompt, negative_prompt=negative_prompt, **kwargs
201
+ ).images
202
+
203
+
204
+ PIPELINE_REGISTRY = {
205
+ "sd35": (SD35Loader, SD35Runner),
206
+ "cosmos": (CosmosLoader, CosmosRunner),
207
+ "kolors": (KolorsLoader, KolorsRunner),
208
+ "flux": (FluxLoader, FluxRunner),
209
+ "chroma": (ChromaLoader, ChromaRunner),
210
+ }
211
+
212
+
213
+ def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner:
214
+ if name not in PIPELINE_REGISTRY:
215
+ raise ValueError(f"Unsupported model: {name}")
216
+ loader_cls, runner_cls = PIPELINE_REGISTRY[name]
217
+ pipe = loader_cls(device=device).load()
218
+
219
+ return runner_cls(pipe)
220
+
221
+
222
+ if __name__ == "__main__":
223
+ model_name = "sd35"
224
+ runner = build_hf_image_pipeline(model_name)
225
+ # NOTE: Just for pipeline testing, generation quality at low resolution is poor.
226
+ images = runner.run(
227
+ prompt="A robot holding a sign that says 'Hello'",
228
+ height=512,
229
+ width=512,
230
+ num_inference_steps=10,
231
+ guidance_scale=6,
232
+ num_images_per_prompt=1,
233
+ )
234
+
235
+ for i, img in enumerate(images):
236
+ img.save(f"image_{model_name}_{i}.jpg")
embodied_gen/models/layout.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import json
20
+ import logging
21
+ import os
22
+ import re
23
+
24
+ import json_repair
25
+ from embodied_gen.utils.enum import (
26
+ LayoutInfo,
27
+ RobotItemEnum,
28
+ Scene3DItemEnum,
29
+ SpatialRelationEnum,
30
+ )
31
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
32
+ from embodied_gen.utils.process_media import SceneTreeVisualizer
33
+
34
+ logging.basicConfig(level=logging.INFO)
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ __all__ = [
39
+ "LayoutDesigner",
40
+ "LAYOUT_DISASSEMBLER",
41
+ "LAYOUT_GRAPHER",
42
+ "LAYOUT_DESCRIBER",
43
+ ]
44
+
45
+
46
+ DISTRACTOR_NUM = 2 # Maximum number of distractor objects allowed
47
+ LAYOUT_DISASSEMBLE_PROMPT = f"""
48
+ You are an intelligent 3D scene planner. Given a natural language
49
+ description of a robotic task, output a structured description of
50
+ an interactive 3D scene.
51
+
52
+ The output must include the following fields:
53
+ - task: A high-level task type (e.g., "single-arm pick",
54
+ "dual-arm grasping", "pick and place", "object sorting").
55
+ - {Scene3DItemEnum.ROBOT}: The name or type of robot involved. If not mentioned,
56
+ use {RobotItemEnum.FRANKA} as default.
57
+ - {Scene3DItemEnum.BACKGROUND}: The room or indoor environment where the task happens
58
+ (e.g., Kitchen, Bedroom, Living Room, Workshop, Office).
59
+ - {Scene3DItemEnum.CONTEXT}: A indoor object involved in the manipulation
60
+ (e.g., Table, Shelf, Desk, Bed, Cabinet).
61
+ - {Scene3DItemEnum.MANIPULATED_OBJS}: The main object(s) that the robot directly interacts with.
62
+ - {Scene3DItemEnum.DISTRACTOR_OBJS}: Other objects that naturally belong to the scene but are not part of the main task.
63
+
64
+ Constraints:
65
+ - The {Scene3DItemEnum.BACKGROUND} must logically match the described task.
66
+ - The {Scene3DItemEnum.CONTEXT} must fit within the {Scene3DItemEnum.BACKGROUND}. (e.g., a bedroom may include a table or bed, but not a workbench.)
67
+ - The {Scene3DItemEnum.CONTEXT} must be a concrete indoor object, such as a "table",
68
+ "shelf", "desk", or "bed". It must not be an abstract concept (e.g., "area", "space", "zone")
69
+ or structural surface (e.g., "floor", "ground"). If the input describes an interaction near
70
+ the floor or vague space, you must infer a plausible object like a "table", "cabinet", or "storage box" instead.
71
+ - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} objects must be plausible,
72
+ and semantically compatible with the {Scene3DItemEnum.CONTEXT} and {Scene3DItemEnum.BACKGROUND}.
73
+ - {Scene3DItemEnum.DISTRACTOR_OBJS} must not confuse or overlap with the manipulated objects.
74
+ - {Scene3DItemEnum.DISTRACTOR_OBJS} number limit: {DISTRACTOR_NUM} distractors maximum.
75
+ - All {Scene3DItemEnum.BACKGROUND} are limited to indoor environments.
76
+ - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} are rigid bodies and not include flexible objects.
77
+ - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} must be common
78
+ household or office items or furniture, not abstract concepts, not too small like needle.
79
+ - If the input includes a plural or grouped object (e.g., "pens", "bottles", "plates", "fruit"),
80
+ you must decompose it into multiple individual instances (e.g., ["pen1", "pen2"], ["apple", "pear"]).
81
+ - Containers that hold objects (e.g., "bowl of apples", "box of tools") must
82
+ be separated into individual items (e.g., ["bowl", "apple1", "apple2"]).
83
+ - Do not include transparent objects such as "glass", "plastic", etc.
84
+ - All {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} must be child node of {Scene3DItemEnum.CONTEXT}.
85
+ - The output must be in compact JSON format and use Markdown syntax, just like the output in the example below.
86
+
87
+ Examples:
88
+
89
+ Input:
90
+ "Pick up the marker from the table and put it in the bowl robot {RobotItemEnum.UR5}."
91
+ Output:
92
+ ```json
93
+ {{
94
+ "task_desc": "Pick up the marker from the table and put it in the bowl.",
95
+ "task": "pick and place",
96
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
97
+ "{Scene3DItemEnum.BACKGROUND}": "kitchen",
98
+ "{Scene3DItemEnum.CONTEXT}": "table",
99
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker"],
100
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["mug", "notebook", "bowl"]
101
+ }}
102
+ ```
103
+
104
+ Input:
105
+ "Put the rubik's cube on the top of the shelf."
106
+ Output:
107
+ ```json
108
+ {{
109
+ "task_desc": "Put the rubik's cube on the top of the shelf.",
110
+ "task": "pick and place",
111
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
112
+ "{Scene3DItemEnum.BACKGROUND}": "bedroom",
113
+ "{Scene3DItemEnum.CONTEXT}": "shelf",
114
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rubik's cube"],
115
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["pen", "cup", "toy car"]
116
+ }}
117
+ ```
118
+
119
+ Input:
120
+ "Remove all the objects from the white basket and put them on the table."
121
+ Output:
122
+ ```json
123
+ {{
124
+ "task_desc": "Remove all the objects from the white basket and put them on the table, robot {RobotItemEnum.PIPER}.",
125
+ "task": "pick and place",
126
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.PIPER}",
127
+ "{Scene3DItemEnum.BACKGROUND}": "office",
128
+ "{Scene3DItemEnum.CONTEXT}": "table",
129
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["banana", "mobile phone"],
130
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["plate", "white basket"]
131
+ }}
132
+ ```
133
+
134
+ Input:
135
+ "Pick up the rope on the chair and put it in the box."
136
+ Output:
137
+ ```json
138
+ {{
139
+ "task_desc": "Pick up the rope on the chair and put it in the box, robot {RobotItemEnum.FRANKA}.",
140
+ "task": "pick and place",
141
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
142
+ "{Scene3DItemEnum.BACKGROUND}": "living room",
143
+ "{Scene3DItemEnum.CONTEXT}": "chair",
144
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rope", "box"],
145
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["magazine"]
146
+ }}
147
+ ```
148
+
149
+ Input:
150
+ "Pick up the seal tape and plastic from the counter and put them in the open drawer and close it."
151
+ Output:
152
+ ```json
153
+ {{
154
+ "task_desc": "Pick up the seal tape and plastic from the counter and put them in the open drawer and close it.",
155
+ "task": "pick and place",
156
+ "robot": "franka",
157
+ "background": "kitchen",
158
+ "context": "counter",
159
+ "manipulated_objs": ["seal tape", "plastic", "opened drawer"],
160
+ "distractor_objs": ["scissors"]
161
+ }}
162
+ ```
163
+
164
+ Input:
165
+ "Put the pens in the grey bowl."
166
+ Output:
167
+ ```json
168
+ {{
169
+ "task_desc": "Put the pens in the grey bowl.",
170
+ "task": "pick and place",
171
+ "robot": "franka",
172
+ "background": "office",
173
+ "context": "table",
174
+ "manipulated_objs": ["pen1", "pen2", "grey bowl"],
175
+ "distractor_objs": ["notepad", "cup"]
176
+ }}
177
+ ```
178
+
179
+ """
180
+
181
+
182
+ LAYOUT_HIERARCHY_PROMPT = f"""
183
+ You are a 3D scene layout reasoning expert.
184
+ Your task is to generate a spatial relationship dictionary in multiway tree
185
+ that describes how objects are arranged in a 3D environment
186
+ based on a given task description and object list.
187
+
188
+ Input in JSON format containing the task description, task type,
189
+ {Scene3DItemEnum.ROBOT}, {Scene3DItemEnum.BACKGROUND}, {Scene3DItemEnum.CONTEXT},
190
+ and a list of objects, including {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS}.
191
+
192
+ ### Supported Spatial Relations:
193
+ - "{SpatialRelationEnum.ON}": The child object bottom is directly on top of the parent object top.
194
+ - "{SpatialRelationEnum.INSIDE}": The child object is inside the context object.
195
+ - "{SpatialRelationEnum.IN}": The {Scene3DItemEnum.ROBOT} in the {Scene3DItemEnum.BACKGROUND}.
196
+ - "{SpatialRelationEnum.FLOOR}": The child object bottom is on the floor of the {Scene3DItemEnum.BACKGROUND}.
197
+
198
+ ### Rules:
199
+ - The {Scene3DItemEnum.CONTEXT} object must be "{SpatialRelationEnum.FLOOR}" the {Scene3DItemEnum.BACKGROUND}.
200
+ - {Scene3DItemEnum.MANIPULATED_OBJS} and {Scene3DItemEnum.DISTRACTOR_OBJS} must be either
201
+ "{SpatialRelationEnum.ON}" or "{SpatialRelationEnum.INSIDE}" the {Scene3DItemEnum.CONTEXT}
202
+ - Or "{SpatialRelationEnum.FLOOR}" {Scene3DItemEnum.BACKGROUND}.
203
+ - Use "{SpatialRelationEnum.INSIDE}" only if the parent is a container-like object (e.g., shelf, rack, cabinet).
204
+ - Do not define relationship edges between objects, only for the child and parent nodes.
205
+ - {Scene3DItemEnum.ROBOT} must "{SpatialRelationEnum.IN}" the {Scene3DItemEnum.BACKGROUND}.
206
+ - Ensure that each object appears only once in the layout tree, and its spatial relationship is defined with only one parent.
207
+ - Ensure a valid multiway tree structure with a maximum depth of 2 levels suitable for a 3D scene layout representation.
208
+ - Only output the final output in JSON format, using Markdown syntax as in examples.
209
+
210
+ ### Example
211
+ Input:
212
+ {{
213
+ "task_desc": "Pick up the marker from the table and put it in the bowl.",
214
+ "task": "pick and place",
215
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.FRANKA}",
216
+ "{Scene3DItemEnum.BACKGROUND}": "kitchen",
217
+ "{Scene3DItemEnum.CONTEXT}": "table",
218
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker", "bowl"],
219
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["mug", "chair"]
220
+ }}
221
+ Intermediate Think:
222
+ table {SpatialRelationEnum.FLOOR} kitchen
223
+ chair {SpatialRelationEnum.FLOOR} kitchen
224
+ {RobotItemEnum.FRANKA} {SpatialRelationEnum.IN} kitchen
225
+ marker {SpatialRelationEnum.ON} table
226
+ bowl {SpatialRelationEnum.ON} table
227
+ mug {SpatialRelationEnum.ON} table
228
+ Final Output:
229
+ ```json
230
+ {{
231
+ "kitchen": [
232
+ ["table", "{SpatialRelationEnum.FLOOR}"],
233
+ ["chair", "{SpatialRelationEnum.FLOOR}"],
234
+ ["{RobotItemEnum.FRANKA}", "{SpatialRelationEnum.IN}"]
235
+ ],
236
+ "table": [
237
+ ["marker", "{SpatialRelationEnum.ON}"],
238
+ ["bowl", "{SpatialRelationEnum.ON}"],
239
+ ["mug", "{SpatialRelationEnum.ON}"]
240
+ ]
241
+ }}
242
+ ```
243
+
244
+ Input:
245
+ {{
246
+ "task_desc": "Put the marker on top of the book.",
247
+ "task": "pick and place",
248
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
249
+ "{Scene3DItemEnum.BACKGROUND}": "office",
250
+ "{Scene3DItemEnum.CONTEXT}": "desk",
251
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["marker", "book"],
252
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["pen holder", "notepad"]
253
+ }}
254
+ Intermediate Think:
255
+ desk {SpatialRelationEnum.FLOOR} office
256
+ {RobotItemEnum.UR5} {SpatialRelationEnum.IN} office
257
+ marker {SpatialRelationEnum.ON} desk
258
+ book {SpatialRelationEnum.ON} desk
259
+ pen holder {SpatialRelationEnum.ON} desk
260
+ notepad {SpatialRelationEnum.ON} desk
261
+ Final Output:
262
+ ```json
263
+ {{
264
+ "office": [
265
+ ["desk", "{SpatialRelationEnum.FLOOR}"],
266
+ ["{RobotItemEnum.UR5}", "{SpatialRelationEnum.IN}"]
267
+ ],
268
+ "desk": [
269
+ ["marker", "{SpatialRelationEnum.ON}"],
270
+ ["book", "{SpatialRelationEnum.ON}"],
271
+ ["pen holder", "{SpatialRelationEnum.ON}"],
272
+ ["notepad", "{SpatialRelationEnum.ON}"]
273
+ ]
274
+ }}
275
+ ```
276
+
277
+ Input:
278
+ {{
279
+ "task_desc": "Put the rubik's cube on the top of the shelf.",
280
+ "task": "pick and place",
281
+ "{Scene3DItemEnum.ROBOT}": "{RobotItemEnum.UR5}",
282
+ "{Scene3DItemEnum.BACKGROUND}": "bedroom",
283
+ "{Scene3DItemEnum.CONTEXT}": "shelf",
284
+ "{Scene3DItemEnum.MANIPULATED_OBJS}": ["rubik's cube"],
285
+ "{Scene3DItemEnum.DISTRACTOR_OBJS}": ["toy car", "pen"]
286
+ }}
287
+ Intermediate Think:
288
+ shelf {SpatialRelationEnum.FLOOR} bedroom
289
+ {RobotItemEnum.UR5} {SpatialRelationEnum.IN} bedroom
290
+ rubik's cube {SpatialRelationEnum.INSIDE} shelf
291
+ toy car {SpatialRelationEnum.INSIDE} shelf
292
+ pen {SpatialRelationEnum.INSIDE} shelf
293
+ Final Output:
294
+ ```json
295
+ {{
296
+ "bedroom": [
297
+ ["shelf", "{SpatialRelationEnum.FLOOR}"],
298
+ ["{RobotItemEnum.UR5}", "{SpatialRelationEnum.IN}"]
299
+ ],
300
+ "shelf": [
301
+ ["rubik's cube", "{SpatialRelationEnum.INSIDE}"],
302
+ ["toy car", "{SpatialRelationEnum.INSIDE}"],
303
+ ["pen", "{SpatialRelationEnum.INSIDE}"]
304
+ ]
305
+ }}
306
+ ```
307
+
308
+ Input:
309
+ {{
310
+ "task_desc": "Put the marker in the cup on the counter.",
311
+ "task": "pick and place",
312
+ "robot": "franka",
313
+ "background": "kitchen",
314
+ "context": "counter",
315
+ "manipulated_objs": ["marker", "cup"],
316
+ "distractor_objs": ["plate", "spoon"]
317
+ }}
318
+ Intermediate Think:
319
+ counter {SpatialRelationEnum.FLOOR} kitchen
320
+ {RobotItemEnum.FRANKA} {SpatialRelationEnum.IN} kitchen
321
+ marker {SpatialRelationEnum.ON} counter
322
+ cup {SpatialRelationEnum.ON} counter
323
+ plate {SpatialRelationEnum.ON} counter
324
+ spoon {SpatialRelationEnum.ON} counter
325
+ Final Output:
326
+ ```json
327
+ {{
328
+ "kitchen": [
329
+ ["counter", "{SpatialRelationEnum.FLOOR}"],
330
+ ["{RobotItemEnum.FRANKA}", "{SpatialRelationEnum.IN}"]
331
+ ],
332
+ "counter": [
333
+ ["marker", "{SpatialRelationEnum.ON}"],
334
+ ["cup", "{SpatialRelationEnum.ON}"],
335
+ ["plate", "{SpatialRelationEnum.ON}"],
336
+ ["spoon", "{SpatialRelationEnum.ON}"]
337
+ ]
338
+ }}
339
+ ```
340
+ """
341
+
342
+
343
+ LAYOUT_DESCRIBER_PROMPT = """
344
+ You are a 3D asset style descriptor.
345
+
346
+ Given a task description and a dictionary where the key is the object content and
347
+ the value is the object type, output a JSON dictionary with each object paired
348
+ with a concise, styled visual description suitable for 3D asset generation.
349
+
350
+ Generation Guidelines:
351
+ - For each object, brainstorm multiple style candidates before selecting the final
352
+ description. Vary phrasing, material, texture, color, and spatial details.
353
+ - Each description must be a maximum of 15 words, including color, style, materials.
354
+ - Descriptions should be visually grounded, specific, and reflect surface texture and structure.
355
+ - For objects marked as "context", explicitly mention the object is standalone, has an empty top.
356
+ - Use rich style descriptors: e.g., "scratched brown wooden desk" etc.
357
+ - Ensure all object styles align with the task's overall context and environment.
358
+
359
+ Format your output in JSON like the example below.
360
+
361
+ Example Input:
362
+ "Pick up the rope on the chair and put it in the box. {'living room': 'background', 'chair': 'context',
363
+ 'rope': 'manipulated_objs', 'box': 'manipulated_objs', 'magazine': 'distractor_objs'}"
364
+
365
+ Example Output:
366
+ ```json
367
+ {
368
+ "living room": "modern cozy living room with soft sunlight and light grey carpet",
369
+ "chair": "standalone dark oak chair with no surroundings and clean empty seat",
370
+ "rope": "twisted hemp rope with rough fibers and dusty beige texture",
371
+ "box": "slightly crumpled cardboard box with open flaps and brown textured surface",
372
+ "magazine": "celebrity magazine with glossy red cover and large bold title"
373
+ }
374
+ ```
375
+ """
376
+
377
+
378
+ class LayoutDesigner(object):
379
+ def __init__(
380
+ self,
381
+ gpt_client: GPTclient,
382
+ system_prompt: str,
383
+ verbose: bool = False,
384
+ ) -> None:
385
+ self.prompt = system_prompt.strip()
386
+ self.verbose = verbose
387
+ self.gpt_client = gpt_client
388
+
389
+ def query(self, prompt: str, params: dict = None) -> str:
390
+ full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\""
391
+
392
+ response = self.gpt_client.query(
393
+ text_prompt=full_prompt,
394
+ params=params,
395
+ )
396
+
397
+ if self.verbose:
398
+ logger.info(f"Response: {response}")
399
+
400
+ return response
401
+
402
+ def format_response(self, response: str) -> dict:
403
+ cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip())
404
+ try:
405
+ output = json.loads(cleaned)
406
+ except json.JSONDecodeError as e:
407
+ raise json.JSONDecodeError(
408
+ f"Error: {e}, failed to parse JSON response: {response}"
409
+ )
410
+
411
+ return output
412
+
413
+ def format_response_repair(self, response: str) -> dict:
414
+ return json_repair.loads(response)
415
+
416
+ def save_output(self, output: dict, save_path: str) -> None:
417
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
418
+ with open(save_path, 'w') as f:
419
+ json.dump(output, f, indent=4)
420
+
421
+ def __call__(
422
+ self, prompt: str, save_path: str = None, params: dict = None
423
+ ) -> dict | str:
424
+ response = self.query(prompt, params=params)
425
+ output = self.format_response_repair(response)
426
+ self.save_output(output, save_path) if save_path else None
427
+
428
+ return output
429
+
430
+
431
+ LAYOUT_DISASSEMBLER = LayoutDesigner(
432
+ gpt_client=GPT_CLIENT, system_prompt=LAYOUT_DISASSEMBLE_PROMPT
433
+ )
434
+ LAYOUT_GRAPHER = LayoutDesigner(
435
+ gpt_client=GPT_CLIENT, system_prompt=LAYOUT_HIERARCHY_PROMPT
436
+ )
437
+ LAYOUT_DESCRIBER = LayoutDesigner(
438
+ gpt_client=GPT_CLIENT, system_prompt=LAYOUT_DESCRIBER_PROMPT
439
+ )
440
+
441
+
442
+ def build_scene_layout(
443
+ task_desc: str, output_path: str = None, gpt_params: dict = None
444
+ ) -> LayoutInfo:
445
+ layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params)
446
+ layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params)
447
+ object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
448
+ obj_prompt = f'{layout_relation["task_desc"]} {object_mapping}'
449
+ objs_desc = LAYOUT_DESCRIBER(obj_prompt, params=gpt_params)
450
+ layout_info = LayoutInfo(
451
+ layout_tree, layout_relation, objs_desc, object_mapping
452
+ )
453
+
454
+ if output_path is not None:
455
+ visualizer = SceneTreeVisualizer(layout_info)
456
+ visualizer.render(save_path=output_path)
457
+ logger.info(f"Scene hierarchy tree saved to {output_path}")
458
+
459
+ return layout_info
460
+
461
+
462
+ def parse_args():
463
+ parser = argparse.ArgumentParser(description="3D Scene Layout Designer")
464
+ parser.add_argument(
465
+ "--task_desc",
466
+ type=str,
467
+ default="Put the apples on the table on the plate",
468
+ help="Natural language description of the robotic task",
469
+ )
470
+ parser.add_argument(
471
+ "--save_root",
472
+ type=str,
473
+ default="outputs/layout_tree",
474
+ help="Path to save the layout output",
475
+ )
476
+ return parser.parse_args()
477
+
478
+
479
+ if __name__ == "__main__":
480
+ from embodied_gen.utils.enum import LayoutInfo
481
+ from embodied_gen.utils.process_media import SceneTreeVisualizer
482
+
483
+ args = parse_args()
484
+ params = {
485
+ "temperature": 1.0,
486
+ "top_p": 0.95,
487
+ "frequency_penalty": 0.3,
488
+ "presence_penalty": 0.5,
489
+ }
490
+ layout_relation = LAYOUT_DISASSEMBLER(args.task_desc, params=params)
491
+ layout_tree = LAYOUT_GRAPHER(layout_relation, params=params)
492
+
493
+ object_mapping = Scene3DItemEnum.object_mapping(layout_relation)
494
+ obj_prompt = f'{layout_relation["task_desc"]} {object_mapping}'
495
+
496
+ objs_desc = LAYOUT_DESCRIBER(obj_prompt, params=params)
497
+
498
+ layout_info = LayoutInfo(layout_tree, layout_relation, objs_desc)
499
+
500
+ visualizer = SceneTreeVisualizer(layout_info)
501
+ os.makedirs(args.save_root, exist_ok=True)
502
+ scene_graph_path = f"{args.save_root}/scene_tree.jpg"
503
+ visualizer.render(save_path=scene_graph_path)
504
+ with open(f"{args.save_root}/layout.json", "w") as f:
505
+ json.dump(layout_info.to_dict(), f, indent=4)
506
+
507
+ print(f"Scene hierarchy tree saved to {scene_graph_path}")
508
+ print(f"Disassembled Layout: {layout_relation}")
509
+ print(f"Layout Graph: {layout_tree}")
510
+ print(f"Layout Descriptions: {objs_desc}")
embodied_gen/models/segment_model.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import logging
19
+ import os
20
+ from typing import Literal, Union
21
+
22
+ import cv2
23
+ import numpy as np
24
+ import rembg
25
+ import torch
26
+ from huggingface_hub import snapshot_download
27
+ from PIL import Image
28
+ from segment_anything import (
29
+ SamAutomaticMaskGenerator,
30
+ SamPredictor,
31
+ sam_model_registry,
32
+ )
33
+ from transformers import pipeline
34
+ from embodied_gen.data.utils import resize_pil, trellis_preprocess
35
+ from embodied_gen.utils.process_media import filter_small_connected_components
36
+ from embodied_gen.validators.quality_checkers import ImageSegChecker
37
+
38
+ logging.basicConfig(level=logging.INFO)
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ __all__ = [
43
+ "SAMRemover",
44
+ "SAMPredictor",
45
+ "RembgRemover",
46
+ "get_segmented_image_by_agent",
47
+ ]
48
+
49
+
50
+ class SAMRemover(object):
51
+ """Loading SAM models and performing background removal on images.
52
+
53
+ Attributes:
54
+ checkpoint (str): Path to the model checkpoint.
55
+ model_type (str): Type of the SAM model to load (default: "vit_h").
56
+ area_ratio (float): Area ratio filtering small connected components.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ checkpoint: str = None,
62
+ model_type: str = "vit_h",
63
+ area_ratio: float = 15,
64
+ ):
65
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ self.model_type = model_type
67
+ self.area_ratio = area_ratio
68
+
69
+ if checkpoint is None:
70
+ suffix = "sam"
71
+ model_path = snapshot_download(
72
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
73
+ )
74
+ checkpoint = os.path.join(
75
+ model_path, suffix, "sam_vit_h_4b8939.pth"
76
+ )
77
+
78
+ self.mask_generator = self._load_sam_model(checkpoint)
79
+
80
+ def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator:
81
+ sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
82
+ sam.to(device=self.device)
83
+
84
+ return SamAutomaticMaskGenerator(sam)
85
+
86
+ def __call__(
87
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
88
+ ) -> Image.Image:
89
+ """Removes the background from an image using the SAM model.
90
+
91
+ Args:
92
+ image (Union[str, Image.Image, np.ndarray]): Input image,
93
+ can be a file path, PIL Image, or numpy array.
94
+ save_path (str): Path to save the output image (default: None).
95
+
96
+ Returns:
97
+ Image.Image: The image with background removed,
98
+ including an alpha channel.
99
+ """
100
+ # Convert input to numpy array
101
+ if isinstance(image, str):
102
+ image = Image.open(image)
103
+ elif isinstance(image, np.ndarray):
104
+ image = Image.fromarray(image).convert("RGB")
105
+ image = resize_pil(image)
106
+ image = np.array(image.convert("RGB"))
107
+
108
+ # Generate masks
109
+ masks = self.mask_generator.generate(image)
110
+ masks = sorted(masks, key=lambda x: x["area"], reverse=True)
111
+
112
+ if not masks:
113
+ logger.warning(
114
+ "Segmentation failed: No mask generated, return raw image."
115
+ )
116
+ output_image = Image.fromarray(image, mode="RGB")
117
+ else:
118
+ # Use the largest mask
119
+ best_mask = masks[0]["segmentation"]
120
+ mask = (best_mask * 255).astype(np.uint8)
121
+ mask = filter_small_connected_components(
122
+ mask, area_ratio=self.area_ratio
123
+ )
124
+ # Apply the mask to remove the background
125
+ background_removed = cv2.bitwise_and(image, image, mask=mask)
126
+ output_image = np.dstack((background_removed, mask))
127
+ output_image = Image.fromarray(output_image, mode="RGBA")
128
+
129
+ if save_path is not None:
130
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
131
+ output_image.save(save_path)
132
+
133
+ return output_image
134
+
135
+
136
+ class SAMPredictor(object):
137
+ def __init__(
138
+ self,
139
+ checkpoint: str = None,
140
+ model_type: str = "vit_h",
141
+ binary_thresh: float = 0.1,
142
+ device: str = "cuda",
143
+ ):
144
+ self.device = device
145
+ self.model_type = model_type
146
+
147
+ if checkpoint is None:
148
+ suffix = "sam"
149
+ model_path = snapshot_download(
150
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
151
+ )
152
+ checkpoint = os.path.join(
153
+ model_path, suffix, "sam_vit_h_4b8939.pth"
154
+ )
155
+
156
+ self.predictor = self._load_sam_model(checkpoint)
157
+ self.binary_thresh = binary_thresh
158
+
159
+ def _load_sam_model(self, checkpoint: str) -> SamPredictor:
160
+ sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
161
+ sam.to(device=self.device)
162
+
163
+ return SamPredictor(sam)
164
+
165
+ def preprocess_image(self, image: Image.Image) -> np.ndarray:
166
+ if isinstance(image, str):
167
+ image = Image.open(image)
168
+ elif isinstance(image, np.ndarray):
169
+ image = Image.fromarray(image).convert("RGB")
170
+
171
+ image = resize_pil(image)
172
+ image = np.array(image.convert("RGB"))
173
+
174
+ return image
175
+
176
+ def generate_masks(
177
+ self,
178
+ image: np.ndarray,
179
+ selected_points: list[list[int]],
180
+ ) -> np.ndarray:
181
+ if len(selected_points) == 0:
182
+ return []
183
+
184
+ points = (
185
+ torch.Tensor([p for p, _ in selected_points])
186
+ .to(self.predictor.device)
187
+ .unsqueeze(1)
188
+ )
189
+
190
+ labels = (
191
+ torch.Tensor([int(l) for _, l in selected_points])
192
+ .to(self.predictor.device)
193
+ .unsqueeze(1)
194
+ )
195
+
196
+ transformed_points = self.predictor.transform.apply_coords_torch(
197
+ points, image.shape[:2]
198
+ )
199
+
200
+ masks, scores, _ = self.predictor.predict_torch(
201
+ point_coords=transformed_points,
202
+ point_labels=labels,
203
+ multimask_output=True,
204
+ )
205
+ valid_mask = masks[:, torch.argmax(scores, dim=1)]
206
+ masks_pos = valid_mask[labels[:, 0] == 1, 0].cpu().detach().numpy()
207
+ masks_neg = valid_mask[labels[:, 0] == 0, 0].cpu().detach().numpy()
208
+ if len(masks_neg) == 0:
209
+ masks_neg = np.zeros_like(masks_pos)
210
+ if len(masks_pos) == 0:
211
+ masks_pos = np.zeros_like(masks_neg)
212
+ masks_neg = masks_neg.max(axis=0, keepdims=True)
213
+ masks_pos = masks_pos.max(axis=0, keepdims=True)
214
+ valid_mask = (masks_pos.astype(int) - masks_neg.astype(int)).clip(0, 1)
215
+
216
+ binary_mask = (valid_mask > self.binary_thresh).astype(np.int32)
217
+
218
+ return [(mask, f"mask_{i}") for i, mask in enumerate(binary_mask)]
219
+
220
+ def get_segmented_image(
221
+ self, image: np.ndarray, masks: list[tuple[np.ndarray, str]]
222
+ ) -> Image.Image:
223
+ seg_image = Image.fromarray(image, mode="RGB")
224
+ alpha_channel = np.zeros(
225
+ (seg_image.height, seg_image.width), dtype=np.uint8
226
+ )
227
+ for mask, _ in masks:
228
+ # Use the maximum to combine multiple masks
229
+ alpha_channel = np.maximum(alpha_channel, mask)
230
+
231
+ alpha_channel = np.clip(alpha_channel, 0, 1)
232
+ alpha_channel = (alpha_channel * 255).astype(np.uint8)
233
+ alpha_image = Image.fromarray(alpha_channel, mode="L")
234
+ r, g, b = seg_image.split()
235
+ seg_image = Image.merge("RGBA", (r, g, b, alpha_image))
236
+
237
+ return seg_image
238
+
239
+ def __call__(
240
+ self,
241
+ image: Union[str, Image.Image, np.ndarray],
242
+ selected_points: list[list[int]],
243
+ ) -> Image.Image:
244
+ image = self.preprocess_image(image)
245
+ self.predictor.set_image(image)
246
+ masks = self.generate_masks(image, selected_points)
247
+
248
+ return self.get_segmented_image(image, masks)
249
+
250
+
251
+ class RembgRemover(object):
252
+ def __init__(self):
253
+ self.rembg_session = rembg.new_session("u2net")
254
+
255
+ def __call__(
256
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
257
+ ) -> Image.Image:
258
+ if isinstance(image, str):
259
+ image = Image.open(image)
260
+ elif isinstance(image, np.ndarray):
261
+ image = Image.fromarray(image)
262
+
263
+ image = resize_pil(image)
264
+ output_image = rembg.remove(image, session=self.rembg_session)
265
+
266
+ if save_path is not None:
267
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
268
+ output_image.save(save_path)
269
+
270
+ return output_image
271
+
272
+
273
+ class BMGG14Remover(object):
274
+ def __init__(self) -> None:
275
+ self.model = pipeline(
276
+ "image-segmentation",
277
+ model="briaai/RMBG-1.4",
278
+ trust_remote_code=True,
279
+ )
280
+
281
+ def __call__(
282
+ self, image: Union[str, Image.Image, np.ndarray], save_path: str = None
283
+ ):
284
+ if isinstance(image, str):
285
+ image = Image.open(image)
286
+ elif isinstance(image, np.ndarray):
287
+ image = Image.fromarray(image)
288
+
289
+ image = resize_pil(image)
290
+ output_image = self.model(image)
291
+
292
+ if save_path is not None:
293
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
294
+ output_image.save(save_path)
295
+
296
+ return output_image
297
+
298
+
299
+ def invert_rgba_pil(
300
+ image: Image.Image, mask: Image.Image, save_path: str = None
301
+ ) -> Image.Image:
302
+ mask = (255 - np.array(mask))[..., None]
303
+ image_array = np.concatenate([np.array(image), mask], axis=-1)
304
+ inverted_image = Image.fromarray(image_array, "RGBA")
305
+
306
+ if save_path is not None:
307
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
308
+ inverted_image.save(save_path)
309
+
310
+ return inverted_image
311
+
312
+
313
+ def get_segmented_image_by_agent(
314
+ image: Image.Image,
315
+ sam_remover: SAMRemover,
316
+ rbg_remover: RembgRemover,
317
+ seg_checker: ImageSegChecker = None,
318
+ save_path: str = None,
319
+ mode: Literal["loose", "strict"] = "loose",
320
+ ) -> Image.Image:
321
+ def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool:
322
+ if seg_checker is None:
323
+ return True
324
+ return raw_img.mode == "RGBA" and seg_checker([raw_img, seg_img])[0]
325
+
326
+ out_sam = f"{save_path}_sam.png" if save_path else None
327
+ out_sam_inv = f"{save_path}_sam_inv.png" if save_path else None
328
+ out_rbg = f"{save_path}_rbg.png" if save_path else None
329
+
330
+ seg_image = sam_remover(image, out_sam)
331
+ seg_image = seg_image.convert("RGBA")
332
+ _, _, _, alpha = seg_image.split()
333
+ seg_image_inv = invert_rgba_pil(image.convert("RGB"), alpha, out_sam_inv)
334
+ seg_image_rbg = rbg_remover(image, out_rbg)
335
+
336
+ final_image = None
337
+ if _is_valid_seg(image, seg_image):
338
+ final_image = seg_image
339
+ elif _is_valid_seg(image, seg_image_inv):
340
+ final_image = seg_image_inv
341
+ elif _is_valid_seg(image, seg_image_rbg):
342
+ logger.warning(f"Failed to segment by `SAM`, retry with `rembg`.")
343
+ final_image = seg_image_rbg
344
+ else:
345
+ if mode == "strict":
346
+ raise RuntimeError(
347
+ f"Failed to segment by `SAM` or `rembg`, abort."
348
+ )
349
+ logger.warning("Failed to segment by SAM or rembg, use raw image.")
350
+ final_image = image.convert("RGBA")
351
+
352
+ if save_path:
353
+ final_image.save(save_path)
354
+
355
+ final_image = trellis_preprocess(final_image)
356
+
357
+ return final_image
358
+
359
+
360
+ if __name__ == "__main__":
361
+ input_image = "outputs/text2image/demo_objects/electrical/sample_0.jpg"
362
+ output_image = "sample_0_seg2.png"
363
+
364
+ # input_image = "outputs/text2image/tmp/coffee_machine.jpeg"
365
+ # output_image = "outputs/text2image/tmp/coffee_machine_seg.png"
366
+
367
+ # input_image = "outputs/text2image/tmp/bucket.jpeg"
368
+ # output_image = "outputs/text2image/tmp/bucket_seg.png"
369
+
370
+ remover = SAMRemover(model_type="vit_h")
371
+ remover = RembgRemover()
372
+ clean_image = remover(input_image)
373
+ clean_image.save(output_image)
374
+ get_segmented_image_by_agent(
375
+ Image.open(input_image), remover, remover, None, "./test_seg.png"
376
+ )
377
+
378
+ remover = BMGG14Remover()
379
+ remover("embodied_gen/models/test_seg.jpg", "./seg.png")
embodied_gen/models/sr_model.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import logging
19
+ import os
20
+ from typing import Union
21
+
22
+ import numpy as np
23
+ import spaces
24
+ import torch
25
+ from huggingface_hub import snapshot_download
26
+ from PIL import Image
27
+ from embodied_gen.data.utils import get_images_from_grid
28
+
29
+ logging.basicConfig(
30
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
31
+ )
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ __all__ = [
36
+ "ImageStableSR",
37
+ "ImageRealESRGAN",
38
+ ]
39
+
40
+
41
+ class ImageStableSR:
42
+ """Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI."""
43
+
44
+ def __init__(
45
+ self,
46
+ model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
47
+ device="cuda",
48
+ ) -> None:
49
+ from diffusers import StableDiffusionUpscalePipeline
50
+
51
+ self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
52
+ model_path,
53
+ torch_dtype=torch.float16,
54
+ ).to(device)
55
+ self.up_pipeline_x4.set_progress_bar_config(disable=True)
56
+ # self.up_pipeline_x4.enable_model_cpu_offload()
57
+
58
+ @spaces.GPU
59
+ def __call__(
60
+ self,
61
+ image: Union[Image.Image, np.ndarray],
62
+ prompt: str = "",
63
+ infer_step: int = 20,
64
+ ) -> Image.Image:
65
+ if isinstance(image, np.ndarray):
66
+ image = Image.fromarray(image)
67
+
68
+ image = image.convert("RGB")
69
+
70
+ with torch.no_grad():
71
+ upscaled_image = self.up_pipeline_x4(
72
+ image=image,
73
+ prompt=[prompt],
74
+ num_inference_steps=infer_step,
75
+ ).images[0]
76
+
77
+ return upscaled_image
78
+
79
+
80
+ class ImageRealESRGAN:
81
+ """A wrapper for Real-ESRGAN-based image super-resolution.
82
+
83
+ This class uses the RealESRGAN model to perform image upscaling,
84
+ typically by a factor of 4.
85
+
86
+ Attributes:
87
+ outscale (int): The output image scale factor (e.g., 2, 4).
88
+ model_path (str): Path to the pre-trained model weights.
89
+ """
90
+
91
+ def __init__(self, outscale: int, model_path: str = None) -> None:
92
+ # monkey patch to support torchvision>=0.16
93
+ import torchvision
94
+ from packaging import version
95
+
96
+ if version.parse(torchvision.__version__) > version.parse("0.16"):
97
+ import sys
98
+ import types
99
+
100
+ import torchvision.transforms.functional as TF
101
+
102
+ functional_tensor = types.ModuleType(
103
+ "torchvision.transforms.functional_tensor"
104
+ )
105
+ functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
106
+ sys.modules["torchvision.transforms.functional_tensor"] = (
107
+ functional_tensor
108
+ )
109
+
110
+ self.outscale = outscale
111
+ self.upsampler = None
112
+
113
+ if model_path is None:
114
+ suffix = "super_resolution"
115
+ model_path = snapshot_download(
116
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
117
+ )
118
+ model_path = os.path.join(
119
+ model_path, suffix, "RealESRGAN_x4plus.pth"
120
+ )
121
+
122
+ self.model_path = model_path
123
+
124
+ def _lazy_init(self):
125
+ if self.upsampler is None:
126
+ from basicsr.archs.rrdbnet_arch import RRDBNet
127
+ from realesrgan import RealESRGANer
128
+
129
+ model = RRDBNet(
130
+ num_in_ch=3,
131
+ num_out_ch=3,
132
+ num_feat=64,
133
+ num_block=23,
134
+ num_grow_ch=32,
135
+ scale=4,
136
+ )
137
+
138
+ self.upsampler = RealESRGANer(
139
+ scale=4,
140
+ model_path=self.model_path,
141
+ model=model,
142
+ pre_pad=0,
143
+ half=True,
144
+ )
145
+
146
+ @spaces.GPU
147
+ def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
148
+ self._lazy_init()
149
+
150
+ if isinstance(image, Image.Image):
151
+ image = np.array(image)
152
+
153
+ with torch.no_grad():
154
+ output, _ = self.upsampler.enhance(image, outscale=self.outscale)
155
+
156
+ return Image.fromarray(output)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
161
+
162
+ # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
163
+ super_model = ImageRealESRGAN(outscale=4)
164
+ multiviews = get_images_from_grid(color_path, img_size=512)
165
+ multiviews = [super_model(img.convert("RGB")) for img in multiviews]
166
+ for idx, img in enumerate(multiviews):
167
+ img.save(f"sr{idx}.png")
168
+
169
+ # # Use stable diffusion for x4 (512->2048) image super resolution.
170
+ # super_model = ImageStableSR()
171
+ # multiviews = get_images_from_grid(color_path, img_size=512)
172
+ # multiviews = [super_model(img) for img in multiviews]
173
+ # for idx, img in enumerate(multiviews):
174
+ # img.save(f"sr_stable{idx}.png")
embodied_gen/models/text_model.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import logging
19
+ import os
20
+ import random
21
+ import subprocess
22
+
23
+ import numpy as np
24
+ import torch
25
+ from diffusers import (
26
+ AutoencoderKL,
27
+ EulerDiscreteScheduler,
28
+ UNet2DConditionModel,
29
+ )
30
+ from kolors.models.modeling_chatglm import ChatGLMModel
31
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
32
+ from kolors.models.unet_2d_condition import (
33
+ UNet2DConditionModel as UNet2DConditionModelIP,
34
+ )
35
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
36
+ StableDiffusionXLPipeline,
37
+ )
38
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
39
+ StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
40
+ )
41
+ from PIL import Image
42
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
43
+
44
+ logging.basicConfig(level=logging.INFO)
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ __all__ = [
49
+ "build_text2img_ip_pipeline",
50
+ "build_text2img_pipeline",
51
+ "text2img_gen",
52
+ "download_kolors_weights",
53
+ ]
54
+
55
+ PROMPT_APPEND = (
56
+ "Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, "
57
+ "no surroundings, high-quality appearance, vivid colors, on a plain clean surface, 3D style revealing multiple surfaces"
58
+ )
59
+ PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality"
60
+
61
+
62
+ def download_kolors_weights(local_dir: str = "weights/Kolors") -> None:
63
+ logger.info(f"Download kolors weights from huggingface...")
64
+ os.makedirs(local_dir, exist_ok=True)
65
+ subprocess.run(
66
+ [
67
+ "huggingface-cli",
68
+ "download",
69
+ "--resume-download",
70
+ "Kwai-Kolors/Kolors",
71
+ "--local-dir",
72
+ local_dir,
73
+ ],
74
+ check=True,
75
+ )
76
+
77
+ ip_adapter_path = f"{local_dir}/../Kolors-IP-Adapter-Plus"
78
+ subprocess.run(
79
+ [
80
+ "huggingface-cli",
81
+ "download",
82
+ "--resume-download",
83
+ "Kwai-Kolors/Kolors-IP-Adapter-Plus",
84
+ "--local-dir",
85
+ ip_adapter_path,
86
+ ],
87
+ check=True,
88
+ )
89
+
90
+
91
+ def build_text2img_ip_pipeline(
92
+ ckpt_dir: str,
93
+ ref_scale: float,
94
+ device: str = "cuda",
95
+ ) -> StableDiffusionXLPipelineIP:
96
+ download_kolors_weights(ckpt_dir)
97
+
98
+ text_encoder = ChatGLMModel.from_pretrained(
99
+ f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
100
+ ).half()
101
+ tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
102
+ vae = AutoencoderKL.from_pretrained(
103
+ f"{ckpt_dir}/vae", revision=None
104
+ ).half()
105
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
106
+ unet = UNet2DConditionModelIP.from_pretrained(
107
+ f"{ckpt_dir}/unet", revision=None
108
+ ).half()
109
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
110
+ f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder",
111
+ ignore_mismatched_sizes=True,
112
+ ).to(dtype=torch.float16)
113
+ clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
114
+
115
+ pipe = StableDiffusionXLPipelineIP(
116
+ vae=vae,
117
+ text_encoder=text_encoder,
118
+ tokenizer=tokenizer,
119
+ unet=unet,
120
+ scheduler=scheduler,
121
+ image_encoder=image_encoder,
122
+ feature_extractor=clip_image_processor,
123
+ force_zeros_for_empty_prompt=False,
124
+ )
125
+
126
+ if hasattr(pipe.unet, "encoder_hid_proj"):
127
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
128
+
129
+ pipe.load_ip_adapter(
130
+ f"{ckpt_dir}/../Kolors-IP-Adapter-Plus",
131
+ subfolder="",
132
+ weight_name=["ip_adapter_plus_general.bin"],
133
+ )
134
+ pipe.set_ip_adapter_scale([ref_scale])
135
+
136
+ pipe = pipe.to(device)
137
+ pipe.image_encoder = pipe.image_encoder.to(device)
138
+ # pipe.enable_model_cpu_offload()
139
+ # # pipe.enable_xformers_memory_efficient_attention()
140
+ # pipe.enable_vae_slicing()
141
+
142
+ return pipe
143
+
144
+
145
+ def build_text2img_pipeline(
146
+ ckpt_dir: str,
147
+ device: str = "cuda",
148
+ ) -> StableDiffusionXLPipeline:
149
+ download_kolors_weights(ckpt_dir)
150
+
151
+ text_encoder = ChatGLMModel.from_pretrained(
152
+ f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16
153
+ ).half()
154
+ tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder")
155
+ vae = AutoencoderKL.from_pretrained(
156
+ f"{ckpt_dir}/vae", revision=None
157
+ ).half()
158
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
159
+ unet = UNet2DConditionModel.from_pretrained(
160
+ f"{ckpt_dir}/unet", revision=None
161
+ ).half()
162
+ pipe = StableDiffusionXLPipeline(
163
+ vae=vae,
164
+ text_encoder=text_encoder,
165
+ tokenizer=tokenizer,
166
+ unet=unet,
167
+ scheduler=scheduler,
168
+ force_zeros_for_empty_prompt=False,
169
+ )
170
+ pipe = pipe.to(device)
171
+ # pipe.enable_model_cpu_offload()
172
+ # pipe.enable_xformers_memory_efficient_attention()
173
+
174
+ return pipe
175
+
176
+
177
+ def text2img_gen(
178
+ prompt: str,
179
+ n_sample: int,
180
+ guidance_scale: float,
181
+ pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP,
182
+ ip_image: Image.Image | str = None,
183
+ image_wh: tuple[int, int] = [1024, 1024],
184
+ infer_step: int = 50,
185
+ ip_image_size: int = 512,
186
+ seed: int = None,
187
+ ) -> list[Image.Image]:
188
+ prompt = PROMPT_KAPPEND.format(object=prompt.strip())
189
+ logger.info(f"Processing prompt: {prompt}")
190
+
191
+ generator = None
192
+ if seed is not None:
193
+ generator = torch.Generator(pipeline.device).manual_seed(seed)
194
+ torch.manual_seed(seed)
195
+ np.random.seed(seed)
196
+ random.seed(seed)
197
+
198
+ kwargs = dict(
199
+ prompt=prompt,
200
+ height=image_wh[1],
201
+ width=image_wh[0],
202
+ num_inference_steps=infer_step,
203
+ guidance_scale=guidance_scale,
204
+ num_images_per_prompt=n_sample,
205
+ generator=generator,
206
+ )
207
+ if ip_image is not None:
208
+ if isinstance(ip_image, str):
209
+ ip_image = Image.open(ip_image)
210
+ ip_image = ip_image.resize((ip_image_size, ip_image_size))
211
+ kwargs.update(ip_adapter_image=[ip_image])
212
+
213
+ return pipeline(**kwargs).images
embodied_gen/models/texture_model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import os
19
+
20
+ import torch
21
+ from diffusers import AutoencoderKL, DiffusionPipeline, EulerDiscreteScheduler
22
+ from huggingface_hub import snapshot_download
23
+ from kolors.models.controlnet import ControlNetModel
24
+ from kolors.models.modeling_chatglm import ChatGLMModel
25
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
26
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
27
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
28
+ StableDiffusionXLControlNetImg2ImgPipeline,
29
+ )
30
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
31
+ from embodied_gen.models.text_model import download_kolors_weights
32
+ from embodied_gen.utils.log import logger
33
+
34
+ __all__ = [
35
+ "build_texture_gen_pipe",
36
+ ]
37
+
38
+
39
+ def build_texture_gen_pipe(
40
+ base_ckpt_dir: str,
41
+ controlnet_ckpt: str = None,
42
+ ip_adapt_scale: float = 0,
43
+ device: str = "cuda",
44
+ ) -> DiffusionPipeline:
45
+ download_kolors_weights(f"{base_ckpt_dir}/Kolors")
46
+ logger.info(f"Load Kolors weights...")
47
+ tokenizer = ChatGLMTokenizer.from_pretrained(
48
+ f"{base_ckpt_dir}/Kolors/text_encoder"
49
+ )
50
+ text_encoder = ChatGLMModel.from_pretrained(
51
+ f"{base_ckpt_dir}/Kolors/text_encoder", torch_dtype=torch.float16
52
+ ).half()
53
+ vae = AutoencoderKL.from_pretrained(
54
+ f"{base_ckpt_dir}/Kolors/vae", revision=None
55
+ ).half()
56
+ unet = UNet2DConditionModel.from_pretrained(
57
+ f"{base_ckpt_dir}/Kolors/unet", revision=None
58
+ ).half()
59
+ scheduler = EulerDiscreteScheduler.from_pretrained(
60
+ f"{base_ckpt_dir}/Kolors/scheduler"
61
+ )
62
+
63
+ if controlnet_ckpt is None:
64
+ suffix = "texture_gen_mv_v1" # "geo_cond_mv"
65
+ model_path = snapshot_download(
66
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
67
+ )
68
+ controlnet_ckpt = os.path.join(model_path, suffix)
69
+
70
+ controlnet = ControlNetModel.from_pretrained(
71
+ controlnet_ckpt, use_safetensors=True
72
+ ).half()
73
+
74
+ # IP-Adapter model
75
+ image_encoder = None
76
+ clip_image_processor = None
77
+ if ip_adapt_scale > 0:
78
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
79
+ f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus/image_encoder",
80
+ # ignore_mismatched_sizes=True,
81
+ ).to(dtype=torch.float16)
82
+ ip_img_size = 336
83
+ clip_image_processor = CLIPImageProcessor(
84
+ size=ip_img_size, crop_size=ip_img_size
85
+ )
86
+
87
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline(
88
+ vae=vae,
89
+ controlnet=controlnet,
90
+ text_encoder=text_encoder,
91
+ tokenizer=tokenizer,
92
+ unet=unet,
93
+ scheduler=scheduler,
94
+ image_encoder=image_encoder,
95
+ feature_extractor=clip_image_processor,
96
+ force_zeros_for_empty_prompt=False,
97
+ )
98
+
99
+ if ip_adapt_scale > 0:
100
+ if hasattr(pipe.unet, "encoder_hid_proj"):
101
+ pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
102
+ pipe.load_ip_adapter(
103
+ f"{base_ckpt_dir}/Kolors-IP-Adapter-Plus",
104
+ subfolder="",
105
+ weight_name=["ip_adapter_plus_general.bin"],
106
+ )
107
+ pipe.set_ip_adapter_scale([ip_adapt_scale])
108
+
109
+ pipe = pipe.to(device)
110
+ # pipe.enable_model_cpu_offload()
111
+
112
+ return pipe
embodied_gen/scripts/compose_layout.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ import shutil
20
+ from dataclasses import dataclass
21
+
22
+ import tyro
23
+ from embodied_gen.scripts.simulate_sapien import entrypoint as sim_cli
24
+ from embodied_gen.utils.enum import LayoutInfo
25
+ from embodied_gen.utils.geometry import bfs_placement, compose_mesh_scene
26
+ from embodied_gen.utils.log import logger
27
+
28
+
29
+ @dataclass
30
+ class LayoutPlacementConfig:
31
+ layout_path: str
32
+ output_dir: str | None = None
33
+ seed: int | None = None
34
+ max_attempts: int = 1000
35
+ output_iscene: bool = False
36
+ insert_robot: bool = False
37
+
38
+
39
+ def entrypoint(**kwargs):
40
+ if kwargs is None or len(kwargs) == 0:
41
+ args = tyro.cli(LayoutPlacementConfig)
42
+ else:
43
+ args = LayoutPlacementConfig(**kwargs)
44
+
45
+ output_dir = (
46
+ args.output_dir
47
+ if args.output_dir is not None
48
+ else os.path.dirname(args.layout_path)
49
+ )
50
+ os.makedirs(output_dir, exist_ok=True)
51
+ out_scene_path = f"{output_dir}/Iscene.glb"
52
+ out_layout_path = f"{output_dir}/layout.json"
53
+
54
+ layout_info = bfs_placement(args.layout_path, seed=args.seed)
55
+ origin_dir = os.path.dirname(args.layout_path)
56
+ for key in layout_info.assets:
57
+ src = f"{origin_dir}/{layout_info.assets[key]}"
58
+ dst = f"{output_dir}/{layout_info.assets[key]}"
59
+ if src == dst:
60
+ continue
61
+ shutil.copytree(src, dst, dirs_exist_ok=True)
62
+
63
+ with open(out_layout_path, "w") as f:
64
+ json.dump(layout_info.to_dict(), f, indent=4)
65
+
66
+ if args.output_iscene:
67
+ compose_mesh_scene(layout_info, out_scene_path)
68
+
69
+ sim_cli(
70
+ layout_path=out_layout_path,
71
+ output_dir=output_dir,
72
+ insert_robot=args.insert_robot,
73
+ )
74
+
75
+ logger.info(f"Layout placement completed in {output_dir}")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ entrypoint()
embodied_gen/scripts/gen_layout.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import gc
18
+ import json
19
+ import os
20
+ from dataclasses import dataclass, field
21
+ from shutil import copytree
22
+ from time import time
23
+ from typing import Optional
24
+
25
+ import torch
26
+ import tyro
27
+ from embodied_gen.models.layout import build_scene_layout
28
+ from embodied_gen.scripts.simulate_sapien import entrypoint as sim_cli
29
+ from embodied_gen.scripts.textto3d import text_to_3d
30
+ from embodied_gen.utils.config import GptParamsConfig
31
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
32
+ from embodied_gen.utils.geometry import bfs_placement, compose_mesh_scene
33
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
34
+ from embodied_gen.utils.log import logger
35
+ from embodied_gen.utils.process_media import (
36
+ load_scene_dict,
37
+ parse_text_prompts,
38
+ )
39
+ from embodied_gen.validators.quality_checkers import SemanticMatcher
40
+
41
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
42
+
43
+
44
+ @dataclass
45
+ class LayoutGenConfig:
46
+ task_descs: list[str]
47
+ output_root: str
48
+ bg_list: str = "outputs/bg_scenes/scene_list.txt"
49
+ n_img_sample: int = 3
50
+ text_guidance_scale: float = 7.0
51
+ img_denoise_step: int = 25
52
+ n_image_retry: int = 4
53
+ n_asset_retry: int = 3
54
+ n_pipe_retry: int = 2
55
+ seed_img: Optional[int] = None
56
+ seed_3d: Optional[int] = None
57
+ seed_layout: Optional[int] = None
58
+ keep_intermediate: bool = False
59
+ output_iscene: bool = False
60
+ insert_robot: bool = False
61
+ gpt_params: GptParamsConfig = field(
62
+ default_factory=lambda: GptParamsConfig(
63
+ temperature=1.0,
64
+ top_p=0.95,
65
+ frequency_penalty=0.3,
66
+ presence_penalty=0.5,
67
+ )
68
+ )
69
+
70
+
71
+ def entrypoint() -> None:
72
+ args = tyro.cli(LayoutGenConfig)
73
+ SCENE_MATCHER = SemanticMatcher(GPT_CLIENT)
74
+ task_descs = parse_text_prompts(args.task_descs)
75
+ scene_dict = load_scene_dict(args.bg_list)
76
+ gpt_params = args.gpt_params.to_dict()
77
+ for idx, task_desc in enumerate(task_descs):
78
+ logger.info(f"Generate Layout and 3D scene for task: {task_desc}")
79
+ output_root = f"{args.output_root}/task_{idx:04d}"
80
+ scene_graph_path = f"{output_root}/scene_tree.jpg"
81
+ start_time = time()
82
+ layout_info: LayoutInfo = build_scene_layout(
83
+ task_desc, scene_graph_path, gpt_params
84
+ )
85
+ prompts_mapping = {v: k for k, v in layout_info.objs_desc.items()}
86
+ prompts = [
87
+ v
88
+ for k, v in layout_info.objs_desc.items()
89
+ if layout_info.objs_mapping[k] != Scene3DItemEnum.BACKGROUND.value
90
+ ]
91
+
92
+ for prompt in prompts:
93
+ node = prompts_mapping[prompt]
94
+ generation_log = text_to_3d(
95
+ prompts=[
96
+ prompt,
97
+ ],
98
+ output_root=output_root,
99
+ asset_names=[
100
+ node,
101
+ ],
102
+ n_img_sample=args.n_img_sample,
103
+ text_guidance_scale=args.text_guidance_scale,
104
+ img_denoise_step=args.img_denoise_step,
105
+ n_image_retry=args.n_image_retry,
106
+ n_asset_retry=args.n_asset_retry,
107
+ n_pipe_retry=args.n_pipe_retry,
108
+ seed_img=args.seed_img,
109
+ seed_3d=args.seed_3d,
110
+ keep_intermediate=args.keep_intermediate,
111
+ )
112
+ layout_info.assets.update(generation_log["assets"])
113
+ layout_info.quality.update(generation_log["quality"])
114
+
115
+ # Background GEN (for efficiency, temp use retrieval instead)
116
+ bg_node = layout_info.relation[Scene3DItemEnum.BACKGROUND.value]
117
+ text = layout_info.objs_desc[bg_node]
118
+ match_key = SCENE_MATCHER.query(
119
+ text, str(scene_dict), params=gpt_params
120
+ )
121
+ n_max_attempt = 10
122
+ while match_key not in scene_dict and n_max_attempt > 0:
123
+ logger.error(
124
+ f"Cannot find matched scene {match_key}, retrying left {n_max_attempt}..."
125
+ )
126
+ match_key = SCENE_MATCHER.query(
127
+ text, str(scene_dict), params=gpt_params
128
+ )
129
+ n_max_attempt -= 1
130
+
131
+ match_scene_path = f"{os.path.dirname(args.bg_list)}/{match_key}"
132
+ bg_save_dir = os.path.join(output_root, "background")
133
+ copytree(match_scene_path, bg_save_dir, dirs_exist_ok=True)
134
+ layout_info.assets[bg_node] = "background"
135
+
136
+ # BFS layout placement.
137
+ layout_path = f"{output_root}/layout.json"
138
+ with open(layout_path, "w") as f:
139
+ json.dump(layout_info.to_dict(), f, indent=4)
140
+
141
+ layout_info = bfs_placement(
142
+ layout_path,
143
+ seed=args.seed_layout,
144
+ )
145
+ layout_path = f"{output_root}/layout.json"
146
+ with open(layout_path, "w") as f:
147
+ json.dump(layout_info.to_dict(), f, indent=4)
148
+
149
+ if args.output_iscene:
150
+ compose_mesh_scene(layout_info, f"{output_root}/Iscene.glb")
151
+
152
+ sim_cli(
153
+ layout_path=layout_path,
154
+ output_dir=output_root,
155
+ insert_robot=args.insert_robot,
156
+ )
157
+
158
+ torch.cuda.empty_cache()
159
+ gc.collect()
160
+
161
+ elapsed_time = (time() - start_time) / 60
162
+ logger.info(
163
+ f"Layout generation done for {scene_graph_path}, layout result "
164
+ f"in {layout_path}, finished in {elapsed_time:.2f} mins."
165
+ )
166
+
167
+ logger.info(f"All tasks completed in {args.output_root}")
168
+
169
+
170
+ if __name__ == "__main__":
171
+ entrypoint()
embodied_gen/scripts/gen_scene3d.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import time
5
+ import warnings
6
+ from dataclasses import dataclass, field
7
+ from shutil import copy, rmtree
8
+
9
+ import torch
10
+ import tyro
11
+ from huggingface_hub import snapshot_download
12
+ from packaging import version
13
+
14
+ # Suppress warnings
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+ logging.getLogger("transformers").setLevel(logging.ERROR)
17
+ logging.getLogger("diffusers").setLevel(logging.ERROR)
18
+
19
+ # TorchVision monkey patch for >0.16
20
+ if version.parse(torch.__version__) >= version.parse("0.16"):
21
+ import sys
22
+ import types
23
+
24
+ import torchvision.transforms.functional as TF
25
+
26
+ functional_tensor = types.ModuleType(
27
+ "torchvision.transforms.functional_tensor"
28
+ )
29
+ functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
30
+ sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
31
+
32
+ from gsplat.distributed import cli
33
+ from txt2panoimg import Text2360PanoramaImagePipeline
34
+ from embodied_gen.trainer.gsplat_trainer import (
35
+ DefaultStrategy,
36
+ GsplatTrainConfig,
37
+ )
38
+ from embodied_gen.trainer.gsplat_trainer import entrypoint as gsplat_entrypoint
39
+ from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline
40
+ from embodied_gen.utils.config import Pano2MeshSRConfig
41
+ from embodied_gen.utils.gaussian import restore_scene_scale_and_position
42
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
43
+ from embodied_gen.utils.log import logger
44
+ from embodied_gen.utils.process_media import is_image_file, parse_text_prompts
45
+ from embodied_gen.validators.quality_checkers import (
46
+ PanoHeightEstimator,
47
+ PanoImageOccChecker,
48
+ )
49
+
50
+ __all__ = [
51
+ "generate_pano_image",
52
+ "entrypoint",
53
+ ]
54
+
55
+
56
+ @dataclass
57
+ class Scene3DGenConfig:
58
+ prompts: list[str] # Text desc of indoor room or style reference image.
59
+ output_dir: str
60
+ seed: int | None = None
61
+ real_height: float | None = None # The real height of the room in meters.
62
+ pano_image_only: bool = False
63
+ disable_pano_check: bool = False
64
+ keep_middle_result: bool = False
65
+ n_retry: int = 7
66
+ gs3d: GsplatTrainConfig = field(
67
+ default_factory=lambda: GsplatTrainConfig(
68
+ strategy=DefaultStrategy(verbose=True),
69
+ max_steps=4000,
70
+ init_opa=0.9,
71
+ opacity_reg=2e-3,
72
+ sh_degree=0,
73
+ means_lr=1e-4,
74
+ scales_lr=1e-3,
75
+ )
76
+ )
77
+
78
+
79
+ def generate_pano_image(
80
+ prompt: str,
81
+ output_path: str,
82
+ pipeline,
83
+ seed: int,
84
+ n_retry: int,
85
+ checker=None,
86
+ num_inference_steps: int = 40,
87
+ ) -> None:
88
+ for i in range(n_retry):
89
+ logger.info(
90
+ f"GEN Panorama: Retry {i+1}/{n_retry} for prompt: {prompt}, seed: {seed}"
91
+ )
92
+ if is_image_file(prompt):
93
+ raise NotImplementedError("Image mode not implemented yet.")
94
+ else:
95
+ txt_prompt = f"{prompt}, spacious, empty, wide open, open floor, minimal furniture"
96
+ inputs = {
97
+ "prompt": txt_prompt,
98
+ "num_inference_steps": num_inference_steps,
99
+ "upscale": False,
100
+ "seed": seed,
101
+ }
102
+ pano_image = pipeline(inputs)
103
+
104
+ pano_image.save(output_path)
105
+ if checker is None:
106
+ break
107
+
108
+ flag, response = checker(pano_image)
109
+ logger.warning(f"{response}, image saved in {output_path}")
110
+ if flag is True or flag is None:
111
+ break
112
+
113
+ seed = random.randint(0, 100000)
114
+
115
+ return
116
+
117
+
118
+ def entrypoint(*args, **kwargs):
119
+ cfg = tyro.cli(Scene3DGenConfig)
120
+
121
+ # Init global models.
122
+ model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
123
+ IMG2PANO_PIPE = Text2360PanoramaImagePipeline(
124
+ model_path, torch_dtype=torch.float16, device="cuda"
125
+ )
126
+ PANOMESH_CFG = Pano2MeshSRConfig()
127
+ PANO2MESH_PIPE = Pano2MeshSRPipeline(PANOMESH_CFG)
128
+ PANO_CHECKER = PanoImageOccChecker(GPT_CLIENT, box_hw=[95, 1000])
129
+ PANOHEIGHT_ESTOR = PanoHeightEstimator(GPT_CLIENT)
130
+
131
+ prompts = parse_text_prompts(cfg.prompts)
132
+ for idx, prompt in enumerate(prompts):
133
+ start_time = time.time()
134
+ output_dir = os.path.join(cfg.output_dir, f"scene_{idx:04d}")
135
+ os.makedirs(output_dir, exist_ok=True)
136
+ pano_path = os.path.join(output_dir, "pano_image.png")
137
+ with open(f"{output_dir}/prompt.txt", "w") as f:
138
+ f.write(prompt)
139
+
140
+ generate_pano_image(
141
+ prompt,
142
+ pano_path,
143
+ IMG2PANO_PIPE,
144
+ cfg.seed if cfg.seed is not None else random.randint(0, 100000),
145
+ cfg.n_retry,
146
+ checker=None if cfg.disable_pano_check else PANO_CHECKER,
147
+ )
148
+
149
+ if cfg.pano_image_only:
150
+ continue
151
+
152
+ logger.info("GEN and REPAIR Mesh from Panorama...")
153
+ PANO2MESH_PIPE(pano_path, output_dir)
154
+
155
+ logger.info("TRAIN 3DGS from Mesh Init and Cube Image...")
156
+ cfg.gs3d.data_dir = output_dir
157
+ cfg.gs3d.result_dir = f"{output_dir}/gaussian"
158
+ cfg.gs3d.adjust_steps(cfg.gs3d.steps_scaler)
159
+ torch.set_default_device("cpu") # recover default setting.
160
+ cli(gsplat_entrypoint, cfg.gs3d, verbose=True)
161
+
162
+ # Clean up the middle results.
163
+ gs_path = (
164
+ f"{cfg.gs3d.result_dir}/ply/point_cloud_{cfg.gs3d.max_steps-1}.ply"
165
+ )
166
+ copy(gs_path, f"{output_dir}/gs_model.ply")
167
+ video_path = f"{cfg.gs3d.result_dir}/renders/video_step{cfg.gs3d.max_steps-1}.mp4"
168
+ copy(video_path, f"{output_dir}/video.mp4")
169
+ gs_cfg_path = f"{cfg.gs3d.result_dir}/cfg.yml"
170
+ copy(gs_cfg_path, f"{output_dir}/gsplat_cfg.yml")
171
+ if not cfg.keep_middle_result:
172
+ rmtree(cfg.gs3d.result_dir, ignore_errors=True)
173
+ os.remove(f"{output_dir}/{PANOMESH_CFG.gs_data_file}")
174
+
175
+ real_height = (
176
+ PANOHEIGHT_ESTOR(pano_path)
177
+ if cfg.real_height is None
178
+ else cfg.real_height
179
+ )
180
+ gs_path = os.path.join(output_dir, "gs_model.ply")
181
+ mesh_path = os.path.join(output_dir, "mesh_model.ply")
182
+ restore_scene_scale_and_position(real_height, mesh_path, gs_path)
183
+
184
+ elapsed_time = (time.time() - start_time) / 60
185
+ logger.info(
186
+ f"FINISHED 3D scene generation in {output_dir} in {elapsed_time:.2f} mins."
187
+ )
188
+
189
+
190
+ if __name__ == "__main__":
191
+ entrypoint()
embodied_gen/scripts/gen_texture.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from dataclasses import dataclass
4
+
5
+ import tyro
6
+ from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
7
+ from embodied_gen.data.differentiable_render import entrypoint as drender_api
8
+ from embodied_gen.data.utils import as_list
9
+ from embodied_gen.models.delight_model import DelightingModel
10
+ from embodied_gen.models.sr_model import ImageRealESRGAN
11
+ from embodied_gen.scripts.render_mv import (
12
+ build_texture_gen_pipe,
13
+ )
14
+ from embodied_gen.scripts.render_mv import infer_pipe as render_mv_api
15
+ from embodied_gen.utils.log import logger
16
+
17
+
18
+ @dataclass
19
+ class TextureGenConfig:
20
+ mesh_path: str | list[str]
21
+ prompt: str | list[str]
22
+ output_root: str
23
+ controlnet_cond_scale: float = 0.7
24
+ guidance_scale: float = 9
25
+ strength: float = 0.9
26
+ num_inference_steps: int = 40
27
+ delight: bool = True
28
+ seed: int = 0
29
+ base_ckpt_dir: str = "./weights"
30
+ texture_size: int = 2048
31
+ ip_adapt_scale: float = 0.0
32
+ ip_img_path: str | list[str] | None = None
33
+
34
+
35
+ def entrypoint() -> None:
36
+ cfg = tyro.cli(TextureGenConfig)
37
+ cfg.mesh_path = as_list(cfg.mesh_path)
38
+ cfg.prompt = as_list(cfg.prompt)
39
+ cfg.ip_img_path = as_list(cfg.ip_img_path)
40
+ assert len(cfg.mesh_path) == len(cfg.prompt)
41
+
42
+ # Pre-load models.
43
+ if cfg.ip_adapt_scale > 0:
44
+ PIPELINE = build_texture_gen_pipe(
45
+ base_ckpt_dir="./weights",
46
+ ip_adapt_scale=cfg.ip_adapt_scale,
47
+ device="cuda",
48
+ )
49
+ else:
50
+ PIPELINE = build_texture_gen_pipe(
51
+ base_ckpt_dir="./weights",
52
+ ip_adapt_scale=0,
53
+ device="cuda",
54
+ )
55
+ DELIGHT = None
56
+ if cfg.delight:
57
+ DELIGHT = DelightingModel()
58
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
59
+
60
+ for idx in range(len(cfg.mesh_path)):
61
+ mesh_path = cfg.mesh_path[idx]
62
+ prompt = cfg.prompt[idx]
63
+ uuid = os.path.splitext(os.path.basename(mesh_path))[0]
64
+ output_root = os.path.join(cfg.output_root, uuid)
65
+ drender_api(
66
+ mesh_path=mesh_path,
67
+ output_root=f"{output_root}/condition",
68
+ uuid=uuid,
69
+ )
70
+ render_mv_api(
71
+ index_file=f"{output_root}/condition/index.json",
72
+ controlnet_cond_scale=cfg.controlnet_cond_scale,
73
+ guidance_scale=cfg.guidance_scale,
74
+ strength=cfg.strength,
75
+ num_inference_steps=cfg.num_inference_steps,
76
+ ip_adapt_scale=cfg.ip_adapt_scale,
77
+ ip_img_path=(
78
+ None if cfg.ip_img_path is None else cfg.ip_img_path[idx]
79
+ ),
80
+ prompt=prompt,
81
+ save_dir=f"{output_root}/multi_view",
82
+ sub_idxs=[[0, 1, 2], [3, 4, 5]],
83
+ pipeline=PIPELINE,
84
+ seed=cfg.seed,
85
+ )
86
+ textured_mesh = backproject_api(
87
+ delight_model=DELIGHT,
88
+ imagesr_model=IMAGESR_MODEL,
89
+ mesh_path=mesh_path,
90
+ color_path=f"{output_root}/multi_view/color_sample0.png",
91
+ output_path=f"{output_root}/texture_mesh/{uuid}.obj",
92
+ save_glb_path=f"{output_root}/texture_mesh/{uuid}.glb",
93
+ skip_fix_mesh=True,
94
+ delight=cfg.delight,
95
+ no_save_delight_img=True,
96
+ texture_wh=[cfg.texture_size, cfg.texture_size],
97
+ )
98
+ drender_api(
99
+ mesh_path=f"{output_root}/texture_mesh/{uuid}.obj",
100
+ output_root=f"{output_root}/texture_mesh",
101
+ uuid=uuid,
102
+ num_images=90,
103
+ elevation=[20],
104
+ with_mtl=True,
105
+ gen_color_mp4=True,
106
+ pbr_light_factor=1.2,
107
+ )
108
+
109
+ # Re-organize folders
110
+ shutil.rmtree(f"{output_root}/condition")
111
+ shutil.copy(
112
+ f"{output_root}/texture_mesh/{uuid}/color.mp4",
113
+ f"{output_root}/color.mp4",
114
+ )
115
+ shutil.rmtree(f"{output_root}/texture_mesh/{uuid}")
116
+
117
+ logger.info(
118
+ f"Successfully generate textured mesh in {output_root}/texture_mesh"
119
+ )
120
+
121
+
122
+ if __name__ == "__main__":
123
+ entrypoint()
embodied_gen/scripts/imageto3d.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import os
20
+ import random
21
+ import sys
22
+ from glob import glob
23
+ from shutil import copy, copytree, rmtree
24
+
25
+ import numpy as np
26
+ import torch
27
+ import trimesh
28
+ from PIL import Image
29
+ from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
30
+ from embodied_gen.data.utils import delete_dir, trellis_preprocess
31
+ from embodied_gen.models.delight_model import DelightingModel
32
+ from embodied_gen.models.gs_model import GaussianOperator
33
+ from embodied_gen.models.segment_model import RembgRemover
34
+ from embodied_gen.models.sr_model import ImageRealESRGAN
35
+ from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
36
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
37
+ from embodied_gen.utils.log import logger
38
+ from embodied_gen.utils.process_media import merge_images_video
39
+ from embodied_gen.utils.tags import VERSION
40
+ from embodied_gen.utils.trender import render_video
41
+ from embodied_gen.validators.quality_checkers import (
42
+ BaseChecker,
43
+ ImageAestheticChecker,
44
+ ImageSegChecker,
45
+ MeshGeoChecker,
46
+ )
47
+ from embodied_gen.validators.urdf_convertor import URDFGenerator
48
+
49
+ current_file_path = os.path.abspath(__file__)
50
+ current_dir = os.path.dirname(current_file_path)
51
+ sys.path.append(os.path.join(current_dir, "../.."))
52
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
53
+
54
+ os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
55
+ "~/.cache/torch_extensions"
56
+ )
57
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
58
+ os.environ["SPCONV_ALGO"] = "native"
59
+ random.seed(0)
60
+
61
+ logger.info("Loading Image3D Models...")
62
+ DELIGHT = DelightingModel()
63
+ IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
64
+ RBG_REMOVER = RembgRemover()
65
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
66
+ "microsoft/TRELLIS-image-large"
67
+ )
68
+ # PIPELINE.cuda()
69
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
70
+ GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
71
+ AESTHETIC_CHECKER = ImageAestheticChecker()
72
+ CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
73
+
74
+
75
+ def parse_args():
76
+ parser = argparse.ArgumentParser(description="Image to 3D pipeline args.")
77
+ parser.add_argument(
78
+ "--image_path", type=str, nargs="+", help="Path to the input images."
79
+ )
80
+ parser.add_argument(
81
+ "--image_root", type=str, help="Path to the input images folder."
82
+ )
83
+ parser.add_argument(
84
+ "--output_root",
85
+ type=str,
86
+ help="Root directory for saving outputs.",
87
+ )
88
+ parser.add_argument(
89
+ "--height_range",
90
+ type=str,
91
+ default=None,
92
+ help="The hight in meter to restore the mesh real size.",
93
+ )
94
+ parser.add_argument(
95
+ "--mass_range",
96
+ type=str,
97
+ default=None,
98
+ help="The mass in kg to restore the mesh real weight.",
99
+ )
100
+ parser.add_argument("--asset_type", type=str, nargs="+", default=None)
101
+ parser.add_argument("--skip_exists", action="store_true")
102
+ parser.add_argument("--version", type=str, default=VERSION)
103
+ parser.add_argument("--keep_intermediate", action="store_true")
104
+ parser.add_argument("--seed", type=int, default=0)
105
+ parser.add_argument(
106
+ "--n_retry",
107
+ type=int,
108
+ default=2,
109
+ )
110
+ parser.add_argument("--disable_decompose_convex", action="store_true")
111
+ parser.add_argument(
112
+ "--texture_wh", type=int, nargs=2, default=[2048, 2048]
113
+ )
114
+ args, unknown = parser.parse_known_args()
115
+
116
+ return args
117
+
118
+
119
+ def entrypoint(**kwargs):
120
+ args = parse_args()
121
+ for k, v in kwargs.items():
122
+ if hasattr(args, k) and v is not None:
123
+ setattr(args, k, v)
124
+
125
+ assert (
126
+ args.image_path or args.image_root
127
+ ), "Please provide either --image_path or --image_root."
128
+ if not args.image_path:
129
+ args.image_path = glob(os.path.join(args.image_root, "*.png"))
130
+ args.image_path += glob(os.path.join(args.image_root, "*.jpg"))
131
+ args.image_path += glob(os.path.join(args.image_root, "*.jpeg"))
132
+
133
+ for idx, image_path in enumerate(args.image_path):
134
+ try:
135
+ filename = os.path.basename(image_path).split(".")[0]
136
+ output_root = args.output_root
137
+ if args.image_root is not None or len(args.image_path) > 1:
138
+ output_root = os.path.join(output_root, filename)
139
+ os.makedirs(output_root, exist_ok=True)
140
+
141
+ mesh_out = f"{output_root}/{filename}.obj"
142
+ if args.skip_exists and os.path.exists(mesh_out):
143
+ logger.warning(
144
+ f"Skip {image_path}, already processed in {mesh_out}"
145
+ )
146
+ continue
147
+
148
+ image = Image.open(image_path)
149
+ image.save(f"{output_root}/{filename}_raw.png")
150
+
151
+ # Segmentation: Get segmented image using Rembg.
152
+ seg_path = f"{output_root}/{filename}_cond.png"
153
+ seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
154
+ seg_image = trellis_preprocess(seg_image)
155
+ seg_image.save(seg_path)
156
+
157
+ seed = args.seed
158
+ asset_node = "unknown"
159
+ if isinstance(args.asset_type, list) and args.asset_type[idx]:
160
+ asset_node = args.asset_type[idx]
161
+ for try_idx in range(args.n_retry):
162
+ logger.info(
163
+ f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
164
+ )
165
+ # Run the pipeline
166
+ try:
167
+ PIPELINE.cuda()
168
+ outputs = PIPELINE.run(
169
+ seg_image,
170
+ preprocess_image=False,
171
+ seed=(
172
+ random.randint(0, 100000) if seed is None else seed
173
+ ),
174
+ # Optional parameters
175
+ # sparse_structure_sampler_params={
176
+ # "steps": 12,
177
+ # "cfg_strength": 7.5,
178
+ # },
179
+ # slat_sampler_params={
180
+ # "steps": 12,
181
+ # "cfg_strength": 3,
182
+ # },
183
+ )
184
+ PIPELINE.cpu()
185
+ torch.cuda.empty_cache()
186
+ except Exception as e:
187
+ logger.error(
188
+ f"[Pipeline Failed] process {image_path}: {e}, skip."
189
+ )
190
+ continue
191
+
192
+ gs_model = outputs["gaussian"][0]
193
+ mesh_model = outputs["mesh"][0]
194
+
195
+ # Save the raw Gaussian model
196
+ gs_path = mesh_out.replace(".obj", "_gs.ply")
197
+ gs_model.save_ply(gs_path)
198
+
199
+ # Rotate mesh and GS by 90 degrees around Z-axis.
200
+ rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
201
+ gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
202
+ mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
203
+
204
+ # Addtional rotation for GS to align mesh.
205
+ gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix)
206
+ pose = GaussianOperator.trans_to_quatpose(gs_rot)
207
+ aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
208
+ GaussianOperator.resave_ply(
209
+ in_ply=gs_path,
210
+ out_ply=aligned_gs_path,
211
+ instance_pose=pose,
212
+ device="cpu",
213
+ )
214
+ color_path = os.path.join(output_root, "color.png")
215
+ render_gs_api(
216
+ input_gs=aligned_gs_path,
217
+ output_path=color_path,
218
+ elevation=[20, -10, 60, -50],
219
+ num_images=12,
220
+ )
221
+
222
+ color_img = Image.open(color_path)
223
+ keep_height = int(color_img.height * 2 / 3)
224
+ crop_img = color_img.crop((0, 0, color_img.width, keep_height))
225
+ geo_flag, geo_result = GEO_CHECKER([crop_img], text=asset_node)
226
+ logger.warning(
227
+ f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
228
+ )
229
+ if geo_flag is True or geo_flag is None:
230
+ break
231
+
232
+ seed = random.randint(0, 100000) if seed is not None else None
233
+
234
+ # Render the video for generated 3D asset.
235
+ color_images = render_video(gs_model)["color"]
236
+ normal_images = render_video(mesh_model)["normal"]
237
+ video_path = os.path.join(output_root, "gs_mesh.mp4")
238
+ merge_images_video(color_images, normal_images, video_path)
239
+
240
+ mesh = trimesh.Trimesh(
241
+ vertices=mesh_model.vertices.cpu().numpy(),
242
+ faces=mesh_model.faces.cpu().numpy(),
243
+ )
244
+ mesh.vertices = mesh.vertices @ np.array(mesh_add_rot)
245
+ mesh.vertices = mesh.vertices @ np.array(rot_matrix)
246
+
247
+ mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
248
+ mesh.export(mesh_obj_path)
249
+
250
+ mesh = backproject_api(
251
+ delight_model=DELIGHT,
252
+ imagesr_model=IMAGESR_MODEL,
253
+ color_path=color_path,
254
+ mesh_path=mesh_obj_path,
255
+ output_path=mesh_obj_path,
256
+ skip_fix_mesh=False,
257
+ delight=True,
258
+ texture_wh=args.texture_wh,
259
+ elevation=[20, -10, 60, -50],
260
+ num_images=12,
261
+ )
262
+
263
+ mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
264
+ mesh.export(mesh_glb_path)
265
+
266
+ urdf_convertor = URDFGenerator(
267
+ GPT_CLIENT,
268
+ render_view_num=4,
269
+ decompose_convex=not args.disable_decompose_convex,
270
+ )
271
+ asset_attrs = {
272
+ "version": VERSION,
273
+ "gs_model": f"{urdf_convertor.output_mesh_dir}/{filename}_gs.ply",
274
+ }
275
+ if args.height_range:
276
+ min_height, max_height = map(
277
+ float, args.height_range.split("-")
278
+ )
279
+ asset_attrs["min_height"] = min_height
280
+ asset_attrs["max_height"] = max_height
281
+ if args.mass_range:
282
+ min_mass, max_mass = map(float, args.mass_range.split("-"))
283
+ asset_attrs["min_mass"] = min_mass
284
+ asset_attrs["max_mass"] = max_mass
285
+ if isinstance(args.asset_type, list) and args.asset_type[idx]:
286
+ asset_attrs["category"] = args.asset_type[idx]
287
+ if args.version:
288
+ asset_attrs["version"] = args.version
289
+
290
+ urdf_root = f"{output_root}/URDF_{filename}"
291
+ urdf_path = urdf_convertor(
292
+ mesh_path=mesh_obj_path,
293
+ output_root=urdf_root,
294
+ **asset_attrs,
295
+ )
296
+
297
+ # Rescale GS and save to URDF/mesh folder.
298
+ real_height = urdf_convertor.get_attr_from_urdf(
299
+ urdf_path, attr_name="real_height"
300
+ )
301
+ out_gs = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}_gs.ply" # noqa
302
+ GaussianOperator.resave_ply(
303
+ in_ply=aligned_gs_path,
304
+ out_ply=out_gs,
305
+ real_height=real_height,
306
+ device="cpu",
307
+ )
308
+
309
+ # Quality check and update .urdf file.
310
+ mesh_out = f"{urdf_root}/{urdf_convertor.output_mesh_dir}/{filename}.obj" # noqa
311
+ trimesh.load(mesh_out).export(mesh_out.replace(".obj", ".glb"))
312
+
313
+ image_dir = f"{urdf_root}/{urdf_convertor.output_render_dir}/image_color" # noqa
314
+ image_paths = glob(f"{image_dir}/*.png")
315
+ images_list = []
316
+ for checker in CHECKERS:
317
+ images = image_paths
318
+ if isinstance(checker, ImageSegChecker):
319
+ images = [
320
+ f"{output_root}/{filename}_raw.png",
321
+ f"{output_root}/{filename}_cond.png",
322
+ ]
323
+ images_list.append(images)
324
+
325
+ qa_results = BaseChecker.validate(CHECKERS, images_list)
326
+ urdf_convertor.add_quality_tag(urdf_path, qa_results)
327
+
328
+ # Organize the final result files
329
+ result_dir = f"{output_root}/result"
330
+ if os.path.exists(result_dir):
331
+ rmtree(result_dir, ignore_errors=True)
332
+ os.makedirs(result_dir, exist_ok=True)
333
+ copy(urdf_path, f"{result_dir}/{os.path.basename(urdf_path)}")
334
+ copytree(
335
+ f"{urdf_root}/{urdf_convertor.output_mesh_dir}",
336
+ f"{result_dir}/{urdf_convertor.output_mesh_dir}",
337
+ )
338
+ copy(video_path, f"{result_dir}/video.mp4")
339
+ if not args.keep_intermediate:
340
+ delete_dir(output_root, keep_subs=["result"])
341
+
342
+ except Exception as e:
343
+ logger.error(f"Failed to process {image_path}: {e}, skip.")
344
+ continue
345
+
346
+ logger.info(f"Processing complete. Outputs saved to {args.output_root}")
347
+
348
+
349
+ if __name__ == "__main__":
350
+ entrypoint()
embodied_gen/scripts/parallel_sim.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ from embodied_gen.utils.monkey_patches import monkey_patch_maniskill
19
+
20
+ monkey_patch_maniskill()
21
+ import json
22
+ from collections import defaultdict
23
+ from dataclasses import dataclass, field
24
+ from typing import Literal
25
+
26
+ import gymnasium as gym
27
+ import numpy as np
28
+ import torch
29
+ import tyro
30
+ from mani_skill.utils.wrappers import RecordEpisode
31
+ from tqdm import tqdm
32
+ import embodied_gen.envs.pick_embodiedgen
33
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
34
+ from embodied_gen.utils.log import logger
35
+ from embodied_gen.utils.simulation import FrankaPandaGrasper
36
+
37
+
38
+ @dataclass
39
+ class ParallelSimConfig:
40
+ """CLI parameters for Parallel Sapien simulation."""
41
+
42
+ # Environment configuration
43
+ layout_file: str
44
+ """Path to the layout JSON file"""
45
+ output_dir: str
46
+ """Directory to save recorded videos"""
47
+ gym_env_name: str = "PickEmbodiedGen-v1"
48
+ """Name of the Gym environment to use"""
49
+ num_envs: int = 4
50
+ """Number of parallel environments"""
51
+ render_mode: Literal["rgb_array", "hybrid"] = "hybrid"
52
+ """Rendering mode: rgb_array or hybrid"""
53
+ enable_shadow: bool = True
54
+ """Whether to enable shadows in rendering"""
55
+ control_mode: str = "pd_joint_pos"
56
+ """Control mode for the agent"""
57
+
58
+ # Recording configuration
59
+ max_steps_per_video: int = 1000
60
+ """Maximum steps to record per video"""
61
+ save_trajectory: bool = False
62
+ """Whether to save trajectory data"""
63
+
64
+ # Simulation parameters
65
+ seed: int = 0
66
+ """Random seed for environment reset"""
67
+ warmup_steps: int = 50
68
+ """Number of warmup steps before action computation"""
69
+ reach_target_only: bool = True
70
+ """Whether to only reach target without full action"""
71
+
72
+ # Camera settings
73
+ camera_eye: list[float] = field(default_factory=lambda: [0.9, 0.0, 1.1])
74
+ """Camera eye position [x, y, z] in global coordiante system"""
75
+ camera_target_pt: list[float] = field(
76
+ default_factory=lambda: [0.0, 0.0, 0.9]
77
+ )
78
+ """Camera target(look-at) point [x, y, z] in global coordiante system"""
79
+ image_hw: list[int] = field(default_factory=lambda: [256, 256])
80
+ """Rendered image height and width [height, width]"""
81
+ fovy_deg: float = 75
82
+ """Camera vertical field of view in degrees"""
83
+
84
+
85
+ def entrypoint(**kwargs):
86
+ if kwargs is None or len(kwargs) == 0:
87
+ cfg = tyro.cli(ParallelSimConfig)
88
+ else:
89
+ cfg = ParallelSimConfig(**kwargs)
90
+
91
+ env = gym.make(
92
+ cfg.gym_env_name,
93
+ num_envs=cfg.num_envs,
94
+ render_mode=cfg.render_mode,
95
+ enable_shadow=cfg.enable_shadow,
96
+ layout_file=cfg.layout_file,
97
+ control_mode=cfg.control_mode,
98
+ camera_cfg=dict(
99
+ camera_eye=cfg.camera_eye,
100
+ camera_target_pt=cfg.camera_target_pt,
101
+ image_hw=cfg.image_hw,
102
+ fovy_deg=cfg.fovy_deg,
103
+ ),
104
+ )
105
+ env = RecordEpisode(
106
+ env,
107
+ cfg.output_dir,
108
+ max_steps_per_video=cfg.max_steps_per_video,
109
+ save_trajectory=cfg.save_trajectory,
110
+ )
111
+ env.reset(seed=cfg.seed)
112
+
113
+ default_action = env.unwrapped.agent.init_qpos[:, :8]
114
+ for _ in tqdm(range(cfg.warmup_steps), desc="SIM Warmup"):
115
+ # action = env.action_space.sample() # Random action
116
+ obs, reward, terminated, truncated, info = env.step(default_action)
117
+
118
+ grasper = FrankaPandaGrasper(
119
+ env.unwrapped.agent,
120
+ env.unwrapped.sim_config.control_freq,
121
+ )
122
+
123
+ layout_data = LayoutInfo.from_dict(json.load(open(cfg.layout_file, "r")))
124
+ actions = defaultdict(list)
125
+ # Plan Grasp reach pose for each manipulated object in each env.
126
+ for env_idx in range(env.num_envs):
127
+ actors = env.unwrapped.env_actors[f"env{env_idx}"]
128
+ for node in layout_data.relation[
129
+ Scene3DItemEnum.MANIPULATED_OBJS.value
130
+ ]:
131
+ action = grasper.compute_grasp_action(
132
+ actor=actors[node]._objs[0],
133
+ reach_target_only=True,
134
+ env_idx=env_idx,
135
+ )
136
+ actions[node].append(action)
137
+
138
+ # Excute the planned actions for each manipulated object in each env.
139
+ for node in actions:
140
+ max_env_steps = 0
141
+ for env_idx in range(env.num_envs):
142
+ if actions[node][env_idx] is None:
143
+ continue
144
+ max_env_steps = max(max_env_steps, len(actions[node][env_idx]))
145
+
146
+ action_tensor = np.ones(
147
+ (max_env_steps, env.num_envs, env.action_space.shape[-1])
148
+ )
149
+ action_tensor *= default_action[None, ...]
150
+ for env_idx in range(env.num_envs):
151
+ action = actions[node][env_idx]
152
+ if action is None:
153
+ continue
154
+ action_tensor[: len(action), env_idx, :] = action
155
+
156
+ for step in tqdm(range(max_env_steps), desc=f"Grasping: {node}"):
157
+ action = torch.Tensor(action_tensor[step]).to(env.unwrapped.device)
158
+ env.unwrapped.agent.set_action(action)
159
+ obs, reward, terminated, truncated, info = env.step(action)
160
+
161
+ env.close()
162
+ logger.info(f"Results saved in {cfg.output_dir}")
163
+
164
+
165
+ if __name__ == "__main__":
166
+ entrypoint()
embodied_gen/scripts/render_gs.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import logging
20
+ import math
21
+
22
+ import cv2
23
+ import spaces
24
+ import torch
25
+ from PIL import Image
26
+ from tqdm import tqdm
27
+ from embodied_gen.data.utils import (
28
+ CameraSetting,
29
+ init_kal_camera,
30
+ normalize_vertices_array,
31
+ )
32
+ from embodied_gen.models.gs_model import GaussianOperator
33
+ from embodied_gen.utils.process_media import combine_images_to_grid
34
+
35
+ logging.basicConfig(
36
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
37
+ )
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ def parse_args():
42
+ parser = argparse.ArgumentParser(description="Render GS color images")
43
+
44
+ parser.add_argument(
45
+ "--input_gs", type=str, help="Input render GS.ply path."
46
+ )
47
+ parser.add_argument(
48
+ "--output_path",
49
+ type=str,
50
+ help="Output grid image path for rendered GS color images.",
51
+ )
52
+ parser.add_argument(
53
+ "--num_images", type=int, default=6, help="Number of images to render."
54
+ )
55
+ parser.add_argument(
56
+ "--elevation",
57
+ type=float,
58
+ nargs="+",
59
+ default=[20.0, -10.0],
60
+ help="Elevation angles for the camera (default: [20.0, -10.0])",
61
+ )
62
+ parser.add_argument(
63
+ "--distance",
64
+ type=float,
65
+ default=5,
66
+ help="Camera distance (default: 5)",
67
+ )
68
+ parser.add_argument(
69
+ "--resolution_hw",
70
+ type=int,
71
+ nargs=2,
72
+ default=(512, 512),
73
+ help="Resolution of the output images (default: (512, 512))",
74
+ )
75
+ parser.add_argument(
76
+ "--fov",
77
+ type=float,
78
+ default=30,
79
+ help="Field of view in degrees (default: 30)",
80
+ )
81
+ parser.add_argument(
82
+ "--device",
83
+ type=str,
84
+ choices=["cpu", "cuda"],
85
+ default="cuda",
86
+ help="Device to run on (default: `cuda`)",
87
+ )
88
+ parser.add_argument(
89
+ "--image_size",
90
+ type=int,
91
+ default=512,
92
+ help="Output image size for single view in color grid (default: 512)",
93
+ )
94
+
95
+ args, unknown = parser.parse_known_args()
96
+
97
+ return args
98
+
99
+
100
+ def load_gs_model(
101
+ input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071]
102
+ ) -> GaussianOperator:
103
+ gs_model = GaussianOperator.load_from_ply(input_gs)
104
+ # Normalize vertices to [-1, 1], center to (0, 0, 0).
105
+ _, scale, center = normalize_vertices_array(gs_model._means)
106
+ scale, center = float(scale), center.tolist()
107
+ transpose = [*[v for v in center], *pre_quat]
108
+ instance_pose = torch.tensor(transpose).to(gs_model.device)
109
+ gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
110
+ gs_model.rescale(scale)
111
+
112
+ return gs_model
113
+
114
+
115
+ @spaces.GPU
116
+ def entrypoint(**kwargs) -> None:
117
+ args = parse_args()
118
+ for k, v in kwargs.items():
119
+ if hasattr(args, k) and v is not None:
120
+ setattr(args, k, v)
121
+
122
+ # Setup camera parameters
123
+ camera_params = CameraSetting(
124
+ num_images=args.num_images,
125
+ elevation=args.elevation,
126
+ distance=args.distance,
127
+ resolution_hw=args.resolution_hw,
128
+ fov=math.radians(args.fov),
129
+ device=args.device,
130
+ )
131
+ camera = init_kal_camera(camera_params, flip_az=True)
132
+ matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
133
+ matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3]
134
+ w2cs = matrix_mv.to(camera_params.device)
135
+ c2ws = [torch.linalg.inv(matrix) for matrix in w2cs]
136
+ Ks = torch.tensor(camera_params.Ks).to(camera_params.device)
137
+
138
+ # Load GS model and normalize.
139
+ gs_model = load_gs_model(args.input_gs, pre_quat=[0.0, 0.0, 1.0, 0.0])
140
+
141
+ # Render GS color images.
142
+ images = []
143
+ for idx in tqdm(range(len(c2ws)), desc="Rendering GS"):
144
+ result = gs_model.render(
145
+ c2ws[idx],
146
+ Ks=Ks,
147
+ image_width=camera_params.resolution_hw[1],
148
+ image_height=camera_params.resolution_hw[0],
149
+ )
150
+ color = cv2.resize(
151
+ result.rgba,
152
+ (args.image_size, args.image_size),
153
+ interpolation=cv2.INTER_AREA,
154
+ )
155
+ color = cv2.cvtColor(color, cv2.COLOR_BGRA2RGBA)
156
+ images.append(Image.fromarray(color))
157
+
158
+ combine_images_to_grid(images, image_mode="RGBA")[0].save(args.output_path)
159
+
160
+ logger.info(f"Saved grid image to {args.output_path}")
161
+
162
+
163
+ if __name__ == "__main__":
164
+ entrypoint()
embodied_gen/scripts/render_mv.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import logging
19
+ import os
20
+ import random
21
+ from typing import List, Tuple
22
+
23
+ import fire
24
+ import numpy as np
25
+ import torch
26
+ from diffusers.utils import make_image_grid
27
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import (
28
+ StableDiffusionXLControlNetImg2ImgPipeline,
29
+ )
30
+ from PIL import Image, ImageEnhance, ImageFilter
31
+ from torchvision import transforms
32
+ from embodied_gen.data.datasets import Asset3dGenDataset
33
+ from embodied_gen.models.texture_model import build_texture_gen_pipe
34
+
35
+ logging.basicConfig(level=logging.INFO)
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def get_init_noise_image(image: Image.Image) -> Image.Image:
40
+ blurred_image = image.convert("L").filter(
41
+ ImageFilter.GaussianBlur(radius=3)
42
+ )
43
+
44
+ enhancer = ImageEnhance.Contrast(blurred_image)
45
+ image_decreased_contrast = enhancer.enhance(factor=0.5)
46
+
47
+ return image_decreased_contrast
48
+
49
+
50
+ def infer_pipe(
51
+ index_file: str,
52
+ controlnet_ckpt: str = None,
53
+ uid: str = None,
54
+ prompt: str = None,
55
+ controlnet_cond_scale: float = 0.4,
56
+ control_guidance_end: float = 0.9,
57
+ strength: float = 1.0,
58
+ num_inference_steps: int = 50,
59
+ guidance_scale: float = 10,
60
+ ip_adapt_scale: float = 0,
61
+ ip_img_path: str = None,
62
+ sub_idxs: List[List[int]] = None,
63
+ num_images_per_prompt: int = 3, # increase if want similar images.
64
+ device: str = "cuda",
65
+ save_dir: str = "infer_vis",
66
+ seed: int = None,
67
+ target_hw: tuple[int, int] = (512, 512),
68
+ pipeline: StableDiffusionXLControlNetImg2ImgPipeline = None,
69
+ ) -> str:
70
+ # sub_idxs = [[0, 1, 2], [3, 4, 5]] # None for single image.
71
+ if sub_idxs is None:
72
+ sub_idxs = [[random.randint(0, 5)]] # 6 views.
73
+ target_hw = [2 * size for size in target_hw]
74
+
75
+ transform_list = [
76
+ transforms.Resize(
77
+ target_hw, interpolation=transforms.InterpolationMode.BILINEAR
78
+ ),
79
+ transforms.CenterCrop(target_hw),
80
+ transforms.ToTensor(),
81
+ transforms.Normalize([0.5], [0.5]),
82
+ ]
83
+ image_transform = transforms.Compose(transform_list)
84
+ control_transform = transforms.Compose(transform_list[:-1])
85
+
86
+ grid_hw = (target_hw[0] * len(sub_idxs), target_hw[1] * len(sub_idxs[0]))
87
+ dataset = Asset3dGenDataset(
88
+ index_file, target_hw=grid_hw, sub_idxs=sub_idxs
89
+ )
90
+
91
+ if uid is None:
92
+ uid = random.choice(list(dataset.meta_info.keys()))
93
+ if prompt is None:
94
+ prompt = dataset.meta_info[uid]["capture"]
95
+ if isinstance(prompt, List) or isinstance(prompt, Tuple):
96
+ prompt = ", ".join(map(str, prompt))
97
+ # prompt += "high quality, ultra-clear, high resolution, best quality, 4k"
98
+ # prompt += "高品质,清晰,细节"
99
+ prompt += ", high quality, high resolution, best quality"
100
+ # prompt += ", with diffuse lighting, showing no reflections."
101
+ logger.info(f"Inference with prompt: {prompt}")
102
+
103
+ negative_prompt = "nsfw,阴影,低分辨率,伪影、模糊,霓虹灯,高光,镜面反射"
104
+
105
+ control_image = dataset.fetch_sample_grid_images(
106
+ uid,
107
+ attrs=["image_view_normal", "image_position", "image_mask"],
108
+ sub_idxs=sub_idxs,
109
+ transform=control_transform,
110
+ )
111
+
112
+ color_image = dataset.fetch_sample_grid_images(
113
+ uid,
114
+ attrs=["image_color"],
115
+ sub_idxs=sub_idxs,
116
+ transform=image_transform,
117
+ )
118
+
119
+ normal_pil, position_pil, mask_pil, color_pil = dataset.visualize_item(
120
+ control_image,
121
+ color_image,
122
+ save_dir=save_dir,
123
+ )
124
+
125
+ if pipeline is None:
126
+ pipeline = build_texture_gen_pipe(
127
+ base_ckpt_dir="./weights",
128
+ controlnet_ckpt=controlnet_ckpt,
129
+ ip_adapt_scale=ip_adapt_scale,
130
+ device=device,
131
+ )
132
+
133
+ if ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0:
134
+ ip_image = Image.open(ip_img_path).convert("RGB")
135
+ ip_image = ip_image.resize(target_hw[::-1])
136
+ ip_image = [ip_image]
137
+ pipeline.set_ip_adapter_scale([ip_adapt_scale])
138
+ else:
139
+ ip_image = None
140
+
141
+ generator = None
142
+ if seed is not None:
143
+ generator = torch.Generator(device).manual_seed(seed)
144
+ torch.manual_seed(seed)
145
+ np.random.seed(seed)
146
+ random.seed(seed)
147
+
148
+ init_image = get_init_noise_image(normal_pil)
149
+ # init_image = get_init_noise_image(color_pil)
150
+
151
+ images = []
152
+ row_num, col_num = 2, 3
153
+ img_save_paths = []
154
+ while len(images) < col_num:
155
+ image = pipeline(
156
+ prompt=prompt,
157
+ image=init_image,
158
+ controlnet_conditioning_scale=controlnet_cond_scale,
159
+ control_guidance_end=control_guidance_end,
160
+ strength=strength,
161
+ control_image=control_image[None, ...],
162
+ negative_prompt=negative_prompt,
163
+ num_inference_steps=num_inference_steps,
164
+ guidance_scale=guidance_scale,
165
+ num_images_per_prompt=num_images_per_prompt,
166
+ ip_adapter_image=ip_image,
167
+ generator=generator,
168
+ ).images
169
+ images.extend(image)
170
+
171
+ grid_image = [normal_pil, position_pil, color_pil] + images[:col_num]
172
+ # save_dir = os.path.join(save_dir, uid)
173
+ os.makedirs(save_dir, exist_ok=True)
174
+
175
+ for idx in range(col_num):
176
+ rgba_image = Image.merge("RGBA", (*images[idx].split(), mask_pil))
177
+ img_save_path = os.path.join(save_dir, f"color_sample{idx}.png")
178
+ rgba_image.save(img_save_path)
179
+ img_save_paths.append(img_save_path)
180
+
181
+ sub_idxs = "_".join(
182
+ [str(item) for sublist in sub_idxs for item in sublist]
183
+ )
184
+ save_path = os.path.join(
185
+ save_dir, f"sample_idx{str(sub_idxs)}_ip{ip_adapt_scale}.jpg"
186
+ )
187
+ make_image_grid(grid_image, row_num, col_num).save(save_path)
188
+ logger.info(f"Visualize in {save_path}")
189
+
190
+ return img_save_paths
191
+
192
+
193
+ def entrypoint() -> None:
194
+ fire.Fire(infer_pipe)
195
+
196
+
197
+ if __name__ == "__main__":
198
+ entrypoint()
embodied_gen/scripts/simulate_sapien.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import json
19
+ import os
20
+ from collections import defaultdict
21
+ from dataclasses import dataclass, field
22
+ from typing import Literal
23
+
24
+ import imageio
25
+ import numpy as np
26
+ import torch
27
+ import tyro
28
+ from tqdm import tqdm
29
+ from embodied_gen.models.gs_model import GaussianOperator
30
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
31
+ from embodied_gen.utils.geometry import quaternion_multiply
32
+ from embodied_gen.utils.log import logger
33
+ from embodied_gen.utils.process_media import alpha_blend_rgba
34
+ from embodied_gen.utils.simulation import (
35
+ SIM_COORD_ALIGN,
36
+ FrankaPandaGrasper,
37
+ SapienSceneManager,
38
+ load_assets_from_layout_file,
39
+ load_mani_skill_robot,
40
+ render_images,
41
+ )
42
+
43
+
44
+ @dataclass
45
+ class SapienSimConfig:
46
+ # Simulation settings.
47
+ layout_path: str
48
+ output_dir: str
49
+ sim_freq: int = 200
50
+ sim_step: int = 400
51
+ z_offset: float = 0.004
52
+ init_3dgs_quat: list[float] = field(
53
+ default_factory=lambda: [0.7071, 0, 0, 0.7071]
54
+ ) # xyzw
55
+ device: str = "cuda"
56
+ control_freq: int = 50
57
+ insert_robot: bool = False
58
+ # Camera settings.
59
+ render_interval: int = 10
60
+ num_cameras: int = 3
61
+ camera_radius: float = 0.9
62
+ camera_height: float = 1.1
63
+ image_hw: tuple[int, int] = (512, 512)
64
+ ray_tracing: bool = True
65
+ fovy_deg: float = 75.0
66
+ camera_target_pt: list[float] = field(
67
+ default_factory=lambda: [0.0, 0.0, 0.9]
68
+ )
69
+ render_keys: list[
70
+ Literal[
71
+ "Color", "Foreground", "Segmentation", "Normal", "Mask", "Depth"
72
+ ]
73
+ ] = field(default_factory=lambda: ["Foreground"])
74
+
75
+
76
+ def entrypoint(**kwargs):
77
+ if kwargs is None or len(kwargs) == 0:
78
+ cfg = tyro.cli(SapienSimConfig)
79
+ else:
80
+ cfg = SapienSimConfig(**kwargs)
81
+
82
+ scene_manager = SapienSceneManager(
83
+ cfg.sim_freq, ray_tracing=cfg.ray_tracing
84
+ )
85
+ _ = scene_manager.initialize_circular_cameras(
86
+ num_cameras=cfg.num_cameras,
87
+ radius=cfg.camera_radius,
88
+ height=cfg.camera_height,
89
+ target_pt=cfg.camera_target_pt,
90
+ image_hw=cfg.image_hw,
91
+ fovy_deg=cfg.fovy_deg,
92
+ )
93
+ with open(cfg.layout_path, "r") as f:
94
+ layout_data: LayoutInfo = LayoutInfo.from_dict(json.load(f))
95
+
96
+ actors = load_assets_from_layout_file(
97
+ scene_manager.scene,
98
+ cfg.layout_path,
99
+ cfg.z_offset,
100
+ )
101
+ agent = load_mani_skill_robot(
102
+ scene_manager.scene, cfg.layout_path, cfg.control_freq
103
+ )
104
+
105
+ frames = defaultdict(list)
106
+ image_cnt = 0
107
+ for step in tqdm(range(cfg.sim_step), desc="Simulation"):
108
+ scene_manager.scene.step()
109
+ agent.reset(agent.init_qpos)
110
+ if step % cfg.render_interval != 0:
111
+ continue
112
+ scene_manager.scene.update_render()
113
+ image_cnt += 1
114
+ for camera in scene_manager.cameras:
115
+ camera.take_picture()
116
+ images = render_images(camera, cfg.render_keys)
117
+ frames[camera.name].append(images)
118
+
119
+ actions = dict()
120
+ if cfg.insert_robot:
121
+ grasper = FrankaPandaGrasper(
122
+ agent,
123
+ cfg.control_freq,
124
+ )
125
+ for node in layout_data.relation[
126
+ Scene3DItemEnum.MANIPULATED_OBJS.value
127
+ ]:
128
+ actions[node] = grasper.compute_grasp_action(
129
+ actor=actors[node], reach_target_only=True
130
+ )
131
+
132
+ if "Foreground" not in cfg.render_keys:
133
+ return
134
+
135
+ asset_root = os.path.dirname(cfg.layout_path)
136
+ bg_node = layout_data.relation[Scene3DItemEnum.BACKGROUND.value]
137
+ gs_path = f"{asset_root}/{layout_data.assets[bg_node]}/gs_model.ply"
138
+ gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
139
+ x, y, z, qx, qy, qz, qw = layout_data.position[bg_node]
140
+ qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], cfg.init_3dgs_quat)
141
+ init_pose = torch.tensor([x, y, z, qx, qy, qz, qw])
142
+ gs_model = gs_model.get_gaussians(instance_pose=init_pose)
143
+
144
+ bg_images = dict()
145
+ for camera in scene_manager.cameras:
146
+ Ks = camera.get_intrinsic_matrix()
147
+ c2w = camera.get_model_matrix()
148
+ c2w = c2w @ SIM_COORD_ALIGN
149
+ result = gs_model.render(
150
+ torch.tensor(c2w, dtype=torch.float32).to(cfg.device),
151
+ torch.tensor(Ks, dtype=torch.float32).to(cfg.device),
152
+ image_width=cfg.image_hw[1],
153
+ image_height=cfg.image_hw[0],
154
+ )
155
+ bg_images[camera.name] = result.rgb[..., ::-1]
156
+
157
+ video_frames = []
158
+ for idx, camera in enumerate(scene_manager.cameras):
159
+ # Scene rendering
160
+ if idx == 0:
161
+ for step in range(image_cnt):
162
+ rgba = alpha_blend_rgba(
163
+ frames[camera.name][step]["Foreground"],
164
+ bg_images[camera.name],
165
+ )
166
+ video_frames.append(np.array(rgba))
167
+
168
+ # Grasp rendering
169
+ for node in actions:
170
+ if actions[node] is None:
171
+ continue
172
+ logger.info(f"Render SIM grasping in camera {idx} for {node}...")
173
+ for action in actions[node]:
174
+ grasp_frames = scene_manager.step_action(
175
+ agent,
176
+ torch.Tensor(action[None, ...]),
177
+ scene_manager.cameras,
178
+ cfg.render_keys,
179
+ sim_steps_per_control=cfg.sim_freq // cfg.control_freq,
180
+ )
181
+ rgba = alpha_blend_rgba(
182
+ grasp_frames[camera.name][0]["Foreground"],
183
+ bg_images[camera.name],
184
+ )
185
+ video_frames.append(np.array(rgba))
186
+
187
+ agent.reset(agent.init_qpos)
188
+
189
+ os.makedirs(cfg.output_dir, exist_ok=True)
190
+ video_path = f"{cfg.output_dir}/Iscene.mp4"
191
+ imageio.mimsave(video_path, video_frames, fps=30)
192
+ logger.info(f"Interative 3D Scene Visualization saved in {video_path}")
193
+
194
+
195
+ if __name__ == "__main__":
196
+ entrypoint()
embodied_gen/scripts/text2image.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import logging
20
+ import os
21
+
22
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import (
23
+ StableDiffusionXLPipeline,
24
+ )
25
+ from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( # noqa
26
+ StableDiffusionXLPipeline as StableDiffusionXLPipelineIP,
27
+ )
28
+ from tqdm import tqdm
29
+ from embodied_gen.models.text_model import (
30
+ build_text2img_ip_pipeline,
31
+ build_text2img_pipeline,
32
+ text2img_gen,
33
+ )
34
+ from embodied_gen.utils.process_media import parse_text_prompts
35
+
36
+ logging.basicConfig(level=logging.INFO)
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ def parse_args():
41
+ parser = argparse.ArgumentParser(description="Text to Image.")
42
+ parser.add_argument(
43
+ "--prompts",
44
+ type=str,
45
+ nargs="+",
46
+ help="List of prompts (space-separated).",
47
+ )
48
+ parser.add_argument(
49
+ "--ref_image",
50
+ type=str,
51
+ nargs="+",
52
+ help="List of ref_image paths (space-separated).",
53
+ )
54
+ parser.add_argument(
55
+ "--output_root",
56
+ type=str,
57
+ help="Root directory for saving outputs.",
58
+ )
59
+ parser.add_argument(
60
+ "--guidance_scale",
61
+ type=float,
62
+ default=12.0,
63
+ help="Guidance scale for the diffusion model.",
64
+ )
65
+ parser.add_argument(
66
+ "--ref_scale",
67
+ type=float,
68
+ default=0.3,
69
+ help="Reference image scale for the IP adapter.",
70
+ )
71
+ parser.add_argument(
72
+ "--n_sample",
73
+ type=int,
74
+ default=1,
75
+ )
76
+ parser.add_argument(
77
+ "--resolution",
78
+ type=int,
79
+ default=1024,
80
+ )
81
+ parser.add_argument(
82
+ "--infer_step",
83
+ type=int,
84
+ default=50,
85
+ )
86
+ parser.add_argument(
87
+ "--seed",
88
+ type=int,
89
+ default=None,
90
+ )
91
+ args = parser.parse_args()
92
+
93
+ return args
94
+
95
+
96
+ def entrypoint(
97
+ pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP = None,
98
+ **kwargs,
99
+ ) -> list[str]:
100
+ args = parse_args()
101
+ for k, v in kwargs.items():
102
+ if hasattr(args, k) and v is not None:
103
+ setattr(args, k, v)
104
+
105
+ prompts = parse_text_prompts(args.prompts)
106
+ os.makedirs(args.output_root, exist_ok=True)
107
+
108
+ ip_img_paths = args.ref_image
109
+ if ip_img_paths is None or len(ip_img_paths) == 0:
110
+ args.ref_scale = 0
111
+ ip_img_paths = [None] * len(prompts)
112
+ elif isinstance(ip_img_paths, str):
113
+ ip_img_paths = [ip_img_paths] * len(prompts)
114
+ elif isinstance(ip_img_paths, list):
115
+ if len(ip_img_paths) == 1:
116
+ ip_img_paths = ip_img_paths * len(prompts)
117
+ else:
118
+ raise ValueError("Invalid ref_image paths.")
119
+ assert len(ip_img_paths) == len(
120
+ prompts
121
+ ), f"Number of ref images does not match prompts, {len(ip_img_paths)} != {len(prompts)}" # noqa
122
+
123
+ if pipeline is None:
124
+ if args.ref_scale > 0:
125
+ pipeline = build_text2img_ip_pipeline(
126
+ "weights/Kolors",
127
+ ref_scale=args.ref_scale,
128
+ )
129
+ else:
130
+ pipeline = build_text2img_pipeline("weights/Kolors")
131
+
132
+ for idx, (prompt, ip_img_path) in tqdm(
133
+ enumerate(zip(prompts, ip_img_paths)),
134
+ desc="Generating images",
135
+ total=len(prompts),
136
+ ):
137
+ images = text2img_gen(
138
+ prompt=prompt,
139
+ n_sample=args.n_sample,
140
+ guidance_scale=args.guidance_scale,
141
+ pipeline=pipeline,
142
+ ip_image=ip_img_path,
143
+ image_wh=[args.resolution, args.resolution],
144
+ infer_step=args.infer_step,
145
+ seed=args.seed,
146
+ )
147
+
148
+ save_paths = []
149
+ for sub_idx, image in enumerate(images):
150
+ save_path = (
151
+ f"{args.output_root}/sample_{idx*args.n_sample+sub_idx}.png"
152
+ )
153
+ image.save(save_path)
154
+ save_paths.append(save_path)
155
+
156
+ logger.info(f"Images saved to {args.output_root}")
157
+
158
+ return save_paths
159
+
160
+
161
+ if __name__ == "__main__":
162
+ entrypoint()
embodied_gen/scripts/textto3d.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import argparse
18
+ import os
19
+ import random
20
+ from collections import defaultdict
21
+
22
+ import numpy as np
23
+ import torch
24
+ from PIL import Image
25
+ from embodied_gen.models.image_comm_model import build_hf_image_pipeline
26
+ from embodied_gen.models.segment_model import RembgRemover
27
+ from embodied_gen.models.text_model import PROMPT_APPEND
28
+ from embodied_gen.scripts.imageto3d import entrypoint as imageto3d_api
29
+ from embodied_gen.utils.gpt_clients import GPT_CLIENT
30
+ from embodied_gen.utils.log import logger
31
+ from embodied_gen.utils.process_media import (
32
+ check_object_edge_truncated,
33
+ render_asset3d,
34
+ )
35
+ from embodied_gen.validators.quality_checkers import (
36
+ ImageSegChecker,
37
+ SemanticConsistChecker,
38
+ TextGenAlignChecker,
39
+ )
40
+
41
+ # Avoid huggingface/tokenizers: The current process just got forked.
42
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
43
+ random.seed(0)
44
+
45
+ logger.info("Loading TEXT2IMG_MODEL...")
46
+ SEMANTIC_CHECKER = SemanticConsistChecker(GPT_CLIENT)
47
+ SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
48
+ TXTGEN_CHECKER = TextGenAlignChecker(GPT_CLIENT)
49
+ PIPE_IMG = build_hf_image_pipeline(os.environ.get("TEXT_MODEL", "sd35"))
50
+ BG_REMOVER = RembgRemover()
51
+
52
+
53
+ __all__ = [
54
+ "text_to_image",
55
+ "text_to_3d",
56
+ ]
57
+
58
+
59
+ def text_to_image(
60
+ prompt: str,
61
+ save_path: str,
62
+ n_retry: int,
63
+ img_denoise_step: int,
64
+ text_guidance_scale: float,
65
+ n_img_sample: int,
66
+ image_hw: tuple[int, int] = (1024, 1024),
67
+ seed: int = None,
68
+ ) -> bool:
69
+ select_image = None
70
+ success_flag = False
71
+ assert save_path.endswith(".png"), "Image save path must end with `.png`."
72
+ for try_idx in range(n_retry):
73
+ if select_image is not None:
74
+ select_image[0].save(save_path.replace(".png", "_raw.png"))
75
+ select_image[1].save(save_path)
76
+ break
77
+
78
+ f_prompt = PROMPT_APPEND.format(object=prompt)
79
+ logger.info(
80
+ f"Image GEN for {os.path.basename(save_path)}\n"
81
+ f"Try: {try_idx + 1}/{n_retry}, Seed: {seed}, Prompt: {f_prompt}"
82
+ )
83
+ torch.cuda.empty_cache()
84
+ images = PIPE_IMG.run(
85
+ f_prompt,
86
+ num_inference_steps=img_denoise_step,
87
+ guidance_scale=text_guidance_scale,
88
+ num_images_per_prompt=n_img_sample,
89
+ height=image_hw[0],
90
+ width=image_hw[1],
91
+ generator=(
92
+ torch.Generator().manual_seed(seed)
93
+ if seed is not None
94
+ else None
95
+ ),
96
+ )
97
+
98
+ for idx in range(len(images)):
99
+ raw_image: Image.Image = images[idx]
100
+ image = BG_REMOVER(raw_image)
101
+ image.save(save_path)
102
+ semantic_flag, semantic_result = SEMANTIC_CHECKER(
103
+ prompt, [image.convert("RGB")]
104
+ )
105
+ seg_flag, seg_result = SEG_CHECKER(
106
+ [raw_image, image.convert("RGB")]
107
+ )
108
+ image_mask = np.array(image)[..., -1]
109
+ edge_flag = check_object_edge_truncated(image_mask)
110
+ logger.warning(
111
+ f"SEMANTIC: {semantic_result}. SEG: {seg_result}. EDGE: {edge_flag}"
112
+ )
113
+ if (
114
+ (edge_flag and semantic_flag and seg_flag)
115
+ or (edge_flag and semantic_flag is None)
116
+ or (edge_flag and seg_flag is None)
117
+ ):
118
+ select_image = [raw_image, image]
119
+ success_flag = True
120
+ break
121
+
122
+ seed = random.randint(0, 100000) if seed is not None else None
123
+
124
+ return success_flag
125
+
126
+
127
+ def text_to_3d(**kwargs) -> dict:
128
+ args = parse_args()
129
+ for k, v in kwargs.items():
130
+ if hasattr(args, k) and v is not None:
131
+ setattr(args, k, v)
132
+
133
+ if args.asset_names is None or len(args.asset_names) == 0:
134
+ args.asset_names = [f"sample3d_{i}" for i in range(len(args.prompts))]
135
+ img_save_dir = os.path.join(args.output_root, "images")
136
+ asset_save_dir = os.path.join(args.output_root, "asset3d")
137
+ os.makedirs(img_save_dir, exist_ok=True)
138
+ os.makedirs(asset_save_dir, exist_ok=True)
139
+ results = defaultdict(dict)
140
+ for prompt, node in zip(args.prompts, args.asset_names):
141
+ success_flag = False
142
+ n_pipe_retry = args.n_pipe_retry
143
+ seed_img = args.seed_img
144
+ seed_3d = args.seed_3d
145
+ while success_flag is False and n_pipe_retry > 0:
146
+ logger.info(
147
+ f"GEN pipeline for node {node}\n"
148
+ f"Try round: {args.n_pipe_retry-n_pipe_retry+1}/{args.n_pipe_retry}, Prompt: {prompt}"
149
+ )
150
+ # Text-to-image GEN
151
+ save_node = node.replace(" ", "_")
152
+ gen_image_path = f"{img_save_dir}/{save_node}.png"
153
+ textgen_flag = text_to_image(
154
+ prompt,
155
+ gen_image_path,
156
+ args.n_image_retry,
157
+ args.img_denoise_step,
158
+ args.text_guidance_scale,
159
+ args.n_img_sample,
160
+ seed=seed_img,
161
+ )
162
+
163
+ # Asset 3D GEN
164
+ node_save_dir = f"{asset_save_dir}/{save_node}"
165
+ asset_type = node if "sample3d_" not in node else None
166
+ imageto3d_api(
167
+ image_path=[gen_image_path],
168
+ output_root=node_save_dir,
169
+ asset_type=[asset_type],
170
+ seed=random.randint(0, 100000) if seed_3d is None else seed_3d,
171
+ n_retry=args.n_asset_retry,
172
+ keep_intermediate=args.keep_intermediate,
173
+ disable_decompose_convex=args.disable_decompose_convex,
174
+ )
175
+ mesh_path = f"{node_save_dir}/result/mesh/{save_node}.obj"
176
+ image_path = render_asset3d(
177
+ mesh_path,
178
+ output_root=f"{node_save_dir}/result",
179
+ num_images=6,
180
+ elevation=(30, -30),
181
+ output_subdir="renders",
182
+ no_index_file=True,
183
+ )
184
+
185
+ check_text = asset_type if asset_type is not None else prompt
186
+ qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
187
+ logger.warning(
188
+ f"Node {node}, {TXTGEN_CHECKER.__class__.__name__}: {qa_result}"
189
+ )
190
+ results["assets"][node] = f"asset3d/{save_node}/result"
191
+ results["quality"][node] = qa_result
192
+
193
+ if qa_flag is None or qa_flag is True:
194
+ success_flag = True
195
+ break
196
+
197
+ n_pipe_retry -= 1
198
+ seed_img = (
199
+ random.randint(0, 100000) if seed_img is not None else None
200
+ )
201
+ seed_3d = (
202
+ random.randint(0, 100000) if seed_3d is not None else None
203
+ )
204
+
205
+ torch.cuda.empty_cache()
206
+
207
+ return results
208
+
209
+
210
+ def parse_args():
211
+ parser = argparse.ArgumentParser(description="3D Layout Generation Config")
212
+ parser.add_argument("--prompts", nargs="+", help="text descriptions")
213
+ parser.add_argument(
214
+ "--output_root",
215
+ type=str,
216
+ help="Directory to save outputs",
217
+ )
218
+ parser.add_argument(
219
+ "--asset_names",
220
+ type=str,
221
+ nargs="+",
222
+ default=None,
223
+ help="Asset names to generate",
224
+ )
225
+ parser.add_argument(
226
+ "--n_img_sample",
227
+ type=int,
228
+ default=3,
229
+ help="Number of image samples to generate",
230
+ )
231
+ parser.add_argument(
232
+ "--text_guidance_scale",
233
+ type=float,
234
+ default=7,
235
+ help="Text-to-image guidance scale",
236
+ )
237
+ parser.add_argument(
238
+ "--img_denoise_step",
239
+ type=int,
240
+ default=25,
241
+ help="Denoising steps for image generation",
242
+ )
243
+ parser.add_argument(
244
+ "--n_image_retry",
245
+ type=int,
246
+ default=2,
247
+ help="Max retry count for image generation",
248
+ )
249
+ parser.add_argument(
250
+ "--n_asset_retry",
251
+ type=int,
252
+ default=2,
253
+ help="Max retry count for 3D generation",
254
+ )
255
+ parser.add_argument(
256
+ "--n_pipe_retry",
257
+ type=int,
258
+ default=1,
259
+ help="Max retry count for 3D asset generation",
260
+ )
261
+ parser.add_argument(
262
+ "--seed_img",
263
+ type=int,
264
+ default=None,
265
+ help="Random seed for image generation",
266
+ )
267
+ parser.add_argument(
268
+ "--seed_3d",
269
+ type=int,
270
+ default=0,
271
+ help="Random seed for 3D generation",
272
+ )
273
+ parser.add_argument("--keep_intermediate", action="store_true")
274
+ parser.add_argument("--disable_decompose_convex", action="store_true")
275
+
276
+ args, unknown = parser.parse_known_args()
277
+
278
+ return args
279
+
280
+
281
+ if __name__ == "__main__":
282
+ text_to_3d()
embodied_gen/scripts/textto3d.sh ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Initialize variables
4
+ prompts=()
5
+ asset_types=()
6
+ output_root=""
7
+ seed=0
8
+
9
+ # Parse arguments
10
+ while [[ $# -gt 0 ]]; do
11
+ case "$1" in
12
+ --prompts)
13
+ shift
14
+ while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
15
+ prompts+=("$1")
16
+ shift
17
+ done
18
+ ;;
19
+ --asset_types)
20
+ shift
21
+ while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do
22
+ asset_types+=("$1")
23
+ shift
24
+ done
25
+ ;;
26
+ --output_root)
27
+ output_root="$2"
28
+ shift 2
29
+ ;;
30
+ --seed)
31
+ seed="$2"
32
+ shift 2
33
+ ;;
34
+ *)
35
+ echo "Unknown argument: $1"
36
+ exit 1
37
+ ;;
38
+ esac
39
+ done
40
+
41
+ # Validate required arguments
42
+ if [[ ${#prompts[@]} -eq 0 || -z "$output_root" ]]; then
43
+ echo "Missing required arguments."
44
+ echo "Usage: bash run_text2asset3d.sh --prompts \"Prompt1\" \"Prompt2\" \
45
+ --asset_types \"type1\" \"type2\" --seed <seed_value> --output_root <path>"
46
+ exit 1
47
+ fi
48
+
49
+ # If no asset_types provided, default to ""
50
+ if [[ ${#asset_types[@]} -eq 0 ]]; then
51
+ for (( i=0; i<${#prompts[@]}; i++ )); do
52
+ asset_types+=("")
53
+ done
54
+ fi
55
+
56
+ # Ensure the number of asset_types matches the number of prompts
57
+ if [[ ${#prompts[@]} -ne ${#asset_types[@]} ]]; then
58
+ echo "The number of asset types must match the number of prompts."
59
+ exit 1
60
+ fi
61
+
62
+ # Print arguments (for debugging)
63
+ echo "Prompts:"
64
+ for p in "${prompts[@]}"; do
65
+ echo " - $p"
66
+ done
67
+ # echo "Asset types:"
68
+ # for at in "${asset_types[@]}"; do
69
+ # echo " - $at"
70
+ # done
71
+ echo "Output root: ${output_root}"
72
+ echo "Seed: ${seed}"
73
+
74
+ # Concatenate prompts and asset types for Python command
75
+ prompt_args=""
76
+ asset_type_args=""
77
+ for i in "${!prompts[@]}"; do
78
+ prompt_args+="\"${prompts[$i]}\" "
79
+ asset_type_args+="\"${asset_types[$i]}\" "
80
+ done
81
+
82
+
83
+ # Step 1: Text-to-Image
84
+ echo ${prompt_args}
85
+ eval python3 embodied_gen/scripts/text2image.py \
86
+ --prompts ${prompt_args} \
87
+ --output_root "${output_root}/images" \
88
+ --seed ${seed}
89
+
90
+ # Step 2: Image-to-3D
91
+ python3 embodied_gen/scripts/imageto3d.py \
92
+ --image_root "${output_root}/images" \
93
+ --output_root "${output_root}/asset3d" \
94
+ --asset_type ${asset_type_args}
embodied_gen/scripts/texture_gen.sh ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ while [[ $# -gt 0 ]]; do
4
+ case $1 in
5
+ --mesh_path)
6
+ mesh_path="$2"
7
+ shift 2
8
+ ;;
9
+ --prompt)
10
+ prompt="$2"
11
+ shift 2
12
+ ;;
13
+ --output_root)
14
+ output_root="$2"
15
+ shift 2
16
+ ;;
17
+ *)
18
+ echo "unknown: $1"
19
+ exit 1
20
+ ;;
21
+ esac
22
+ done
23
+
24
+
25
+ if [[ -z "$mesh_path" || -z "$prompt" || -z "$output_root" ]]; then
26
+ echo "params missing"
27
+ echo "usage: bash run.sh --mesh_path <path> --prompt <text> --output_root <path>"
28
+ exit 1
29
+ fi
30
+
31
+ echo "Will be deprecated, recommended to use 'texture-cli' instead."
32
+ uuid=$(basename "$output_root")
33
+ # Step 1: drender-cli for condition rendering
34
+ drender-cli --mesh_path ${mesh_path} \
35
+ --output_root ${output_root}/condition \
36
+ --uuid ${uuid}
37
+
38
+ # Step 2: multi-view rendering
39
+ python embodied_gen/scripts/render_mv.py \
40
+ --index_file "${output_root}/condition/index.json" \
41
+ --controlnet_cond_scale 0.7 \
42
+ --guidance_scale 9 \
43
+ --strength 0.9 \
44
+ --num_inference_steps 40 \
45
+ --ip_adapt_scale 0 \
46
+ --ip_img_path None \
47
+ --uid ${uuid} \
48
+ --prompt "${prompt}" \
49
+ --save_dir "${output_root}/multi_view" \
50
+ --sub_idxs "[[0,1,2],[3,4,5]]" \
51
+ --seed 0
52
+
53
+ # Step 3: backprojection
54
+ backproject-cli --mesh_path ${mesh_path} \
55
+ --color_path ${output_root}/multi_view/color_sample0.png \
56
+ --output_path "${output_root}/texture_mesh/${uuid}.obj" \
57
+ --save_glb_path "${output_root}/texture_mesh/${uuid}.glb" \
58
+ --skip_fix_mesh \
59
+ --delight \
60
+ --no_save_delight_img
61
+
62
+ # Step 4: final rendering of textured mesh
63
+ drender-cli --mesh_path "${output_root}/texture_mesh/${uuid}.obj" \
64
+ --output_root ${output_root}/texture_mesh \
65
+ --num_images 90 \
66
+ --elevation 20 \
67
+ --with_mtl \
68
+ --gen_color_mp4 \
69
+ --pbr_light_factor 1.2
70
+
71
+ # Organize folders
72
+ rm -rf ${output_root}/condition
73
+ video_path="${output_root}/texture_mesh/${uuid}/color.mp4"
74
+ if [ -f "${video_path}" ]; then
75
+ cp "${video_path}" "${output_root}/texture_mesh/color.mp4"
76
+ echo "Resave video to ${output_root}/texture_mesh/color.mp4"
77
+ fi
78
+ rm -rf ${output_root}/texture_mesh/${uuid}
embodied_gen/trainer/gsplat_trainer.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+ # Part of the code comes from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
17
+ # Both under the Apache License, Version 2.0.
18
+
19
+
20
+ import json
21
+ import os
22
+ import time
23
+ from collections import defaultdict
24
+ from typing import Dict, Optional, Tuple
25
+
26
+ import cv2
27
+ import imageio
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import tqdm
32
+ import tyro
33
+ import yaml
34
+ from fused_ssim import fused_ssim
35
+ from gsplat.distributed import cli
36
+ from gsplat.rendering import rasterization
37
+ from gsplat.strategy import DefaultStrategy, MCMCStrategy
38
+ from torch import Tensor
39
+ from torch.utils.tensorboard import SummaryWriter
40
+ from torchmetrics.image import (
41
+ PeakSignalNoiseRatio,
42
+ StructuralSimilarityIndexMeasure,
43
+ )
44
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
45
+ from typing_extensions import Literal, assert_never
46
+ from embodied_gen.data.datasets import PanoGSplatDataset
47
+ from embodied_gen.utils.config import GsplatTrainConfig
48
+ from embodied_gen.utils.gaussian import (
49
+ create_splats_with_optimizers,
50
+ export_splats,
51
+ resize_pinhole_intrinsics,
52
+ set_random_seed,
53
+ )
54
+
55
+
56
+ class Runner:
57
+ """Engine for training and testing from gsplat example.
58
+
59
+ Code from https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ local_rank: int,
65
+ world_rank,
66
+ world_size: int,
67
+ cfg: GsplatTrainConfig,
68
+ ) -> None:
69
+ set_random_seed(42 + local_rank)
70
+
71
+ self.cfg = cfg
72
+ self.world_rank = world_rank
73
+ self.local_rank = local_rank
74
+ self.world_size = world_size
75
+ self.device = f"cuda:{local_rank}"
76
+
77
+ # Where to dump results.
78
+ os.makedirs(cfg.result_dir, exist_ok=True)
79
+
80
+ # Setup output directories.
81
+ self.ckpt_dir = f"{cfg.result_dir}/ckpts"
82
+ os.makedirs(self.ckpt_dir, exist_ok=True)
83
+ self.stats_dir = f"{cfg.result_dir}/stats"
84
+ os.makedirs(self.stats_dir, exist_ok=True)
85
+ self.render_dir = f"{cfg.result_dir}/renders"
86
+ os.makedirs(self.render_dir, exist_ok=True)
87
+ self.ply_dir = f"{cfg.result_dir}/ply"
88
+ os.makedirs(self.ply_dir, exist_ok=True)
89
+
90
+ # Tensorboard
91
+ self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
92
+ self.trainset = PanoGSplatDataset(cfg.data_dir, split="train")
93
+ self.valset = PanoGSplatDataset(
94
+ cfg.data_dir, split="train", max_sample_num=6
95
+ )
96
+ self.testset = PanoGSplatDataset(cfg.data_dir, split="eval")
97
+ self.scene_scale = cfg.scene_scale
98
+
99
+ # Model
100
+ self.splats, self.optimizers = create_splats_with_optimizers(
101
+ self.trainset.points,
102
+ self.trainset.points_rgb,
103
+ init_num_pts=cfg.init_num_pts,
104
+ init_extent=cfg.init_extent,
105
+ init_opacity=cfg.init_opa,
106
+ init_scale=cfg.init_scale,
107
+ means_lr=cfg.means_lr,
108
+ scales_lr=cfg.scales_lr,
109
+ opacities_lr=cfg.opacities_lr,
110
+ quats_lr=cfg.quats_lr,
111
+ sh0_lr=cfg.sh0_lr,
112
+ shN_lr=cfg.shN_lr,
113
+ scene_scale=self.scene_scale,
114
+ sh_degree=cfg.sh_degree,
115
+ sparse_grad=cfg.sparse_grad,
116
+ visible_adam=cfg.visible_adam,
117
+ batch_size=cfg.batch_size,
118
+ feature_dim=None,
119
+ device=self.device,
120
+ world_rank=world_rank,
121
+ world_size=world_size,
122
+ )
123
+ print("Model initialized. Number of GS:", len(self.splats["means"]))
124
+
125
+ # Densification Strategy
126
+ self.cfg.strategy.check_sanity(self.splats, self.optimizers)
127
+
128
+ if isinstance(self.cfg.strategy, DefaultStrategy):
129
+ self.strategy_state = self.cfg.strategy.initialize_state(
130
+ scene_scale=self.scene_scale
131
+ )
132
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
133
+ self.strategy_state = self.cfg.strategy.initialize_state()
134
+ else:
135
+ assert_never(self.cfg.strategy)
136
+
137
+ # Losses & Metrics.
138
+ self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(
139
+ self.device
140
+ )
141
+ self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
142
+
143
+ if cfg.lpips_net == "alex":
144
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
145
+ net_type="alex", normalize=True
146
+ ).to(self.device)
147
+ elif cfg.lpips_net == "vgg":
148
+ # The 3DGS official repo uses lpips vgg, which is equivalent with the following:
149
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
150
+ net_type="vgg", normalize=False
151
+ ).to(self.device)
152
+ else:
153
+ raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
154
+
155
+ def rasterize_splats(
156
+ self,
157
+ camtoworlds: Tensor,
158
+ Ks: Tensor,
159
+ width: int,
160
+ height: int,
161
+ masks: Optional[Tensor] = None,
162
+ rasterize_mode: Optional[Literal["classic", "antialiased"]] = None,
163
+ camera_model: Optional[Literal["pinhole", "ortho", "fisheye"]] = None,
164
+ **kwargs,
165
+ ) -> Tuple[Tensor, Tensor, Dict]:
166
+ means = self.splats["means"] # [N, 3]
167
+ # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
168
+ # rasterization does normalization internally
169
+ quats = self.splats["quats"] # [N, 4]
170
+ scales = torch.exp(self.splats["scales"]) # [N, 3]
171
+ opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
172
+ image_ids = kwargs.pop("image_ids", None)
173
+
174
+ colors = torch.cat(
175
+ [self.splats["sh0"], self.splats["shN"]], 1
176
+ ) # [N, K, 3]
177
+
178
+ if rasterize_mode is None:
179
+ rasterize_mode = (
180
+ "antialiased" if self.cfg.antialiased else "classic"
181
+ )
182
+ if camera_model is None:
183
+ camera_model = self.cfg.camera_model
184
+
185
+ render_colors, render_alphas, info = rasterization(
186
+ means=means,
187
+ quats=quats,
188
+ scales=scales,
189
+ opacities=opacities,
190
+ colors=colors,
191
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
192
+ Ks=Ks, # [C, 3, 3]
193
+ width=width,
194
+ height=height,
195
+ packed=self.cfg.packed,
196
+ absgrad=(
197
+ self.cfg.strategy.absgrad
198
+ if isinstance(self.cfg.strategy, DefaultStrategy)
199
+ else False
200
+ ),
201
+ sparse_grad=self.cfg.sparse_grad,
202
+ rasterize_mode=rasterize_mode,
203
+ distributed=self.world_size > 1,
204
+ camera_model=self.cfg.camera_model,
205
+ with_ut=self.cfg.with_ut,
206
+ with_eval3d=self.cfg.with_eval3d,
207
+ **kwargs,
208
+ )
209
+ if masks is not None:
210
+ render_colors[~masks] = 0
211
+ return render_colors, render_alphas, info
212
+
213
+ def train(self):
214
+ cfg = self.cfg
215
+ device = self.device
216
+ world_rank = self.world_rank
217
+
218
+ # Dump cfg.
219
+ if world_rank == 0:
220
+ with open(f"{cfg.result_dir}/cfg.yml", "w") as f:
221
+ yaml.dump(vars(cfg), f)
222
+
223
+ max_steps = cfg.max_steps
224
+ init_step = 0
225
+
226
+ schedulers = [
227
+ # means has a learning rate schedule, that end at 0.01 of the initial value
228
+ torch.optim.lr_scheduler.ExponentialLR(
229
+ self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
230
+ ),
231
+ ]
232
+ trainloader = torch.utils.data.DataLoader(
233
+ self.trainset,
234
+ batch_size=cfg.batch_size,
235
+ shuffle=True,
236
+ num_workers=4,
237
+ persistent_workers=True,
238
+ pin_memory=True,
239
+ )
240
+ trainloader_iter = iter(trainloader)
241
+
242
+ # Training loop.
243
+ global_tic = time.time()
244
+ pbar = tqdm.tqdm(range(init_step, max_steps))
245
+ for step in pbar:
246
+ try:
247
+ data = next(trainloader_iter)
248
+ except StopIteration:
249
+ trainloader_iter = iter(trainloader)
250
+ data = next(trainloader_iter)
251
+
252
+ camtoworlds = data["camtoworld"].to(device) # [1, 4, 4]
253
+ Ks = data["K"].to(device) # [1, 3, 3]
254
+ pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
255
+ image_ids = data["image_id"].to(device)
256
+ masks = (
257
+ data["mask"].to(device) if "mask" in data else None
258
+ ) # [1, H, W]
259
+ if cfg.depth_loss:
260
+ points = data["points"].to(device) # [1, M, 2]
261
+ depths_gt = data["depths"].to(device) # [1, M]
262
+
263
+ height, width = pixels.shape[1:3]
264
+
265
+ # sh schedule
266
+ sh_degree_to_use = min(
267
+ step // cfg.sh_degree_interval, cfg.sh_degree
268
+ )
269
+
270
+ # forward
271
+ renders, alphas, info = self.rasterize_splats(
272
+ camtoworlds=camtoworlds,
273
+ Ks=Ks,
274
+ width=width,
275
+ height=height,
276
+ sh_degree=sh_degree_to_use,
277
+ near_plane=cfg.near_plane,
278
+ far_plane=cfg.far_plane,
279
+ image_ids=image_ids,
280
+ render_mode="RGB+ED" if cfg.depth_loss else "RGB",
281
+ masks=masks,
282
+ )
283
+ if renders.shape[-1] == 4:
284
+ colors, depths = renders[..., 0:3], renders[..., 3:4]
285
+ else:
286
+ colors, depths = renders, None
287
+
288
+ if cfg.random_bkgd:
289
+ bkgd = torch.rand(1, 3, device=device)
290
+ colors = colors + bkgd * (1.0 - alphas)
291
+
292
+ self.cfg.strategy.step_pre_backward(
293
+ params=self.splats,
294
+ optimizers=self.optimizers,
295
+ state=self.strategy_state,
296
+ step=step,
297
+ info=info,
298
+ )
299
+
300
+ # loss
301
+ l1loss = F.l1_loss(colors, pixels)
302
+ ssimloss = 1.0 - fused_ssim(
303
+ colors.permute(0, 3, 1, 2),
304
+ pixels.permute(0, 3, 1, 2),
305
+ padding="valid",
306
+ )
307
+ loss = (
308
+ l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
309
+ )
310
+ if cfg.depth_loss:
311
+ # query depths from depth map
312
+ points = torch.stack(
313
+ [
314
+ points[:, :, 0] / (width - 1) * 2 - 1,
315
+ points[:, :, 1] / (height - 1) * 2 - 1,
316
+ ],
317
+ dim=-1,
318
+ ) # normalize to [-1, 1]
319
+ grid = points.unsqueeze(2) # [1, M, 1, 2]
320
+ depths = F.grid_sample(
321
+ depths.permute(0, 3, 1, 2), grid, align_corners=True
322
+ ) # [1, 1, M, 1]
323
+ depths = depths.squeeze(3).squeeze(1) # [1, M]
324
+ # calculate loss in disparity space
325
+ disp = torch.where(
326
+ depths > 0.0, 1.0 / depths, torch.zeros_like(depths)
327
+ )
328
+ disp_gt = 1.0 / depths_gt # [1, M]
329
+ depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
330
+ loss += depthloss * cfg.depth_lambda
331
+
332
+ # regularizations
333
+ if cfg.opacity_reg > 0.0:
334
+ loss += (
335
+ cfg.opacity_reg
336
+ * torch.sigmoid(self.splats["opacities"]).mean()
337
+ )
338
+ if cfg.scale_reg > 0.0:
339
+ loss += cfg.scale_reg * torch.exp(self.splats["scales"]).mean()
340
+
341
+ loss.backward()
342
+
343
+ desc = (
344
+ f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
345
+ )
346
+ if cfg.depth_loss:
347
+ desc += f"depth loss={depthloss.item():.6f}| "
348
+ pbar.set_description(desc)
349
+
350
+ # write images (gt and render)
351
+ # if world_rank == 0 and step % 800 == 0:
352
+ # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
353
+ # canvas = canvas.reshape(-1, *canvas.shape[2:])
354
+ # imageio.imwrite(
355
+ # f"{self.render_dir}/train_rank{self.world_rank}.png",
356
+ # (canvas * 255).astype(np.uint8),
357
+ # )
358
+
359
+ if (
360
+ world_rank == 0
361
+ and cfg.tb_every > 0
362
+ and step % cfg.tb_every == 0
363
+ ):
364
+ mem = torch.cuda.max_memory_allocated() / 1024**3
365
+ self.writer.add_scalar("train/loss", loss.item(), step)
366
+ self.writer.add_scalar("train/l1loss", l1loss.item(), step)
367
+ self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
368
+ self.writer.add_scalar(
369
+ "train/num_GS", len(self.splats["means"]), step
370
+ )
371
+ self.writer.add_scalar("train/mem", mem, step)
372
+ if cfg.depth_loss:
373
+ self.writer.add_scalar(
374
+ "train/depthloss", depthloss.item(), step
375
+ )
376
+ if cfg.tb_save_image:
377
+ canvas = (
378
+ torch.cat([pixels, colors], dim=2)
379
+ .detach()
380
+ .cpu()
381
+ .numpy()
382
+ )
383
+ canvas = canvas.reshape(-1, *canvas.shape[2:])
384
+ self.writer.add_image("train/render", canvas, step)
385
+ self.writer.flush()
386
+
387
+ # save checkpoint before updating the model
388
+ if (
389
+ step in [i - 1 for i in cfg.save_steps]
390
+ or step == max_steps - 1
391
+ ):
392
+ mem = torch.cuda.max_memory_allocated() / 1024**3
393
+ stats = {
394
+ "mem": mem,
395
+ "ellipse_time": time.time() - global_tic,
396
+ "num_GS": len(self.splats["means"]),
397
+ }
398
+ print("Step: ", step, stats)
399
+ with open(
400
+ f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json",
401
+ "w",
402
+ ) as f:
403
+ json.dump(stats, f)
404
+ data = {"step": step, "splats": self.splats.state_dict()}
405
+ torch.save(
406
+ data,
407
+ f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt",
408
+ )
409
+ if (
410
+ step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1
411
+ ) and cfg.save_ply:
412
+ sh0 = self.splats["sh0"]
413
+ shN = self.splats["shN"]
414
+ means = self.splats["means"]
415
+ scales = self.splats["scales"]
416
+ quats = self.splats["quats"]
417
+ opacities = self.splats["opacities"]
418
+ export_splats(
419
+ means=means,
420
+ scales=scales,
421
+ quats=quats,
422
+ opacities=opacities,
423
+ sh0=sh0,
424
+ shN=shN,
425
+ format="ply",
426
+ save_to=f"{self.ply_dir}/point_cloud_{step}.ply",
427
+ )
428
+
429
+ # Turn Gradients into Sparse Tensor before running optimizer
430
+ if cfg.sparse_grad:
431
+ assert (
432
+ cfg.packed
433
+ ), "Sparse gradients only work with packed mode."
434
+ gaussian_ids = info["gaussian_ids"]
435
+ for k in self.splats.keys():
436
+ grad = self.splats[k].grad
437
+ if grad is None or grad.is_sparse:
438
+ continue
439
+ self.splats[k].grad = torch.sparse_coo_tensor(
440
+ indices=gaussian_ids[None], # [1, nnz]
441
+ values=grad[gaussian_ids], # [nnz, ...]
442
+ size=self.splats[k].size(), # [N, ...]
443
+ is_coalesced=len(Ks) == 1,
444
+ )
445
+
446
+ if cfg.visible_adam:
447
+ gaussian_cnt = self.splats.means.shape[0]
448
+ if cfg.packed:
449
+ visibility_mask = torch.zeros_like(
450
+ self.splats["opacities"], dtype=bool
451
+ )
452
+ visibility_mask.scatter_(0, info["gaussian_ids"], 1)
453
+ else:
454
+ visibility_mask = (info["radii"] > 0).all(-1).any(0)
455
+
456
+ # optimize
457
+ for optimizer in self.optimizers.values():
458
+ if cfg.visible_adam:
459
+ optimizer.step(visibility_mask)
460
+ else:
461
+ optimizer.step()
462
+ optimizer.zero_grad(set_to_none=True)
463
+ for scheduler in schedulers:
464
+ scheduler.step()
465
+
466
+ # Run post-backward steps after backward and optimizer
467
+ if isinstance(self.cfg.strategy, DefaultStrategy):
468
+ self.cfg.strategy.step_post_backward(
469
+ params=self.splats,
470
+ optimizers=self.optimizers,
471
+ state=self.strategy_state,
472
+ step=step,
473
+ info=info,
474
+ packed=cfg.packed,
475
+ )
476
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
477
+ self.cfg.strategy.step_post_backward(
478
+ params=self.splats,
479
+ optimizers=self.optimizers,
480
+ state=self.strategy_state,
481
+ step=step,
482
+ info=info,
483
+ lr=schedulers[0].get_last_lr()[0],
484
+ )
485
+ else:
486
+ assert_never(self.cfg.strategy)
487
+
488
+ # eval the full set
489
+ if step in [i - 1 for i in cfg.eval_steps]:
490
+ self.eval(step)
491
+ self.render_video(step)
492
+
493
+ @torch.no_grad()
494
+ def eval(
495
+ self,
496
+ step: int,
497
+ stage: str = "val",
498
+ canvas_h: int = 512,
499
+ canvas_w: int = 1024,
500
+ ):
501
+ """Entry for evaluation."""
502
+ print("Running evaluation...")
503
+ cfg = self.cfg
504
+ device = self.device
505
+ world_rank = self.world_rank
506
+
507
+ valloader = torch.utils.data.DataLoader(
508
+ self.valset, batch_size=1, shuffle=False, num_workers=1
509
+ )
510
+ ellipse_time = 0
511
+ metrics = defaultdict(list)
512
+ for i, data in enumerate(valloader):
513
+ camtoworlds = data["camtoworld"].to(device)
514
+ Ks = data["K"].to(device)
515
+ pixels = data["image"].to(device) / 255.0
516
+ height, width = pixels.shape[1:3]
517
+ masks = data["mask"].to(device) if "mask" in data else None
518
+
519
+ pixels = pixels.permute(0, 3, 1, 2) # NHWC -> NCHW
520
+ pixels = F.interpolate(pixels, size=(canvas_h, canvas_w // 2))
521
+
522
+ torch.cuda.synchronize()
523
+ tic = time.time()
524
+ colors, _, _ = self.rasterize_splats(
525
+ camtoworlds=camtoworlds,
526
+ Ks=Ks,
527
+ width=width,
528
+ height=height,
529
+ sh_degree=cfg.sh_degree,
530
+ near_plane=cfg.near_plane,
531
+ far_plane=cfg.far_plane,
532
+ masks=masks,
533
+ ) # [1, H, W, 3]
534
+ torch.cuda.synchronize()
535
+ ellipse_time += max(time.time() - tic, 1e-10)
536
+
537
+ colors = colors.permute(0, 3, 1, 2) # NHWC -> NCHW
538
+ colors = F.interpolate(colors, size=(canvas_h, canvas_w // 2))
539
+ colors = torch.clamp(colors, 0.0, 1.0)
540
+ canvas_list = [pixels, colors]
541
+
542
+ if world_rank == 0:
543
+ canvas = torch.cat(canvas_list, dim=2).squeeze(0)
544
+ canvas = canvas.permute(1, 2, 0) # CHW -> HWC
545
+ canvas = (canvas * 255).to(torch.uint8).cpu().numpy()
546
+ cv2.imwrite(
547
+ f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
548
+ canvas[..., ::-1],
549
+ )
550
+ metrics["psnr"].append(self.psnr(colors, pixels))
551
+ metrics["ssim"].append(self.ssim(colors, pixels))
552
+ metrics["lpips"].append(self.lpips(colors, pixels))
553
+
554
+ if world_rank == 0:
555
+ ellipse_time /= len(valloader)
556
+
557
+ stats = {
558
+ k: torch.stack(v).mean().item() for k, v in metrics.items()
559
+ }
560
+ stats.update(
561
+ {
562
+ "ellipse_time": ellipse_time,
563
+ "num_GS": len(self.splats["means"]),
564
+ }
565
+ )
566
+ print(
567
+ f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
568
+ f"Time: {stats['ellipse_time']:.3f}s/image "
569
+ f"Number of GS: {stats['num_GS']}"
570
+ )
571
+ # save stats as json
572
+ with open(
573
+ f"{self.stats_dir}/{stage}_step{step:04d}.json", "w"
574
+ ) as f:
575
+ json.dump(stats, f)
576
+ # save stats to tensorboard
577
+ for k, v in stats.items():
578
+ self.writer.add_scalar(f"{stage}/{k}", v, step)
579
+ self.writer.flush()
580
+
581
+ @torch.no_grad()
582
+ def render_video(
583
+ self, step: int, canvas_h: int = 512, canvas_w: int = 1024
584
+ ):
585
+ testloader = torch.utils.data.DataLoader(
586
+ self.testset, batch_size=1, shuffle=False, num_workers=1
587
+ )
588
+
589
+ images_cache = []
590
+ depth_global_min, depth_global_max = float("inf"), -float("inf")
591
+ for data in testloader:
592
+ camtoworlds = data["camtoworld"].to(self.device)
593
+ Ks = resize_pinhole_intrinsics(
594
+ data["K"].squeeze(),
595
+ raw_hw=(data["image_h"].item(), data["image_w"].item()),
596
+ new_hw=(canvas_h, canvas_w // 2),
597
+ ).to(self.device)
598
+ renders, _, _ = self.rasterize_splats(
599
+ camtoworlds=camtoworlds,
600
+ Ks=Ks[None, ...],
601
+ width=canvas_w // 2,
602
+ height=canvas_h,
603
+ sh_degree=self.cfg.sh_degree,
604
+ near_plane=self.cfg.near_plane,
605
+ far_plane=self.cfg.far_plane,
606
+ render_mode="RGB+ED",
607
+ ) # [1, H, W, 4]
608
+ colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
609
+ colors = (colors * 255).to(torch.uint8).cpu().numpy()
610
+ depths = renders[0, ..., 3:4] # [H, W, 1], tensor in device.
611
+ images_cache.append([colors, depths])
612
+ depth_global_min = min(depth_global_min, depths.min().item())
613
+ depth_global_max = max(depth_global_max, depths.max().item())
614
+
615
+ video_path = f"{self.render_dir}/video_step{step}.mp4"
616
+ writer = imageio.get_writer(video_path, fps=30)
617
+ for rgb, depth in images_cache:
618
+ depth_normalized = torch.clip(
619
+ (depth - depth_global_min)
620
+ / (depth_global_max - depth_global_min + 1e-8),
621
+ 0,
622
+ 1,
623
+ )
624
+ depth_normalized = (
625
+ (depth_normalized * 255).to(torch.uint8).cpu().numpy()
626
+ )
627
+ depth_map = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_JET)
628
+ image = np.concatenate([rgb, depth_map], axis=1)
629
+ writer.append_data(image)
630
+
631
+ writer.close()
632
+
633
+
634
+ def entrypoint(
635
+ local_rank: int, world_rank, world_size: int, cfg: GsplatTrainConfig
636
+ ):
637
+ runner = Runner(local_rank, world_rank, world_size, cfg)
638
+
639
+ if cfg.ckpt is not None:
640
+ # run eval only
641
+ ckpts = [
642
+ torch.load(file, map_location=runner.device, weights_only=True)
643
+ for file in cfg.ckpt
644
+ ]
645
+ for k in runner.splats.keys():
646
+ runner.splats[k].data = torch.cat(
647
+ [ckpt["splats"][k] for ckpt in ckpts]
648
+ )
649
+ step = ckpts[0]["step"]
650
+ runner.eval(step=step)
651
+ runner.render_video(step=step)
652
+ else:
653
+ runner.train()
654
+ runner.render_video(step=cfg.max_steps - 1)
655
+
656
+
657
+ if __name__ == "__main__":
658
+ configs = {
659
+ "default": (
660
+ "Gaussian splatting training using densification heuristics from the original paper.",
661
+ GsplatTrainConfig(
662
+ strategy=DefaultStrategy(verbose=True),
663
+ ),
664
+ ),
665
+ "mcmc": (
666
+ "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
667
+ GsplatTrainConfig(
668
+ init_scale=0.1,
669
+ opacity_reg=0.01,
670
+ scale_reg=0.01,
671
+ strategy=MCMCStrategy(verbose=True),
672
+ ),
673
+ ),
674
+ }
675
+ cfg = tyro.extras.overridable_config_cli(configs)
676
+ cfg.adjust_steps(cfg.steps_scaler)
677
+
678
+ cli(entrypoint, cfg, verbose=True)
embodied_gen/trainer/pono2mesh_trainer.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ from embodied_gen.utils.monkey_patches import monkey_patch_pano2room
19
+
20
+ monkey_patch_pano2room()
21
+
22
+ import os
23
+
24
+ import cv2
25
+ import numpy as np
26
+ import torch
27
+ import trimesh
28
+ from equilib import cube2equi, equi2pers
29
+ from kornia.morphology import dilation
30
+ from PIL import Image
31
+ from embodied_gen.models.sr_model import ImageRealESRGAN
32
+ from embodied_gen.utils.config import Pano2MeshSRConfig
33
+ from embodied_gen.utils.geometry import compute_pinhole_intrinsics
34
+ from embodied_gen.utils.log import logger
35
+ from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor
36
+ from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import (
37
+ PanoFusionDistancePredictor,
38
+ )
39
+ from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter
40
+ from thirdparty.pano2room.modules.mesh_fusion.render import (
41
+ features_to_world_space_mesh,
42
+ render_mesh,
43
+ )
44
+ from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool
45
+ from thirdparty.pano2room.utils.camera_utils import gen_pano_rays
46
+ from thirdparty.pano2room.utils.functions import (
47
+ depth_to_distance,
48
+ get_cubemap_views_world_to_cam,
49
+ resize_image_with_aspect_ratio,
50
+ rot_z_world_to_cam,
51
+ tensor_to_pil,
52
+ )
53
+
54
+
55
+ class Pano2MeshSRPipeline:
56
+ """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement.
57
+
58
+ This class integrates several key components including:
59
+ - Depth estimation from RGB panorama
60
+ - Inpainting of missing regions under offsets
61
+ - RGB-D to mesh conversion
62
+ - Multi-view mesh repair
63
+ - 3D Gaussian Splatting (3DGS) dataset generation
64
+
65
+ Args:
66
+ config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters.
67
+
68
+ Example:
69
+ ```python
70
+ pipeline = Pano2MeshSRPipeline(config)
71
+ pipeline(pano_image='example.png', output_dir='./output')
72
+ ```
73
+ """
74
+
75
+ def __init__(self, config: Pano2MeshSRConfig) -> None:
76
+ self.cfg = config
77
+ self.device = config.device
78
+
79
+ # Init models.
80
+ self.inpainter = PanoPersFusionInpainter(save_path=None)
81
+ self.geo_predictor = PanoJointPredictor(save_path=None)
82
+ self.pano_fusion_distance_predictor = PanoFusionDistancePredictor()
83
+ self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor)
84
+
85
+ # Init poses.
86
+ cubemap_w2cs = get_cubemap_views_world_to_cam()
87
+ self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs]
88
+ self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir)
89
+
90
+ kernel = cv2.getStructuringElement(
91
+ cv2.MORPH_ELLIPSE, self.cfg.kernel_size
92
+ )
93
+ self.kernel = torch.from_numpy(kernel).float().to(self.device)
94
+
95
+ def init_mesh_params(self) -> None:
96
+ torch.set_default_device(self.device)
97
+ self.inpaint_mask = torch.ones(
98
+ (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool
99
+ )
100
+ self.vertices = torch.empty((3, 0), requires_grad=False)
101
+ self.colors = torch.empty((3, 0), requires_grad=False)
102
+ self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False)
103
+
104
+ @staticmethod
105
+ def read_camera_pose_file(filepath: str) -> np.ndarray:
106
+ with open(filepath, "r") as f:
107
+ values = [float(num) for line in f for num in line.split()]
108
+
109
+ return np.array(values).reshape(4, 4)
110
+
111
+ def load_camera_poses(
112
+ self, trajectory_dir: str
113
+ ) -> tuple[np.ndarray, list[torch.Tensor]]:
114
+ pose_filenames = sorted(
115
+ [
116
+ fname
117
+ for fname in os.listdir(trajectory_dir)
118
+ if fname.startswith("camera_pose")
119
+ ]
120
+ )
121
+
122
+ pano_pose_world = None
123
+ relative_poses = []
124
+ for idx, filename in enumerate(pose_filenames):
125
+ pose_path = os.path.join(trajectory_dir, filename)
126
+ pose_matrix = self.read_camera_pose_file(pose_path)
127
+
128
+ if pano_pose_world is None:
129
+ pano_pose_world = pose_matrix.copy()
130
+ pano_pose_world[0, 3] += self.cfg.pano_center_offset[0]
131
+ pano_pose_world[2, 3] += self.cfg.pano_center_offset[1]
132
+
133
+ # Use different reference for the first 6 cubemap views
134
+ reference_pose = pose_matrix if idx < 6 else pano_pose_world
135
+ relative_matrix = pose_matrix @ np.linalg.inv(reference_pose)
136
+ relative_matrix[0:2, :] *= -1 # flip_xy
137
+ relative_matrix = (
138
+ relative_matrix @ rot_z_world_to_cam(180).cpu().numpy()
139
+ )
140
+ relative_matrix[:3, 3] *= self.cfg.pose_scale
141
+ relative_matrix = torch.tensor(
142
+ relative_matrix, dtype=torch.float32
143
+ )
144
+ relative_poses.append(relative_matrix)
145
+
146
+ return relative_poses
147
+
148
+ def load_inpaint_poses(
149
+ self, poses: torch.Tensor
150
+ ) -> dict[int, torch.Tensor]:
151
+ inpaint_poses = dict()
152
+ sampled_views = poses[:: self.cfg.inpaint_frame_stride]
153
+ init_pose = torch.eye(4)
154
+ for idx, w2c_tensor in enumerate(sampled_views):
155
+ w2c = w2c_tensor.cpu().numpy().astype(np.float32)
156
+ c2w = np.linalg.inv(w2c)
157
+ pose_tensor = init_pose.clone()
158
+ pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3])
159
+ pose_tensor[:3, 3] *= -1
160
+ inpaint_poses[idx] = pose_tensor.to(self.device)
161
+
162
+ return inpaint_poses
163
+
164
+ def project(self, world_to_cam: torch.Tensor):
165
+ (
166
+ project_image,
167
+ project_depth,
168
+ inpaint_mask,
169
+ _,
170
+ z_buf,
171
+ mesh,
172
+ ) = render_mesh(
173
+ vertices=self.vertices,
174
+ faces=self.faces,
175
+ vertex_features=self.colors,
176
+ H=self.cfg.cubemap_h,
177
+ W=self.cfg.cubemap_w,
178
+ fov_in_degrees=self.cfg.fov,
179
+ RT=world_to_cam,
180
+ blur_radius=self.cfg.blur_radius,
181
+ faces_per_pixel=self.cfg.faces_per_pixel,
182
+ )
183
+ project_image = project_image * ~inpaint_mask
184
+
185
+ return project_image[:3, ...], inpaint_mask, project_depth
186
+
187
+ def render_pano(self, pose: torch.Tensor):
188
+ cubemap_list = []
189
+ for cubemap_pose in self.cubemap_w2cs:
190
+ project_pose = cubemap_pose @ pose
191
+ rgb, inpaint_mask, depth = self.project(project_pose)
192
+ distance_map = depth_to_distance(depth[None, ...])
193
+ mask = inpaint_mask[None, ...]
194
+ cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0))
195
+
196
+ # Set default tensor type for CPU operation in cube2equi
197
+ with torch.device("cpu"):
198
+ pano_rgbd = cube2equi(
199
+ cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w
200
+ )
201
+
202
+ pano_rgb = pano_rgbd[:3, :, :]
203
+ pano_depth = pano_rgbd[3:4, :, :].squeeze(0)
204
+ pano_mask = pano_rgbd[4:, :, :].squeeze(0)
205
+
206
+ return pano_rgb, pano_depth, pano_mask
207
+
208
+ def rgbd_to_mesh(
209
+ self,
210
+ rgb: torch.Tensor,
211
+ depth: torch.Tensor,
212
+ inpaint_mask: torch.Tensor,
213
+ world_to_cam: torch.Tensor = None,
214
+ using_distance_map: bool = True,
215
+ ) -> None:
216
+ if world_to_cam is None:
217
+ world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device)
218
+
219
+ if inpaint_mask.sum() == 0:
220
+ return
221
+
222
+ vertices, faces, colors = features_to_world_space_mesh(
223
+ colors=rgb.squeeze(0),
224
+ depth=depth,
225
+ fov_in_degrees=self.cfg.fov,
226
+ world_to_cam=world_to_cam,
227
+ mask=inpaint_mask,
228
+ faces=self.faces,
229
+ vertices=self.vertices,
230
+ using_distance_map=using_distance_map,
231
+ edge_threshold=0.05,
232
+ )
233
+
234
+ faces += self.vertices.shape[1]
235
+ self.vertices = torch.cat([self.vertices, vertices], dim=1)
236
+ self.colors = torch.cat([self.colors, colors], dim=1)
237
+ self.faces = torch.cat([self.faces, faces], dim=1)
238
+
239
+ def get_edge_image_by_depth(
240
+ self, depth: torch.Tensor, dilate_iter: int = 1
241
+ ) -> np.ndarray:
242
+ if isinstance(depth, torch.Tensor):
243
+ depth = depth.cpu().detach().numpy()
244
+
245
+ gray = (depth / depth.max() * 255).astype(np.uint8)
246
+ edges = cv2.Canny(gray, 60, 150)
247
+ if dilate_iter > 0:
248
+ kernel = np.ones((3, 3), np.uint8)
249
+ edges = cv2.dilate(edges, kernel, iterations=dilate_iter)
250
+
251
+ return edges
252
+
253
+ def mesh_repair_by_greedy_view_selection(
254
+ self, pose_dict: dict[str, torch.Tensor], output_dir: str
255
+ ) -> list:
256
+ inpainted_panos_w_pose = []
257
+ while len(pose_dict) > 0:
258
+ logger.info(f"Repairing mesh left rounds {len(pose_dict)}")
259
+ sampled_views = []
260
+ for key, pose in pose_dict.items():
261
+ pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
262
+ completeness = torch.sum(1 - pano_mask) / (pano_mask.numel())
263
+ sampled_views.append((key, completeness.item(), pose))
264
+
265
+ if len(sampled_views) == 0:
266
+ break
267
+
268
+ # Find inpainting with least view completeness.
269
+ sampled_views = sorted(sampled_views, key=lambda x: x[1])
270
+ key, _, pose = sampled_views[len(sampled_views) * 2 // 3]
271
+ pose_dict.pop(key)
272
+
273
+ pano_rgb, pano_distance, pano_mask = self.render_pano(pose)
274
+
275
+ colors = pano_rgb.permute(1, 2, 0).clone()
276
+ distances = pano_distance.unsqueeze(-1).clone()
277
+ pano_inpaint_mask = pano_mask.clone()
278
+ init_pose = pose.clone()
279
+ normals = None
280
+ if pano_inpaint_mask.min().item() < 0.5:
281
+ colors, distances, normals = self.inpaint_panorama(
282
+ idx=key,
283
+ colors=colors,
284
+ distances=distances,
285
+ pano_mask=pano_inpaint_mask,
286
+ )
287
+
288
+ init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = (
289
+ -pose[0, 3],
290
+ pose[2, 3],
291
+ 0,
292
+ )
293
+ rays = gen_pano_rays(
294
+ init_pose, self.cfg.pano_h, self.cfg.pano_w
295
+ )
296
+ conflict_mask = self.sup_pool.geo_check(
297
+ rays, distances.unsqueeze(-1)
298
+ ) # 0 is conflict, 1 not conflict
299
+ pano_inpaint_mask *= conflict_mask
300
+
301
+ self.rgbd_to_mesh(
302
+ colors.permute(2, 0, 1),
303
+ distances,
304
+ pano_inpaint_mask,
305
+ world_to_cam=pose,
306
+ )
307
+
308
+ self.sup_pool.register_sup_info(
309
+ pose=init_pose,
310
+ mask=pano_inpaint_mask.clone(),
311
+ rgb=colors,
312
+ distance=distances.unsqueeze(-1),
313
+ normal=normals,
314
+ )
315
+
316
+ colors = colors.permute(2, 0, 1).unsqueeze(0)
317
+ inpainted_panos_w_pose.append([colors, pose])
318
+
319
+ if self.cfg.visualize:
320
+ from embodied_gen.data.utils import DiffrastRender
321
+
322
+ tensor_to_pil(pano_rgb.unsqueeze(0)).save(
323
+ f"{output_dir}/rendered_pano_{key}.jpg"
324
+ )
325
+ tensor_to_pil(colors).save(
326
+ f"{output_dir}/inpainted_pano_{key}.jpg"
327
+ )
328
+ norm_depth = DiffrastRender.normalize_map_by_mask(
329
+ distances, torch.ones_like(distances)
330
+ )
331
+ heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8)
332
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
333
+ Image.fromarray(heatmap).save(
334
+ f"{output_dir}/inpainted_depth_{key}.png"
335
+ )
336
+
337
+ return inpainted_panos_w_pose
338
+
339
+ def inpaint_panorama(
340
+ self,
341
+ idx: int,
342
+ colors: torch.Tensor,
343
+ distances: torch.Tensor,
344
+ pano_mask: torch.Tensor,
345
+ ) -> tuple[torch.Tensor]:
346
+ mask = (pano_mask[None, ..., None] > 0.5).float()
347
+ mask = mask.permute(0, 3, 1, 2)
348
+ mask = dilation(mask, kernel=self.kernel)
349
+ mask = mask[0, 0, ..., None] # hwc
350
+ inpainted_img = self.inpainter.inpaint(idx, colors, mask)
351
+ inpainted_img = colors * (1 - mask) + inpainted_img * mask
352
+ inpainted_distances, inpainted_normals = self.geo_predictor(
353
+ idx,
354
+ inpainted_img,
355
+ distances[..., None],
356
+ mask=mask,
357
+ reg_loss_weight=0.0,
358
+ normal_loss_weight=5e-2,
359
+ normal_tv_loss_weight=5e-2,
360
+ )
361
+
362
+ return inpainted_img, inpainted_distances.squeeze(), inpainted_normals
363
+
364
+ def preprocess_pano(
365
+ self, image: Image.Image | str
366
+ ) -> tuple[torch.Tensor, torch.Tensor]:
367
+ if isinstance(image, str):
368
+ image = Image.open(image)
369
+
370
+ image = image.convert("RGB")
371
+
372
+ if image.size[0] < image.size[1]:
373
+ image = image.transpose(Image.TRANSPOSE)
374
+
375
+ image = resize_image_with_aspect_ratio(image, self.cfg.pano_w)
376
+ image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255
377
+ image_rgb = image_rgb.to(self.device)
378
+ image_depth = self.pano_fusion_distance_predictor.predict(
379
+ image_rgb.permute(1, 2, 0)
380
+ )
381
+ image_depth = (
382
+ image_depth / image_depth.max() * self.cfg.depth_scale_factor
383
+ )
384
+
385
+ return image_rgb, image_depth
386
+
387
+ def pano_to_perpective(
388
+ self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float
389
+ ) -> torch.Tensor:
390
+ rots = dict(
391
+ roll=0,
392
+ pitch=pitch,
393
+ yaw=yaw,
394
+ )
395
+ perspective = equi2pers(
396
+ equi=pano_image.squeeze(0),
397
+ rots=rots,
398
+ height=self.cfg.cubemap_h,
399
+ width=self.cfg.cubemap_w,
400
+ fov_x=fov,
401
+ mode="bilinear",
402
+ ).unsqueeze(0)
403
+
404
+ return perspective
405
+
406
+ def pano_to_cubemap(self, pano_rgb: torch.Tensor):
407
+ # Define six canonical cube directions in (pitch, yaw)
408
+ directions = [
409
+ (0, 0),
410
+ (0, 1.5 * np.pi),
411
+ (0, 1.0 * np.pi),
412
+ (0, 0.5 * np.pi),
413
+ (-0.5 * np.pi, 0),
414
+ (0.5 * np.pi, 0),
415
+ ]
416
+
417
+ cubemaps_rgb = []
418
+ for pitch, yaw in directions:
419
+ rgb_view = self.pano_to_perpective(
420
+ pano_rgb, pitch, yaw, fov=self.cfg.fov
421
+ )
422
+ cubemaps_rgb.append(rgb_view.cpu())
423
+
424
+ return cubemaps_rgb
425
+
426
+ def save_mesh(self, output_path: str) -> None:
427
+ vertices_np = self.vertices.T.cpu().numpy()
428
+ colors_np = self.colors.T.cpu().numpy()
429
+ faces_np = self.faces.T.cpu().numpy()
430
+ mesh = trimesh.Trimesh(
431
+ vertices=vertices_np, faces=faces_np, vertex_colors=colors_np
432
+ )
433
+
434
+ mesh.export(output_path)
435
+
436
+ def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray:
437
+ pose = mesh_pose.clone()
438
+ pose[0, :] *= -1
439
+ pose[1, :] *= -1
440
+
441
+ Rw2c = pose[:3, :3].cpu().numpy()
442
+ Tw2c = pose[:3, 3:].cpu().numpy()
443
+ yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
444
+
445
+ Rc2w = (yz_reverse @ Rw2c).T
446
+ Tc2w = -(Rc2w @ yz_reverse @ Tw2c)
447
+ c2w = np.concatenate((Rc2w, Tc2w), axis=1)
448
+ c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0)
449
+
450
+ return c2w
451
+
452
+ def __call__(self, pano_image: Image.Image | str, output_dir: str):
453
+ self.init_mesh_params()
454
+ pano_rgb, pano_depth = self.preprocess_pano(pano_image)
455
+ self.sup_pool = SupInfoPool()
456
+ self.sup_pool.register_sup_info(
457
+ pose=torch.eye(4).to(self.device),
458
+ mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]),
459
+ rgb=pano_rgb.permute(1, 2, 0),
460
+ distance=pano_depth[..., None],
461
+ )
462
+ self.sup_pool.gen_occ_grid(res=256)
463
+
464
+ logger.info("Init mesh from pano RGBD image...")
465
+ depth_edge = self.get_edge_image_by_depth(pano_depth)
466
+ inpaint_edge_mask = (
467
+ ~torch.from_numpy(depth_edge).to(self.device).bool()
468
+ )
469
+ self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask)
470
+
471
+ repair_poses = self.load_inpaint_poses(self.camera_poses)
472
+ inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection(
473
+ repair_poses, output_dir
474
+ )
475
+ torch.cuda.empty_cache()
476
+ torch.set_default_device("cpu")
477
+
478
+ if self.cfg.mesh_file is not None:
479
+ mesh_path = os.path.join(output_dir, self.cfg.mesh_file)
480
+ self.save_mesh(mesh_path)
481
+
482
+ if self.cfg.gs_data_file is None:
483
+ return
484
+
485
+ logger.info(f"Dump data for 3DGS training...")
486
+ points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8)
487
+ data = {
488
+ "points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3)
489
+ "points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3)
490
+ "train": [],
491
+ "eval": [],
492
+ }
493
+ image_h = self.cfg.cubemap_h * self.cfg.upscale_factor
494
+ image_w = self.cfg.cubemap_w * self.cfg.upscale_factor
495
+ Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov)
496
+ for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses):
497
+ cubemaps = self.pano_to_cubemap(pano_img)
498
+ for i in range(len(cubemaps)):
499
+ cubemap = tensor_to_pil(cubemaps[i])
500
+ cubemap = self.super_model(cubemap)
501
+ mesh_pose = self.cubemap_w2cs[i] @ pano_pose
502
+ c2w = self.mesh_pose_to_gs_pose(mesh_pose)
503
+ data["train"].append(
504
+ {
505
+ "camtoworld": c2w.astype(np.float32),
506
+ "K": Ks.astype(np.float32),
507
+ "image": np.array(cubemap),
508
+ "image_h": image_h,
509
+ "image_w": image_w,
510
+ "image_id": len(cubemaps) * idx + i,
511
+ }
512
+ )
513
+
514
+ # Camera poses for evaluation.
515
+ for idx in range(len(self.camera_poses)):
516
+ c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx])
517
+ data["eval"].append(
518
+ {
519
+ "camtoworld": c2w.astype(np.float32),
520
+ "K": Ks.astype(np.float32),
521
+ "image_h": image_h,
522
+ "image_w": image_w,
523
+ "image_id": idx,
524
+ }
525
+ )
526
+
527
+ data_path = os.path.join(output_dir, self.cfg.gs_data_file)
528
+ torch.save(data, data_path)
529
+
530
+ return
531
+
532
+
533
+ if __name__ == "__main__":
534
+ output_dir = "outputs/bg_v2/test3"
535
+ input_pano = "apps/assets/example_scene/result_pano.png"
536
+ config = Pano2MeshSRConfig()
537
+ pipeline = Pano2MeshSRPipeline(config)
538
+ pipeline(input_pano, output_dir)
embodied_gen/utils/config.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import List, Optional, Union
19
+
20
+ from dataclasses_json import DataClassJsonMixin
21
+ from gsplat.strategy import DefaultStrategy, MCMCStrategy
22
+ from typing_extensions import Literal, assert_never
23
+
24
+ __all__ = [
25
+ "GptParamsConfig",
26
+ "Pano2MeshSRConfig",
27
+ "GsplatTrainConfig",
28
+ ]
29
+
30
+
31
+ @dataclass
32
+ class GptParamsConfig(DataClassJsonMixin):
33
+ temperature: float = 0.1
34
+ top_p: float = 0.1
35
+ frequency_penalty: float = 0.0
36
+ presence_penalty: float = 0.0
37
+ stop: int | None = None
38
+ max_tokens: int = 500
39
+
40
+
41
+ @dataclass
42
+ class Pano2MeshSRConfig:
43
+ mesh_file: str = "mesh_model.ply"
44
+ gs_data_file: str = "gs_data.pt"
45
+ device: str = "cuda"
46
+ blur_radius: int = 0
47
+ faces_per_pixel: int = 8
48
+ fov: int = 90
49
+ pano_w: int = 2048
50
+ pano_h: int = 1024
51
+ cubemap_w: int = 512
52
+ cubemap_h: int = 512
53
+ pose_scale: float = 0.6
54
+ pano_center_offset: tuple = (-0.2, 0.3)
55
+ inpaint_frame_stride: int = 20
56
+ trajectory_dir: str = "apps/assets/example_scene/camera_trajectory"
57
+ visualize: bool = False
58
+ depth_scale_factor: float = 3.4092
59
+ kernel_size: tuple = (9, 9)
60
+ upscale_factor: int = 4
61
+
62
+
63
+ @dataclass
64
+ class GsplatTrainConfig:
65
+ # Path to the .pt files. If provide, it will skip training and run evaluation only.
66
+ ckpt: Optional[List[str]] = None
67
+ # Render trajectory path
68
+ render_traj_path: str = "interp"
69
+
70
+ # Path to the Mip-NeRF 360 dataset
71
+ data_dir: str = "outputs/bg"
72
+ # Downsample factor for the dataset
73
+ data_factor: int = 4
74
+ # Directory to save results
75
+ result_dir: str = "outputs/bg"
76
+ # Every N images there is a test image
77
+ test_every: int = 8
78
+ # Random crop size for training (experimental)
79
+ patch_size: Optional[int] = None
80
+ # A global scaler that applies to the scene size related parameters
81
+ global_scale: float = 1.0
82
+ # Normalize the world space
83
+ normalize_world_space: bool = True
84
+ # Camera model
85
+ camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
86
+
87
+ # Port for the viewer server
88
+ port: int = 8080
89
+
90
+ # Batch size for training. Learning rates are scaled automatically
91
+ batch_size: int = 1
92
+ # A global factor to scale the number of training steps
93
+ steps_scaler: float = 1.0
94
+
95
+ # Number of training steps
96
+ max_steps: int = 30_000
97
+ # Steps to evaluate the model
98
+ eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
99
+ # Steps to save the model
100
+ save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
101
+ # Whether to save ply file (storage size can be large)
102
+ save_ply: bool = True
103
+ # Steps to save the model as ply
104
+ ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
105
+ # Whether to disable video generation during training and evaluation
106
+ disable_video: bool = False
107
+
108
+ # Initial number of GSs. Ignored if using sfm
109
+ init_num_pts: int = 100_000
110
+ # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
111
+ init_extent: float = 3.0
112
+ # Degree of spherical harmonics
113
+ sh_degree: int = 1
114
+ # Turn on another SH degree every this steps
115
+ sh_degree_interval: int = 1000
116
+ # Initial opacity of GS
117
+ init_opa: float = 0.1
118
+ # Initial scale of GS
119
+ init_scale: float = 1.0
120
+ # Weight for SSIM loss
121
+ ssim_lambda: float = 0.2
122
+
123
+ # Near plane clipping distance
124
+ near_plane: float = 0.01
125
+ # Far plane clipping distance
126
+ far_plane: float = 1e10
127
+
128
+ # Strategy for GS densification
129
+ strategy: Union[DefaultStrategy, MCMCStrategy] = field(
130
+ default_factory=DefaultStrategy
131
+ )
132
+ # Use packed mode for rasterization, this leads to less memory usage but slightly slower.
133
+ packed: bool = False
134
+ # Use sparse gradients for optimization. (experimental)
135
+ sparse_grad: bool = False
136
+ # Use visible adam from Taming 3DGS. (experimental)
137
+ visible_adam: bool = False
138
+ # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
139
+ antialiased: bool = False
140
+
141
+ # Use random background for training to discourage transparency
142
+ random_bkgd: bool = False
143
+
144
+ # LR for 3D point positions
145
+ means_lr: float = 1.6e-4
146
+ # LR for Gaussian scale factors
147
+ scales_lr: float = 5e-3
148
+ # LR for alpha blending weights
149
+ opacities_lr: float = 5e-2
150
+ # LR for orientation (quaternions)
151
+ quats_lr: float = 1e-3
152
+ # LR for SH band 0 (brightness)
153
+ sh0_lr: float = 2.5e-3
154
+ # LR for higher-order SH (detail)
155
+ shN_lr: float = 2.5e-3 / 20
156
+
157
+ # Opacity regularization
158
+ opacity_reg: float = 0.0
159
+ # Scale regularization
160
+ scale_reg: float = 0.0
161
+
162
+ # Enable depth loss. (experimental)
163
+ depth_loss: bool = False
164
+ # Weight for depth loss
165
+ depth_lambda: float = 1e-2
166
+
167
+ # Dump information to tensorboard every this steps
168
+ tb_every: int = 200
169
+ # Save training images to tensorboard
170
+ tb_save_image: bool = False
171
+
172
+ lpips_net: Literal["vgg", "alex"] = "alex"
173
+
174
+ # 3DGUT (uncented transform + eval 3D)
175
+ with_ut: bool = False
176
+ with_eval3d: bool = False
177
+
178
+ scene_scale: float = 1.0
179
+
180
+ def adjust_steps(self, factor: float):
181
+ self.eval_steps = [int(i * factor) for i in self.eval_steps]
182
+ self.save_steps = [int(i * factor) for i in self.save_steps]
183
+ self.ply_steps = [int(i * factor) for i in self.ply_steps]
184
+ self.max_steps = int(self.max_steps * factor)
185
+ self.sh_degree_interval = int(self.sh_degree_interval * factor)
186
+
187
+ strategy = self.strategy
188
+ if isinstance(strategy, DefaultStrategy):
189
+ strategy.refine_start_iter = int(
190
+ strategy.refine_start_iter * factor
191
+ )
192
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
193
+ strategy.reset_every = int(strategy.reset_every * factor)
194
+ strategy.refine_every = int(strategy.refine_every * factor)
195
+ elif isinstance(strategy, MCMCStrategy):
196
+ strategy.refine_start_iter = int(
197
+ strategy.refine_start_iter * factor
198
+ )
199
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
200
+ strategy.refine_every = int(strategy.refine_every * factor)
201
+ else:
202
+ assert_never(strategy)
embodied_gen/utils/enum.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from enum import Enum
19
+
20
+ from dataclasses_json import DataClassJsonMixin
21
+
22
+ __all__ = [
23
+ "RenderItems",
24
+ "Scene3DItemEnum",
25
+ "SpatialRelationEnum",
26
+ "RobotItemEnum",
27
+ ]
28
+
29
+
30
+ @dataclass
31
+ class RenderItems(str, Enum):
32
+ IMAGE = "image_color"
33
+ ALPHA = "image_mask"
34
+ VIEW_NORMAL = "image_view_normal"
35
+ GLOBAL_NORMAL = "image_global_normal"
36
+ POSITION_MAP = "image_position"
37
+ DEPTH = "image_depth"
38
+ ALBEDO = "image_albedo"
39
+ DIFFUSE = "image_diffuse"
40
+
41
+
42
+ @dataclass
43
+ class Scene3DItemEnum(str, Enum):
44
+ BACKGROUND = "background"
45
+ CONTEXT = "context"
46
+ ROBOT = "robot"
47
+ MANIPULATED_OBJS = "manipulated_objs"
48
+ DISTRACTOR_OBJS = "distractor_objs"
49
+ OTHERS = "others"
50
+
51
+ @classmethod
52
+ def object_list(cls, layout_relation: dict) -> list:
53
+ return (
54
+ [
55
+ layout_relation[cls.BACKGROUND.value],
56
+ layout_relation[cls.CONTEXT.value],
57
+ ]
58
+ + layout_relation[cls.MANIPULATED_OBJS.value]
59
+ + layout_relation[cls.DISTRACTOR_OBJS.value]
60
+ )
61
+
62
+ @classmethod
63
+ def object_mapping(cls, layout_relation):
64
+ relation_mapping = {
65
+ # layout_relation[cls.ROBOT.value]: cls.ROBOT.value,
66
+ layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value,
67
+ layout_relation[cls.CONTEXT.value]: cls.CONTEXT.value,
68
+ }
69
+ relation_mapping.update(
70
+ {
71
+ item: cls.MANIPULATED_OBJS.value
72
+ for item in layout_relation[cls.MANIPULATED_OBJS.value]
73
+ }
74
+ )
75
+ relation_mapping.update(
76
+ {
77
+ item: cls.DISTRACTOR_OBJS.value
78
+ for item in layout_relation[cls.DISTRACTOR_OBJS.value]
79
+ }
80
+ )
81
+
82
+ return relation_mapping
83
+
84
+
85
+ @dataclass
86
+ class SpatialRelationEnum(str, Enum):
87
+ ON = "ON" # objects on the table
88
+ IN = "IN" # objects in the room
89
+ INSIDE = "INSIDE" # objects inside the shelf/rack
90
+ FLOOR = "FLOOR" # object floor room/bin
91
+
92
+
93
+ @dataclass
94
+ class RobotItemEnum(str, Enum):
95
+ FRANKA = "franka"
96
+ UR5 = "ur5"
97
+ PIPER = "piper"
98
+
99
+
100
+ @dataclass
101
+ class LayoutInfo(DataClassJsonMixin):
102
+ tree: dict[str, list]
103
+ relation: dict[str, str | list[str]]
104
+ objs_desc: dict[str, str] = field(default_factory=dict)
105
+ objs_mapping: dict[str, str] = field(default_factory=dict)
106
+ assets: dict[str, str] = field(default_factory=dict)
107
+ quality: dict[str, str] = field(default_factory=dict)
108
+ position: dict[str, list[float]] = field(default_factory=dict)
embodied_gen/utils/gaussian.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+ # Part of the code comes from https://github.com/nerfstudio-project/gsplat
17
+ # Both under the Apache License, Version 2.0.
18
+
19
+
20
+ import math
21
+ import random
22
+ from io import BytesIO
23
+ from typing import Dict, Literal, Optional, Tuple
24
+
25
+ import numpy as np
26
+ import torch
27
+ import trimesh
28
+ from gsplat.optimizers import SelectiveAdam
29
+ from scipy.spatial.transform import Rotation
30
+ from sklearn.neighbors import NearestNeighbors
31
+ from torch import Tensor
32
+ from embodied_gen.models.gs_model import GaussianOperator
33
+
34
+ __all__ = [
35
+ "set_random_seed",
36
+ "export_splats",
37
+ "create_splats_with_optimizers",
38
+ "resize_pinhole_intrinsics",
39
+ "restore_scene_scale_and_position",
40
+ ]
41
+
42
+
43
+ def knn(x: Tensor, K: int = 4) -> Tensor:
44
+ x_np = x.cpu().numpy()
45
+ model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
46
+ distances, _ = model.kneighbors(x_np)
47
+ return torch.from_numpy(distances).to(x)
48
+
49
+
50
+ def rgb_to_sh(rgb: Tensor) -> Tensor:
51
+ C0 = 0.28209479177387814
52
+ return (rgb - 0.5) / C0
53
+
54
+
55
+ def set_random_seed(seed: int):
56
+ random.seed(seed)
57
+ np.random.seed(seed)
58
+ torch.manual_seed(seed)
59
+
60
+
61
+ def splat2ply_bytes(
62
+ means: torch.Tensor,
63
+ scales: torch.Tensor,
64
+ quats: torch.Tensor,
65
+ opacities: torch.Tensor,
66
+ sh0: torch.Tensor,
67
+ shN: torch.Tensor,
68
+ ) -> bytes:
69
+ num_splats = means.shape[0]
70
+ buffer = BytesIO()
71
+
72
+ # Write PLY header
73
+ buffer.write(b"ply\n")
74
+ buffer.write(b"format binary_little_endian 1.0\n")
75
+ buffer.write(f"element vertex {num_splats}\n".encode())
76
+ buffer.write(b"property float x\n")
77
+ buffer.write(b"property float y\n")
78
+ buffer.write(b"property float z\n")
79
+ for i, data in enumerate([sh0, shN]):
80
+ prefix = "f_dc" if i == 0 else "f_rest"
81
+ for j in range(data.shape[1]):
82
+ buffer.write(f"property float {prefix}_{j}\n".encode())
83
+ buffer.write(b"property float opacity\n")
84
+ for i in range(scales.shape[1]):
85
+ buffer.write(f"property float scale_{i}\n".encode())
86
+ for i in range(quats.shape[1]):
87
+ buffer.write(f"property float rot_{i}\n".encode())
88
+ buffer.write(b"end_header\n")
89
+
90
+ # Concatenate all tensors in the correct order
91
+ splat_data = torch.cat(
92
+ [means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1
93
+ )
94
+ # Ensure correct dtype
95
+ splat_data = splat_data.to(torch.float32)
96
+
97
+ # Write binary data
98
+ float_dtype = np.dtype(np.float32).newbyteorder("<")
99
+ buffer.write(
100
+ splat_data.detach().cpu().numpy().astype(float_dtype).tobytes()
101
+ )
102
+
103
+ return buffer.getvalue()
104
+
105
+
106
+ def export_splats(
107
+ means: torch.Tensor,
108
+ scales: torch.Tensor,
109
+ quats: torch.Tensor,
110
+ opacities: torch.Tensor,
111
+ sh0: torch.Tensor,
112
+ shN: torch.Tensor,
113
+ format: Literal["ply"] = "ply",
114
+ save_to: Optional[str] = None,
115
+ ) -> bytes:
116
+ """Export a Gaussian Splats model to bytes in PLY file format."""
117
+ total_splats = means.shape[0]
118
+ assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)"
119
+ assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)"
120
+ assert quats.shape == (
121
+ total_splats,
122
+ 4,
123
+ ), "Quaternions must be of shape (N, 4)"
124
+ assert opacities.shape == (
125
+ total_splats,
126
+ ), "Opacities must be of shape (N,)"
127
+ assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)"
128
+ assert (
129
+ shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3
130
+ ), f"shN must be of shape (N, K, 3), got {shN.shape}"
131
+
132
+ # Reshape spherical harmonics
133
+ sh0 = sh0.squeeze(1) # Shape (N, 3)
134
+ shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3)
135
+
136
+ # Check for NaN or Inf values
137
+ invalid_mask = (
138
+ torch.isnan(means).any(dim=1)
139
+ | torch.isinf(means).any(dim=1)
140
+ | torch.isnan(scales).any(dim=1)
141
+ | torch.isinf(scales).any(dim=1)
142
+ | torch.isnan(quats).any(dim=1)
143
+ | torch.isinf(quats).any(dim=1)
144
+ | torch.isnan(opacities).any(dim=0)
145
+ | torch.isinf(opacities).any(dim=0)
146
+ | torch.isnan(sh0).any(dim=1)
147
+ | torch.isinf(sh0).any(dim=1)
148
+ | torch.isnan(shN).any(dim=1)
149
+ | torch.isinf(shN).any(dim=1)
150
+ )
151
+
152
+ # Filter out invalid entries
153
+ valid_mask = ~invalid_mask
154
+ means = means[valid_mask]
155
+ scales = scales[valid_mask]
156
+ quats = quats[valid_mask]
157
+ opacities = opacities[valid_mask]
158
+ sh0 = sh0[valid_mask]
159
+ shN = shN[valid_mask]
160
+
161
+ if format == "ply":
162
+ data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN)
163
+ else:
164
+ raise ValueError(f"Unsupported format: {format}")
165
+
166
+ if save_to:
167
+ with open(save_to, "wb") as binary_file:
168
+ binary_file.write(data)
169
+
170
+ return data
171
+
172
+
173
+ def create_splats_with_optimizers(
174
+ points: np.ndarray = None,
175
+ points_rgb: np.ndarray = None,
176
+ init_num_pts: int = 100_000,
177
+ init_extent: float = 3.0,
178
+ init_opacity: float = 0.1,
179
+ init_scale: float = 1.0,
180
+ means_lr: float = 1.6e-4,
181
+ scales_lr: float = 5e-3,
182
+ opacities_lr: float = 5e-2,
183
+ quats_lr: float = 1e-3,
184
+ sh0_lr: float = 2.5e-3,
185
+ shN_lr: float = 2.5e-3 / 20,
186
+ scene_scale: float = 1.0,
187
+ sh_degree: int = 3,
188
+ sparse_grad: bool = False,
189
+ visible_adam: bool = False,
190
+ batch_size: int = 1,
191
+ feature_dim: Optional[int] = None,
192
+ device: str = "cuda",
193
+ world_rank: int = 0,
194
+ world_size: int = 1,
195
+ ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
196
+ if points is not None and points_rgb is not None:
197
+ points = torch.from_numpy(points).float()
198
+ rgbs = torch.from_numpy(points_rgb / 255.0).float()
199
+ else:
200
+ points = (
201
+ init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
202
+ )
203
+ rgbs = torch.rand((init_num_pts, 3))
204
+
205
+ # Initialize the GS size to be the average dist of the 3 nearest neighbors
206
+ dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
207
+ dist_avg = torch.sqrt(dist2_avg)
208
+ scales = (
209
+ torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3)
210
+ ) # [N, 3]
211
+
212
+ # Distribute the GSs to different ranks (also works for single rank)
213
+ points = points[world_rank::world_size]
214
+ rgbs = rgbs[world_rank::world_size]
215
+ scales = scales[world_rank::world_size]
216
+
217
+ N = points.shape[0]
218
+ quats = torch.rand((N, 4)) # [N, 4]
219
+ opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
220
+
221
+ params = [
222
+ # name, value, lr
223
+ ("means", torch.nn.Parameter(points), means_lr * scene_scale),
224
+ ("scales", torch.nn.Parameter(scales), scales_lr),
225
+ ("quats", torch.nn.Parameter(quats), quats_lr),
226
+ ("opacities", torch.nn.Parameter(opacities), opacities_lr),
227
+ ]
228
+
229
+ if feature_dim is None:
230
+ # color is SH coefficients.
231
+ colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
232
+ colors[:, 0, :] = rgb_to_sh(rgbs)
233
+ params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr))
234
+ params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr))
235
+ else:
236
+ # features will be used for appearance and view-dependent shading
237
+ features = torch.rand(N, feature_dim) # [N, feature_dim]
238
+ params.append(("features", torch.nn.Parameter(features), sh0_lr))
239
+ colors = torch.logit(rgbs) # [N, 3]
240
+ params.append(("colors", torch.nn.Parameter(colors), sh0_lr))
241
+
242
+ splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
243
+ # Scale learning rate based on batch size, reference:
244
+ # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
245
+ # Note that this would not make the training exactly equivalent, see
246
+ # https://arxiv.org/pdf/2402.18824v1
247
+ BS = batch_size * world_size
248
+ optimizer_class = None
249
+ if sparse_grad:
250
+ optimizer_class = torch.optim.SparseAdam
251
+ elif visible_adam:
252
+ optimizer_class = SelectiveAdam
253
+ else:
254
+ optimizer_class = torch.optim.Adam
255
+ optimizers = {
256
+ name: optimizer_class(
257
+ [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
258
+ eps=1e-15 / math.sqrt(BS),
259
+ # TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
260
+ betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
261
+ )
262
+ for name, _, lr in params
263
+ }
264
+ return splats, optimizers
265
+
266
+
267
+ def compute_intrinsics_from_fovy(
268
+ image_w: int, image_h: int, fovy_deg: float
269
+ ) -> np.ndarray:
270
+ fovy_rad = np.deg2rad(fovy_deg)
271
+ fy = image_h / (2 * np.tan(fovy_rad / 2))
272
+ fx = fy * (image_w / image_h)
273
+ cx = image_w / 2
274
+ cy = image_h / 2
275
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
276
+
277
+ return K
278
+
279
+
280
+ def resize_pinhole_intrinsics(
281
+ raw_K: np.ndarray | torch.Tensor,
282
+ raw_hw: tuple[int, int],
283
+ new_hw: tuple[int, int],
284
+ ) -> np.ndarray:
285
+ raw_h, raw_w = raw_hw
286
+ new_h, new_w = new_hw
287
+
288
+ scale_x = new_w / raw_w
289
+ scale_y = new_h / raw_h
290
+
291
+ new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone()
292
+ new_K[0, 0] *= scale_x # fx
293
+ new_K[0, 2] *= scale_x # cx
294
+ new_K[1, 1] *= scale_y # fy
295
+ new_K[1, 2] *= scale_y # cy
296
+
297
+ return new_K
298
+
299
+
300
+ def restore_scene_scale_and_position(
301
+ real_height: float, mesh_path: str, gs_path: str
302
+ ) -> None:
303
+ """Scales a mesh and corresponding GS model to match a given real-world height.
304
+
305
+ Uses the 1st and 99th percentile of mesh Z-axis to estimate height,
306
+ applies scaling and vertical alignment, and updates both the mesh and GS model.
307
+
308
+ Args:
309
+ real_height (float): Target real-world height among Z axis.
310
+ mesh_path (str): Path to the input mesh file.
311
+ gs_path (str): Path to the Gaussian Splatting model file.
312
+ """
313
+ mesh = trimesh.load(mesh_path)
314
+ z_min = np.percentile(mesh.vertices[:, 1], 1)
315
+ z_max = np.percentile(mesh.vertices[:, 1], 99)
316
+ height = z_max - z_min
317
+ scale = real_height / height
318
+
319
+ rot = Rotation.from_quat([0, 1, 0, 0])
320
+ mesh.vertices = rot.apply(mesh.vertices)
321
+ mesh.vertices[:, 1] -= z_min
322
+ mesh.vertices *= scale
323
+ mesh.export(mesh_path)
324
+
325
+ gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path)
326
+ gs_model = gs_model.get_gaussians(
327
+ instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0])
328
+ )
329
+ gs_model.rescale(scale)
330
+ gs_model.save_to_ply(gs_path)
embodied_gen/utils/geometry.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ import random
20
+ from collections import defaultdict, deque
21
+ from functools import wraps
22
+ from typing import Literal
23
+
24
+ import numpy as np
25
+ import torch
26
+ import trimesh
27
+ from matplotlib.path import Path
28
+ from pyquaternion import Quaternion
29
+ from scipy.spatial import ConvexHull
30
+ from scipy.spatial.transform import Rotation as R
31
+ from shapely.geometry import Polygon
32
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
33
+ from embodied_gen.utils.log import logger
34
+
35
+ __all__ = [
36
+ "with_seed",
37
+ "matrix_to_pose",
38
+ "pose_to_matrix",
39
+ "quaternion_multiply",
40
+ "check_reachable",
41
+ "bfs_placement",
42
+ "compose_mesh_scene",
43
+ "compute_pinhole_intrinsics",
44
+ ]
45
+
46
+
47
+ def matrix_to_pose(matrix: np.ndarray) -> list[float]:
48
+ """Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw).
49
+
50
+ Args:
51
+ matrix (np.ndarray): 4x4 transformation matrix.
52
+
53
+ Returns:
54
+ List[float]: Pose as [x, y, z, qx, qy, qz, qw].
55
+ """
56
+ x, y, z = matrix[:3, 3]
57
+ rot_mat = matrix[:3, :3]
58
+ quat = R.from_matrix(rot_mat).as_quat()
59
+ qx, qy, qz, qw = quat
60
+
61
+ return [x, y, z, qx, qy, qz, qw]
62
+
63
+
64
+ def pose_to_matrix(pose: list[float]) -> np.ndarray:
65
+ """Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix.
66
+
67
+ Args:
68
+ List[float]: Pose as [x, y, z, qx, qy, qz, qw].
69
+
70
+ Returns:
71
+ matrix (np.ndarray): 4x4 transformation matrix.
72
+ """
73
+ x, y, z, qx, qy, qz, qw = pose
74
+ r = R.from_quat([qx, qy, qz, qw])
75
+ matrix = np.eye(4)
76
+ matrix[:3, :3] = r.as_matrix()
77
+ matrix[:3, 3] = [x, y, z]
78
+
79
+ return matrix
80
+
81
+
82
+ def compute_xy_bbox(
83
+ vertices: np.ndarray, col_x: int = 0, col_y: int = 1
84
+ ) -> list[float]:
85
+ x_vals = vertices[:, col_x]
86
+ y_vals = vertices[:, col_y]
87
+ return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max()
88
+
89
+
90
+ def has_iou_conflict(
91
+ new_box: list[float],
92
+ placed_boxes: list[list[float]],
93
+ iou_threshold: float = 0.0,
94
+ ) -> bool:
95
+ new_min_x, new_max_x, new_min_y, new_max_y = new_box
96
+ for min_x, max_x, min_y, max_y in placed_boxes:
97
+ ix1 = max(new_min_x, min_x)
98
+ iy1 = max(new_min_y, min_y)
99
+ ix2 = min(new_max_x, max_x)
100
+ iy2 = min(new_max_y, max_y)
101
+ inter_area = max(0, ix2 - ix1) * max(0, iy2 - iy1)
102
+ if inter_area > iou_threshold:
103
+ return True
104
+ return False
105
+
106
+
107
+ def with_seed(seed_attr_name: str = "seed"):
108
+ """A parameterized decorator that temporarily sets the random seed."""
109
+
110
+ def decorator(func):
111
+ @wraps(func)
112
+ def wrapper(*args, **kwargs):
113
+ seed = kwargs.get(seed_attr_name, None)
114
+ if seed is not None:
115
+ py_state = random.getstate()
116
+ np_state = np.random.get_state()
117
+ torch_state = torch.get_rng_state()
118
+
119
+ random.seed(seed)
120
+ np.random.seed(seed)
121
+ torch.manual_seed(seed)
122
+ try:
123
+ result = func(*args, **kwargs)
124
+ finally:
125
+ random.setstate(py_state)
126
+ np.random.set_state(np_state)
127
+ torch.set_rng_state(torch_state)
128
+ return result
129
+ else:
130
+ return func(*args, **kwargs)
131
+
132
+ return wrapper
133
+
134
+ return decorator
135
+
136
+
137
+ def compute_convex_hull_path(
138
+ vertices: np.ndarray,
139
+ z_threshold: float = 0.05,
140
+ interp_per_edge: int = 10,
141
+ margin: float = -0.02,
142
+ x_axis: int = 0,
143
+ y_axis: int = 1,
144
+ z_axis: int = 2,
145
+ ) -> Path:
146
+ top_vertices = vertices[
147
+ vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold
148
+ ]
149
+ top_xy = top_vertices[:, [x_axis, y_axis]]
150
+
151
+ if len(top_xy) < 3:
152
+ raise ValueError("Not enough points to form a convex hull")
153
+
154
+ hull = ConvexHull(top_xy)
155
+ hull_points = top_xy[hull.vertices]
156
+
157
+ polygon = Polygon(hull_points)
158
+ polygon = polygon.buffer(margin)
159
+ hull_points = np.array(polygon.exterior.coords)
160
+
161
+ dense_points = []
162
+ for i in range(len(hull_points)):
163
+ p1 = hull_points[i]
164
+ p2 = hull_points[(i + 1) % len(hull_points)]
165
+ for t in np.linspace(0, 1, interp_per_edge, endpoint=False):
166
+ pt = (1 - t) * p1 + t * p2
167
+ dense_points.append(pt)
168
+
169
+ return Path(np.array(dense_points), closed=True)
170
+
171
+
172
+ def find_parent_node(node: str, tree: dict) -> str | None:
173
+ for parent, children in tree.items():
174
+ if any(child[0] == node for child in children):
175
+ return parent
176
+ return None
177
+
178
+
179
+ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool:
180
+ x1, x2, y1, y2 = box
181
+ corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]]
182
+
183
+ num_inside = sum(hull.contains_point(c) for c in corners)
184
+ return num_inside >= threshold
185
+
186
+
187
+ def compute_axis_rotation_quat(
188
+ axis: Literal["x", "y", "z"], angle_rad: float
189
+ ) -> list[float]:
190
+ if axis.lower() == "x":
191
+ q = Quaternion(axis=[1, 0, 0], angle=angle_rad)
192
+ elif axis.lower() == "y":
193
+ q = Quaternion(axis=[0, 1, 0], angle=angle_rad)
194
+ elif axis.lower() == "z":
195
+ q = Quaternion(axis=[0, 0, 1], angle=angle_rad)
196
+ else:
197
+ raise ValueError(f"Unsupported axis '{axis}', must be one of x, y, z")
198
+
199
+ return [q.x, q.y, q.z, q.w]
200
+
201
+
202
+ def quaternion_multiply(
203
+ init_quat: list[float], rotate_quat: list[float]
204
+ ) -> list[float]:
205
+ qx, qy, qz, qw = init_quat
206
+ q1 = Quaternion(w=qw, x=qx, y=qy, z=qz)
207
+ qx, qy, qz, qw = rotate_quat
208
+ q2 = Quaternion(w=qw, x=qx, y=qy, z=qz)
209
+ quat = q2 * q1
210
+
211
+ return [quat.x, quat.y, quat.z, quat.w]
212
+
213
+
214
+ def check_reachable(
215
+ base_xyz: np.ndarray,
216
+ reach_xyz: np.ndarray,
217
+ min_reach: float = 0.25,
218
+ max_reach: float = 0.85,
219
+ ) -> bool:
220
+ """Check if the target point is within the reachable range."""
221
+ distance = np.linalg.norm(reach_xyz - base_xyz)
222
+
223
+ return min_reach < distance < max_reach
224
+
225
+
226
+ @with_seed("seed")
227
+ def bfs_placement(
228
+ layout_file: str,
229
+ floor_margin: float = 0,
230
+ beside_margin: float = 0.1,
231
+ max_attempts: int = 3000,
232
+ init_rpy: tuple = (1.5708, 0.0, 0.0),
233
+ rotate_objs: bool = True,
234
+ rotate_bg: bool = True,
235
+ rotate_context: bool = True,
236
+ limit_reach_range: tuple[float, float] | None = (0.20, 0.85),
237
+ max_orient_diff: float | None = 60,
238
+ robot_dim: float = 0.12,
239
+ seed: int = None,
240
+ ) -> LayoutInfo:
241
+ """Place objects in the layout using BFS traversal.
242
+
243
+ Args:
244
+ layout_file: Path to the JSON file defining the layout structure and assets.
245
+ floor_margin: Z-offset for the background object, typically for objects placed on the floor.
246
+ beside_margin: Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails.
247
+ max_attempts: Maximum number of attempts to find a non-overlapping position for an object.
248
+ init_rpy: Initial Roll-Pitch-Yaw rotation rad applied to all object meshes to align the mesh's
249
+ coordinate system with the world's (e.g., Z-up).
250
+ rotate_objs: If True, apply a random rotation around the Z-axis for manipulated and distractor objects.
251
+ rotate_bg: If True, apply a random rotation around the Y-axis for the background object.
252
+ rotate_context: If True, apply a random rotation around the Z-axis for the context object.
253
+ limit_reach_range: If set, enforce a check that manipulated objects are within the robot's reach range, in meter.
254
+ max_orient_diff: If set, enforce a check that manipulated objects are within the robot's orientation range, in degree.
255
+ robot_dim: The approximate dimension (e.g., diameter) of the robot for box representation.
256
+ seed: Random seed for reproducible placement.
257
+
258
+ Returns:
259
+ A :class:`LayoutInfo` object containing the objects and their final computed 7D poses
260
+ ([x, y, z, qx, qy, qz, qw]).
261
+ """
262
+ layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r")))
263
+ asset_dir = os.path.dirname(layout_file)
264
+ object_mapping = layout_info.objs_mapping
265
+ position = {} # node: [x, y, z, qx, qy, qz, qw]
266
+ parent_bbox_xy = {}
267
+ placed_boxes_map = defaultdict(list)
268
+ mesh_info = defaultdict(dict)
269
+ robot_node = layout_info.relation[Scene3DItemEnum.ROBOT.value]
270
+ for node in object_mapping:
271
+ if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
272
+ bg_quat = (
273
+ compute_axis_rotation_quat(
274
+ axis="y",
275
+ angle_rad=np.random.uniform(0, 2 * np.pi),
276
+ )
277
+ if rotate_bg
278
+ else [0, 0, 0, 1]
279
+ )
280
+ bg_quat = [round(q, 4) for q in bg_quat]
281
+ continue
282
+
283
+ mesh_path = (
284
+ f"{layout_info.assets[node]}/mesh/{node.replace(' ', '_')}.obj"
285
+ )
286
+ mesh_path = os.path.join(asset_dir, mesh_path)
287
+ mesh_info[node]["path"] = mesh_path
288
+ mesh = trimesh.load(mesh_path)
289
+ rotation = R.from_euler("xyz", init_rpy, degrees=False)
290
+ vertices = mesh.vertices @ rotation.as_matrix().T
291
+ z1 = np.percentile(vertices[:, 2], 1)
292
+ z2 = np.percentile(vertices[:, 2], 99)
293
+
294
+ if object_mapping[node] == Scene3DItemEnum.CONTEXT.value:
295
+ object_quat = [0, 0, 0, 1]
296
+ if rotate_context:
297
+ angle_rad = np.random.uniform(0, 2 * np.pi)
298
+ object_quat = compute_axis_rotation_quat(
299
+ axis="z", angle_rad=angle_rad
300
+ )
301
+ rotation = R.from_quat(object_quat).as_matrix()
302
+ vertices = vertices @ rotation.T
303
+
304
+ mesh_info[node]["surface"] = compute_convex_hull_path(vertices)
305
+
306
+ # Put robot in the CONTEXT edge.
307
+ x, y = random.choice(mesh_info[node]["surface"].vertices)
308
+ theta = np.arctan2(y, x)
309
+ quat_initial = Quaternion(axis=[0, 0, 1], angle=theta)
310
+ quat_extra = Quaternion(axis=[0, 0, 1], angle=np.pi)
311
+ quat = quat_extra * quat_initial
312
+ _pose = [x, y, z2 - z1, quat.x, quat.y, quat.z, quat.w]
313
+ position[robot_node] = [round(v, 4) for v in _pose]
314
+ node_box = [
315
+ x - robot_dim / 2,
316
+ x + robot_dim / 2,
317
+ y - robot_dim / 2,
318
+ y + robot_dim / 2,
319
+ ]
320
+ placed_boxes_map[node].append(node_box)
321
+ elif rotate_objs:
322
+ # For manipulated and distractor objects, apply random rotation
323
+ angle_rad = np.random.uniform(0, 2 * np.pi)
324
+ object_quat = compute_axis_rotation_quat(
325
+ axis="z", angle_rad=angle_rad
326
+ )
327
+ rotation = R.from_quat(object_quat).as_matrix()
328
+ vertices = vertices @ rotation.T
329
+
330
+ x1, x2, y1, y2 = compute_xy_bbox(vertices)
331
+ mesh_info[node]["pose"] = [x1, x2, y1, y2, z1, z2, *object_quat]
332
+ mesh_info[node]["area"] = max(1e-5, (x2 - x1) * (y2 - y1))
333
+
334
+ root = list(layout_info.tree.keys())[0]
335
+ queue = deque([((root, None), layout_info.tree.get(root, []))])
336
+ while queue:
337
+ (node, relation), children = queue.popleft()
338
+ if node not in object_mapping:
339
+ continue
340
+
341
+ if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
342
+ position[node] = [0, 0, floor_margin, *bg_quat]
343
+ else:
344
+ x1, x2, y1, y2, z1, z2, qx, qy, qz, qw = mesh_info[node]["pose"]
345
+ if object_mapping[node] == Scene3DItemEnum.CONTEXT.value:
346
+ position[node] = [0, 0, -round(z1, 4), qx, qy, qz, qw]
347
+ parent_bbox_xy[node] = [x1, x2, y1, y2, z1, z2]
348
+ elif object_mapping[node] in [
349
+ Scene3DItemEnum.MANIPULATED_OBJS.value,
350
+ Scene3DItemEnum.DISTRACTOR_OBJS.value,
351
+ ]:
352
+ parent_node = find_parent_node(node, layout_info.tree)
353
+ parent_pos = position[parent_node]
354
+ (
355
+ p_x1,
356
+ p_x2,
357
+ p_y1,
358
+ p_y2,
359
+ p_z1,
360
+ p_z2,
361
+ ) = parent_bbox_xy[parent_node]
362
+
363
+ obj_dx = x2 - x1
364
+ obj_dy = y2 - y1
365
+ hull_path = mesh_info[parent_node].get("surface")
366
+ for _ in range(max_attempts):
367
+ node_x1 = random.uniform(p_x1, p_x2 - obj_dx)
368
+ node_y1 = random.uniform(p_y1, p_y2 - obj_dy)
369
+ node_box = [
370
+ node_x1,
371
+ node_x1 + obj_dx,
372
+ node_y1,
373
+ node_y1 + obj_dy,
374
+ ]
375
+ if hull_path and not all_corners_inside(
376
+ hull_path, node_box
377
+ ):
378
+ continue
379
+ # Make sure the manipulated object is reachable by robot.
380
+ if (
381
+ limit_reach_range is not None
382
+ and object_mapping[node]
383
+ == Scene3DItemEnum.MANIPULATED_OBJS.value
384
+ ):
385
+ cx = parent_pos[0] + node_box[0] + obj_dx / 2
386
+ cy = parent_pos[1] + node_box[2] + obj_dy / 2
387
+ cz = parent_pos[2] + p_z2 - z1
388
+ robot_pos = position[robot_node][:3]
389
+ if not check_reachable(
390
+ base_xyz=np.array(robot_pos),
391
+ reach_xyz=np.array([cx, cy, cz]),
392
+ min_reach=limit_reach_range[0],
393
+ max_reach=limit_reach_range[1],
394
+ ):
395
+ continue
396
+
397
+ # Make sure the manipulated object is inside the robot's orientation.
398
+ if (
399
+ max_orient_diff is not None
400
+ and object_mapping[node]
401
+ == Scene3DItemEnum.MANIPULATED_OBJS.value
402
+ ):
403
+ cx = parent_pos[0] + node_box[0] + obj_dx / 2
404
+ cy = parent_pos[1] + node_box[2] + obj_dy / 2
405
+ cx2, cy2 = position[robot_node][:2]
406
+ v1 = np.array([-cx2, -cy2])
407
+ v2 = np.array([cx - cx2, cy - cy2])
408
+ dot = np.dot(v1, v2)
409
+ norms = np.linalg.norm(v1) * np.linalg.norm(v2)
410
+ theta = np.arccos(np.clip(dot / norms, -1.0, 1.0))
411
+ theta = np.rad2deg(theta)
412
+ if theta > max_orient_diff:
413
+ continue
414
+
415
+ if not has_iou_conflict(
416
+ node_box, placed_boxes_map[parent_node]
417
+ ):
418
+ z_offset = 0
419
+ break
420
+ else:
421
+ logger.warning(
422
+ f"Cannot place {node} on {parent_node} without overlap"
423
+ f" after {max_attempts} attempts, place beside {parent_node}."
424
+ )
425
+ for _ in range(max_attempts):
426
+ node_x1 = random.choice(
427
+ [
428
+ random.uniform(
429
+ p_x1 - obj_dx - beside_margin,
430
+ p_x1 - obj_dx,
431
+ ),
432
+ random.uniform(p_x2, p_x2 + beside_margin),
433
+ ]
434
+ )
435
+ node_y1 = random.choice(
436
+ [
437
+ random.uniform(
438
+ p_y1 - obj_dy - beside_margin,
439
+ p_y1 - obj_dy,
440
+ ),
441
+ random.uniform(p_y2, p_y2 + beside_margin),
442
+ ]
443
+ )
444
+ node_box = [
445
+ node_x1,
446
+ node_x1 + obj_dx,
447
+ node_y1,
448
+ node_y1 + obj_dy,
449
+ ]
450
+ z_offset = -(parent_pos[2] + p_z2)
451
+ if not has_iou_conflict(
452
+ node_box, placed_boxes_map[parent_node]
453
+ ):
454
+ break
455
+
456
+ placed_boxes_map[parent_node].append(node_box)
457
+
458
+ abs_cx = parent_pos[0] + node_box[0] + obj_dx / 2
459
+ abs_cy = parent_pos[1] + node_box[2] + obj_dy / 2
460
+ abs_cz = parent_pos[2] + p_z2 - z1 + z_offset
461
+ position[node] = [
462
+ round(v, 4)
463
+ for v in [abs_cx, abs_cy, abs_cz, qx, qy, qz, qw]
464
+ ]
465
+ parent_bbox_xy[node] = [x1, x2, y1, y2, z1, z2]
466
+
467
+ sorted_children = sorted(
468
+ children, key=lambda x: -mesh_info[x[0]].get("area", 0)
469
+ )
470
+ for child, rel in sorted_children:
471
+ queue.append(((child, rel), layout_info.tree.get(child, [])))
472
+
473
+ layout_info.position = position
474
+
475
+ return layout_info
476
+
477
+
478
+ def compose_mesh_scene(
479
+ layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False
480
+ ) -> None:
481
+ object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation)
482
+ scene = trimesh.Scene()
483
+ for node in layout_info.assets:
484
+ if object_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
485
+ mesh_path = f"{layout_info.assets[node]}/mesh_model.ply"
486
+ if not with_bg:
487
+ continue
488
+ else:
489
+ mesh_path = (
490
+ f"{layout_info.assets[node]}/mesh/{node.replace(' ', '_')}.obj"
491
+ )
492
+
493
+ mesh = trimesh.load(mesh_path)
494
+ offset = np.array(layout_info.position[node])[[0, 2, 1]]
495
+ mesh.vertices += offset
496
+ scene.add_geometry(mesh, node_name=node)
497
+
498
+ os.makedirs(os.path.dirname(out_scene_path), exist_ok=True)
499
+ scene.export(out_scene_path)
500
+ logger.info(f"Composed interactive 3D layout saved in {out_scene_path}")
501
+
502
+ return
503
+
504
+
505
+ def compute_pinhole_intrinsics(
506
+ image_w: int, image_h: int, fov_deg: float
507
+ ) -> np.ndarray:
508
+ fov_rad = np.deg2rad(fov_deg)
509
+ fx = image_w / (2 * np.tan(fov_rad / 2))
510
+ fy = fx # assuming square pixels
511
+ cx = image_w / 2
512
+ cy = image_h / 2
513
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
514
+
515
+ return K
embodied_gen/utils/gpt_clients.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import base64
19
+ import logging
20
+ import os
21
+ from io import BytesIO
22
+ from typing import Optional
23
+
24
+ import yaml
25
+ from openai import AzureOpenAI, OpenAI # pip install openai
26
+ from PIL import Image
27
+ from tenacity import (
28
+ retry,
29
+ stop_after_attempt,
30
+ stop_after_delay,
31
+ wait_random_exponential,
32
+ )
33
+ from embodied_gen.utils.process_media import combine_images_to_grid
34
+
35
+ logging.getLogger("httpx").setLevel(logging.WARNING)
36
+ logging.basicConfig(level=logging.WARNING)
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ __all__ = [
41
+ "GPTclient",
42
+ ]
43
+
44
+ CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml"
45
+
46
+
47
+ class GPTclient:
48
+ """A client to interact with the GPT model via OpenAI or Azure API."""
49
+
50
+ def __init__(
51
+ self,
52
+ endpoint: str,
53
+ api_key: str,
54
+ model_name: str = "yfb-gpt-4o",
55
+ api_version: str = None,
56
+ check_connection: bool = True,
57
+ verbose: bool = False,
58
+ ):
59
+ if api_version is not None:
60
+ self.client = AzureOpenAI(
61
+ azure_endpoint=endpoint,
62
+ api_key=api_key,
63
+ api_version=api_version,
64
+ )
65
+ else:
66
+ self.client = OpenAI(
67
+ base_url=endpoint,
68
+ api_key=api_key,
69
+ )
70
+
71
+ self.endpoint = endpoint
72
+ self.model_name = model_name
73
+ self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
74
+ self.verbose = verbose
75
+ if check_connection:
76
+ self.check_connection()
77
+
78
+ logger.info(f"Using GPT model: {self.model_name}.")
79
+
80
+ @retry(
81
+ wait=wait_random_exponential(min=1, max=20),
82
+ stop=(stop_after_attempt(10) | stop_after_delay(30)),
83
+ )
84
+ def completion_with_backoff(self, **kwargs):
85
+ return self.client.chat.completions.create(**kwargs)
86
+
87
+ def query(
88
+ self,
89
+ text_prompt: str,
90
+ image_base64: Optional[list[str | Image.Image]] = None,
91
+ system_role: Optional[str] = None,
92
+ params: Optional[dict] = None,
93
+ ) -> Optional[str]:
94
+ """Queries the GPT model with a text and optional image prompts.
95
+
96
+ Args:
97
+ text_prompt (str): The main text input that the model responds to.
98
+ image_base64 (Optional[List[str]]): A list of image base64 strings
99
+ or local image paths or PIL.Image to accompany the text prompt.
100
+ system_role (Optional[str]): Optional system-level instructions
101
+ that specify the behavior of the assistant.
102
+ params (Optional[dict]): Additional parameters for GPT setting.
103
+
104
+ Returns:
105
+ Optional[str]: The response content generated by the model based on
106
+ the prompt. Returns `None` if an error occurs.
107
+ """
108
+ if system_role is None:
109
+ system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa
110
+
111
+ content_user = [
112
+ {
113
+ "type": "text",
114
+ "text": text_prompt,
115
+ },
116
+ ]
117
+
118
+ # Process images if provided
119
+ if image_base64 is not None:
120
+ if not isinstance(image_base64, list):
121
+ image_base64 = [image_base64]
122
+ # Hardcode tmp because of the openrouter can't input multi images.
123
+ if "openrouter" in self.endpoint:
124
+ image_base64 = combine_images_to_grid(image_base64)
125
+ for img in image_base64:
126
+ if isinstance(img, Image.Image):
127
+ buffer = BytesIO()
128
+ img.save(buffer, format=img.format or "PNG")
129
+ buffer.seek(0)
130
+ image_binary = buffer.read()
131
+ img = base64.b64encode(image_binary).decode("utf-8")
132
+ elif (
133
+ len(os.path.splitext(img)) > 1
134
+ and os.path.splitext(img)[-1].lower() in self.image_formats
135
+ ):
136
+ if not os.path.exists(img):
137
+ raise FileNotFoundError(f"Image file not found: {img}")
138
+ with open(img, "rb") as f:
139
+ img = base64.b64encode(f.read()).decode("utf-8")
140
+
141
+ content_user.append(
142
+ {
143
+ "type": "image_url",
144
+ "image_url": {"url": f"data:image/png;base64,{img}"},
145
+ }
146
+ )
147
+
148
+ payload = {
149
+ "messages": [
150
+ {"role": "system", "content": system_role},
151
+ {"role": "user", "content": content_user},
152
+ ],
153
+ "temperature": 0.1,
154
+ "max_tokens": 500,
155
+ "top_p": 0.1,
156
+ "frequency_penalty": 0,
157
+ "presence_penalty": 0,
158
+ "stop": None,
159
+ "model": self.model_name,
160
+ }
161
+
162
+ if params:
163
+ payload.update(params)
164
+
165
+ response = None
166
+ try:
167
+ response = self.completion_with_backoff(**payload)
168
+ response = response.choices[0].message.content
169
+ except Exception as e:
170
+ logger.error(f"Error GPTclint {self.endpoint} API call: {e}")
171
+ response = None
172
+
173
+ if self.verbose:
174
+ logger.info(f"Prompt: {text_prompt}")
175
+ logger.info(f"Response: {response}")
176
+
177
+ return response
178
+
179
+ def check_connection(self) -> None:
180
+ """Check whether the GPT API connection is working."""
181
+ try:
182
+ response = self.completion_with_backoff(
183
+ messages=[
184
+ {"role": "system", "content": "You are a test system."},
185
+ {"role": "user", "content": "Hello"},
186
+ ],
187
+ model=self.model_name,
188
+ temperature=0,
189
+ max_tokens=100,
190
+ )
191
+ content = response.choices[0].message.content
192
+ logger.info(f"Connection check success.")
193
+ except Exception as e:
194
+ raise ConnectionError(
195
+ f"Failed to connect to GPT API at {self.endpoint}, "
196
+ f"please check setting in `{CONFIG_FILE}` and `README`."
197
+ )
198
+
199
+
200
+ with open(CONFIG_FILE, "r") as f:
201
+ config = yaml.safe_load(f)
202
+
203
+ agent_type = config["agent_type"]
204
+ agent_config = config.get(agent_type, {})
205
+
206
+ # Prefer environment variables, fallback to YAML config
207
+ endpoint = os.environ.get("ENDPOINT", agent_config.get("endpoint"))
208
+ api_key = os.environ.get("API_KEY", agent_config.get("api_key"))
209
+ api_version = os.environ.get("API_VERSION", agent_config.get("api_version"))
210
+ model_name = os.environ.get("MODEL_NAME", agent_config.get("model_name"))
211
+
212
+ GPT_CLIENT = GPTclient(
213
+ endpoint=endpoint,
214
+ api_key=api_key,
215
+ api_version=api_version,
216
+ model_name=model_name,
217
+ check_connection=False,
218
+ )
embodied_gen/utils/gpt_config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.yaml
2
+ agent_type: "qwen2.5-vl" # gpt-4o or qwen2.5-vl
3
+
4
+ gpt-4o:
5
+ endpoint: https://xxx.openai.azure.com
6
+ api_key: xxx
7
+ api_version: 2025-xx-xx
8
+ model_name: yfb-gpt-4o
9
+
10
+ qwen2.5-vl:
11
+ endpoint: https://openrouter.ai/api/v1
12
+ api_key: sk-or-v1-xxx
13
+ api_version: null
14
+ model_name: qwen/qwen2.5-vl-72b-instruct:free
embodied_gen/utils/log.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import logging
18
+
19
+ from colorlog import ColoredFormatter
20
+
21
+ __all__ = [
22
+ "logger",
23
+ ]
24
+
25
+ LOG_FORMAT = (
26
+ "%(log_color)s[%(asctime)s] %(levelname)-8s | %(message)s%(reset)s"
27
+ )
28
+ DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
29
+
30
+ formatter = ColoredFormatter(
31
+ LOG_FORMAT,
32
+ datefmt=DATE_FORMAT,
33
+ log_colors={
34
+ "DEBUG": "cyan",
35
+ "INFO": "green",
36
+ "WARNING": "yellow",
37
+ "ERROR": "red",
38
+ "CRITICAL": "bold_red",
39
+ },
40
+ )
41
+
42
+ handler = logging.StreamHandler()
43
+ handler.setFormatter(formatter)
44
+
45
+ logger = logging.getLogger(__name__)
46
+ logger.setLevel(logging.INFO)
47
+ logger.addHandler(handler)
48
+ logger.propagate = False
embodied_gen/utils/monkey_patches.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import os
18
+ import sys
19
+ import zipfile
20
+
21
+ import numpy as np
22
+ import torch
23
+ from huggingface_hub import hf_hub_download
24
+ from omegaconf import OmegaConf
25
+ from PIL import Image
26
+ from torchvision import transforms
27
+
28
+
29
+ def monkey_patch_pano2room():
30
+ current_file_path = os.path.abspath(__file__)
31
+ current_dir = os.path.dirname(current_file_path)
32
+ sys.path.append(os.path.join(current_dir, "../.."))
33
+ sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room"))
34
+ from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import (
35
+ OmnidataNormalPredictor,
36
+ )
37
+ from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import (
38
+ OmnidataPredictor,
39
+ )
40
+
41
+ def patched_omni_depth_init(self):
42
+ self.img_size = 384
43
+ self.model = torch.hub.load(
44
+ 'alexsax/omnidata_models', 'depth_dpt_hybrid_384'
45
+ )
46
+ self.model.eval()
47
+ self.trans_totensor = transforms.Compose(
48
+ [
49
+ transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
50
+ transforms.CenterCrop(self.img_size),
51
+ transforms.Normalize(mean=0.5, std=0.5),
52
+ ]
53
+ )
54
+
55
+ OmnidataPredictor.__init__ = patched_omni_depth_init
56
+
57
+ def patched_omni_normal_init(self):
58
+ self.img_size = 384
59
+ self.model = torch.hub.load(
60
+ 'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384'
61
+ )
62
+ self.model.eval()
63
+ self.trans_totensor = transforms.Compose(
64
+ [
65
+ transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
66
+ transforms.CenterCrop(self.img_size),
67
+ transforms.Normalize(mean=0.5, std=0.5),
68
+ ]
69
+ )
70
+
71
+ OmnidataNormalPredictor.__init__ = patched_omni_normal_init
72
+
73
+ def patched_panojoint_init(self, save_path=None):
74
+ self.depth_predictor = OmnidataPredictor()
75
+ self.normal_predictor = OmnidataNormalPredictor()
76
+ self.save_path = save_path
77
+
78
+ from modules.geo_predictors import PanoJointPredictor
79
+
80
+ PanoJointPredictor.__init__ = patched_panojoint_init
81
+
82
+ # NOTE: We use gsplat instead.
83
+ # import depth_diff_gaussian_rasterization_min as ddgr
84
+ # from dataclasses import dataclass
85
+ # @dataclass
86
+ # class PatchedGaussianRasterizationSettings:
87
+ # image_height: int
88
+ # image_width: int
89
+ # tanfovx: float
90
+ # tanfovy: float
91
+ # bg: torch.Tensor
92
+ # scale_modifier: float
93
+ # viewmatrix: torch.Tensor
94
+ # projmatrix: torch.Tensor
95
+ # sh_degree: int
96
+ # campos: torch.Tensor
97
+ # prefiltered: bool
98
+ # debug: bool = False
99
+ # ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings
100
+
101
+ # disable get_has_ddp_rank print in `BaseInpaintingTrainingModule`
102
+ os.environ["NODE_RANK"] = "0"
103
+
104
+ from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import (
105
+ load_checkpoint,
106
+ )
107
+ from thirdparty.pano2room.modules.inpainters.lama_inpainter import (
108
+ LamaInpainter,
109
+ )
110
+
111
+ def patched_lama_inpaint_init(self):
112
+ zip_path = hf_hub_download(
113
+ repo_id="smartywu/big-lama",
114
+ filename="big-lama.zip",
115
+ repo_type="model",
116
+ )
117
+ extract_dir = os.path.splitext(zip_path)[0]
118
+
119
+ if not os.path.exists(extract_dir):
120
+ os.makedirs(extract_dir, exist_ok=True)
121
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
122
+ zip_ref.extractall(extract_dir)
123
+
124
+ config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml')
125
+ checkpoint_path = os.path.join(
126
+ extract_dir, 'big-lama/models/best.ckpt'
127
+ )
128
+ train_config = OmegaConf.load(config_path)
129
+ train_config.training_model.predict_only = True
130
+ train_config.visualizer.kind = 'noop'
131
+
132
+ self.model = load_checkpoint(
133
+ train_config, checkpoint_path, strict=False, map_location='cpu'
134
+ )
135
+ self.model.freeze()
136
+
137
+ LamaInpainter.__init__ = patched_lama_inpaint_init
138
+
139
+ from diffusers import StableDiffusionInpaintPipeline
140
+ from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import (
141
+ SDFTInpainter,
142
+ )
143
+
144
+ def patched_sd_inpaint_init(self, subset_name=None):
145
+ super(SDFTInpainter, self).__init__()
146
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
147
+ "stabilityai/stable-diffusion-2-inpainting",
148
+ torch_dtype=torch.float16,
149
+ ).to("cuda")
150
+ pipe.enable_model_cpu_offload()
151
+ self.inpaint_pipe = pipe
152
+
153
+ SDFTInpainter.__init__ = patched_sd_inpaint_init
154
+
155
+
156
+ def monkey_patch_maniskill():
157
+ from mani_skill.envs.scene import ManiSkillScene
158
+
159
+ def get_sensor_images(
160
+ self, obs: dict[str, any]
161
+ ) -> dict[str, dict[str, torch.Tensor]]:
162
+ sensor_data = dict()
163
+ for name, sensor in self.sensors.items():
164
+ sensor_data[name] = sensor.get_images(obs[name])
165
+ return sensor_data
166
+
167
+ def get_human_render_camera_images(
168
+ self, camera_name: str = None, return_alpha: bool = False
169
+ ) -> dict[str, torch.Tensor]:
170
+ def get_rgba_tensor(camera, return_alpha):
171
+ color = camera.get_obs(
172
+ rgb=True, depth=False, segmentation=False, position=False
173
+ )["rgb"]
174
+ if return_alpha:
175
+ seg_labels = camera.get_obs(
176
+ rgb=False, depth=False, segmentation=True, position=False
177
+ )["segmentation"]
178
+ masks = np.where((seg_labels.cpu() > 1), 255, 0).astype(
179
+ np.uint8
180
+ )
181
+ masks = torch.tensor(masks).to(color.device)
182
+ color = torch.concat([color, masks], dim=-1)
183
+
184
+ return color
185
+
186
+ image_data = dict()
187
+ if self.gpu_sim_enabled:
188
+ if self.parallel_in_single_scene:
189
+ for name, camera in self.human_render_cameras.items():
190
+ camera.camera._render_cameras[0].take_picture()
191
+ rgba = get_rgba_tensor(camera, return_alpha)
192
+ image_data[name] = rgba
193
+ else:
194
+ for name, camera in self.human_render_cameras.items():
195
+ if camera_name is not None and name != camera_name:
196
+ continue
197
+ assert camera.config.shader_config.shader_pack not in [
198
+ "rt",
199
+ "rt-fast",
200
+ "rt-med",
201
+ ], "ray tracing shaders do not work with parallel rendering"
202
+ camera.capture()
203
+ rgba = get_rgba_tensor(camera, return_alpha)
204
+ image_data[name] = rgba
205
+ else:
206
+ for name, camera in self.human_render_cameras.items():
207
+ if camera_name is not None and name != camera_name:
208
+ continue
209
+ camera.capture()
210
+ rgba = get_rgba_tensor(camera, return_alpha)
211
+ image_data[name] = rgba
212
+
213
+ return image_data
214
+
215
+ ManiSkillScene.get_sensor_images = get_sensor_images
216
+ ManiSkillScene.get_human_render_camera_images = (
217
+ get_human_render_camera_images
218
+ )
embodied_gen/utils/process_media.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import logging
19
+ import math
20
+ import mimetypes
21
+ import os
22
+ import textwrap
23
+ from glob import glob
24
+ from typing import Union
25
+
26
+ import cv2
27
+ import imageio
28
+ import matplotlib.pyplot as plt
29
+ import networkx as nx
30
+ import numpy as np
31
+ import spaces
32
+ from matplotlib.patches import Patch
33
+ from moviepy.editor import VideoFileClip, clips_array
34
+ from PIL import Image
35
+ from embodied_gen.data.differentiable_render import entrypoint as render_api
36
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
37
+
38
+ logging.basicConfig(level=logging.INFO)
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ __all__ = [
43
+ "render_asset3d",
44
+ "merge_images_video",
45
+ "filter_small_connected_components",
46
+ "filter_image_small_connected_components",
47
+ "combine_images_to_grid",
48
+ "SceneTreeVisualizer",
49
+ "is_image_file",
50
+ "parse_text_prompts",
51
+ "check_object_edge_truncated",
52
+ "vcat_pil_images",
53
+ ]
54
+
55
+
56
+ @spaces.GPU
57
+ def render_asset3d(
58
+ mesh_path: str,
59
+ output_root: str,
60
+ distance: float = 5.0,
61
+ num_images: int = 1,
62
+ elevation: list[float] = (0.0,),
63
+ pbr_light_factor: float = 1.2,
64
+ return_key: str = "image_color/*",
65
+ output_subdir: str = "renders",
66
+ gen_color_mp4: bool = False,
67
+ gen_viewnormal_mp4: bool = False,
68
+ gen_glonormal_mp4: bool = False,
69
+ no_index_file: bool = False,
70
+ with_mtl: bool = True,
71
+ ) -> list[str]:
72
+ input_args = dict(
73
+ mesh_path=mesh_path,
74
+ output_root=output_root,
75
+ uuid=output_subdir,
76
+ distance=distance,
77
+ num_images=num_images,
78
+ elevation=elevation,
79
+ pbr_light_factor=pbr_light_factor,
80
+ with_mtl=with_mtl,
81
+ gen_color_mp4=gen_color_mp4,
82
+ gen_viewnormal_mp4=gen_viewnormal_mp4,
83
+ gen_glonormal_mp4=gen_glonormal_mp4,
84
+ no_index_file=no_index_file,
85
+ )
86
+
87
+ try:
88
+ _ = render_api(**input_args)
89
+ except Exception as e:
90
+ logger.error(f"Error occurred during rendering: {e}.")
91
+
92
+ dst_paths = glob(os.path.join(output_root, output_subdir, return_key))
93
+
94
+ return dst_paths
95
+
96
+
97
+ def merge_images_video(color_images, normal_images, output_path) -> None:
98
+ width = color_images[0].shape[1]
99
+ combined_video = [
100
+ np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]])
101
+ for rgb_img, normal_img in zip(color_images, normal_images)
102
+ ]
103
+ imageio.mimsave(output_path, combined_video, fps=50)
104
+
105
+ return
106
+
107
+
108
+ def merge_video_video(
109
+ video_path1: str, video_path2: str, output_path: str
110
+ ) -> None:
111
+ """Merge two videos by the left half and the right half of the videos."""
112
+ clip1 = VideoFileClip(video_path1)
113
+ clip2 = VideoFileClip(video_path2)
114
+
115
+ if clip1.size != clip2.size:
116
+ raise ValueError("The resolutions of the two videos do not match.")
117
+
118
+ width, height = clip1.size
119
+ clip1_half = clip1.crop(x1=0, y1=0, x2=width // 2, y2=height)
120
+ clip2_half = clip2.crop(x1=width // 2, y1=0, x2=width, y2=height)
121
+ final_clip = clips_array([[clip1_half, clip2_half]])
122
+ final_clip.write_videofile(output_path, codec="libx264")
123
+
124
+
125
+ def filter_small_connected_components(
126
+ mask: Union[Image.Image, np.ndarray],
127
+ area_ratio: float,
128
+ connectivity: int = 8,
129
+ ) -> np.ndarray:
130
+ if isinstance(mask, Image.Image):
131
+ mask = np.array(mask)
132
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
133
+ mask,
134
+ connectivity=connectivity,
135
+ )
136
+
137
+ small_components = np.zeros_like(mask, dtype=np.uint8)
138
+ mask_area = (mask != 0).sum()
139
+ min_area = mask_area // area_ratio
140
+ for label in range(1, num_labels):
141
+ area = stats[label, cv2.CC_STAT_AREA]
142
+ if area < min_area:
143
+ small_components[labels == label] = 255
144
+
145
+ mask = cv2.bitwise_and(mask, cv2.bitwise_not(small_components))
146
+
147
+ return mask
148
+
149
+
150
+ def filter_image_small_connected_components(
151
+ image: Union[Image.Image, np.ndarray],
152
+ area_ratio: float = 10,
153
+ connectivity: int = 8,
154
+ ) -> np.ndarray:
155
+ if isinstance(image, Image.Image):
156
+ image = image.convert("RGBA")
157
+ image = np.array(image)
158
+
159
+ mask = image[..., 3]
160
+ mask = filter_small_connected_components(mask, area_ratio, connectivity)
161
+ image[..., 3] = mask
162
+
163
+ return image
164
+
165
+
166
+ def combine_images_to_grid(
167
+ images: list[str | Image.Image],
168
+ cat_row_col: tuple[int, int] = None,
169
+ target_wh: tuple[int, int] = (512, 512),
170
+ image_mode: str = "RGB",
171
+ ) -> list[Image.Image]:
172
+ n_images = len(images)
173
+ if n_images == 1:
174
+ return images
175
+
176
+ if cat_row_col is None:
177
+ n_col = math.ceil(math.sqrt(n_images))
178
+ n_row = math.ceil(n_images / n_col)
179
+ else:
180
+ n_row, n_col = cat_row_col
181
+
182
+ images = [
183
+ Image.open(p).convert(image_mode) if isinstance(p, str) else p
184
+ for p in images
185
+ ]
186
+ images = [img.resize(target_wh) for img in images]
187
+
188
+ grid_w, grid_h = n_col * target_wh[0], n_row * target_wh[1]
189
+ grid = Image.new(image_mode, (grid_w, grid_h), (0, 0, 0))
190
+
191
+ for idx, img in enumerate(images):
192
+ row, col = divmod(idx, n_col)
193
+ grid.paste(img, (col * target_wh[0], row * target_wh[1]))
194
+
195
+ return [grid]
196
+
197
+
198
+ class SceneTreeVisualizer:
199
+ def __init__(self, layout_info: LayoutInfo) -> None:
200
+ self.tree = layout_info.tree
201
+ self.relation = layout_info.relation
202
+ self.objs_desc = layout_info.objs_desc
203
+ self.G = nx.DiGraph()
204
+ self.root = self._find_root()
205
+ self._build_graph()
206
+
207
+ self.role_colors = {
208
+ Scene3DItemEnum.BACKGROUND.value: "plum",
209
+ Scene3DItemEnum.CONTEXT.value: "lightblue",
210
+ Scene3DItemEnum.ROBOT.value: "lightcoral",
211
+ Scene3DItemEnum.MANIPULATED_OBJS.value: "lightgreen",
212
+ Scene3DItemEnum.DISTRACTOR_OBJS.value: "lightgray",
213
+ Scene3DItemEnum.OTHERS.value: "orange",
214
+ }
215
+
216
+ def _find_root(self) -> str:
217
+ children = {c for cs in self.tree.values() for c, _ in cs}
218
+ parents = set(self.tree.keys())
219
+ roots = parents - children
220
+ if not roots:
221
+ raise ValueError("No root node found.")
222
+ return next(iter(roots))
223
+
224
+ def _build_graph(self):
225
+ for parent, children in self.tree.items():
226
+ for child, relation in children:
227
+ self.G.add_edge(parent, child, relation=relation)
228
+
229
+ def _get_node_role(self, node: str) -> str:
230
+ if node == self.relation.get(Scene3DItemEnum.BACKGROUND.value):
231
+ return Scene3DItemEnum.BACKGROUND.value
232
+ if node == self.relation.get(Scene3DItemEnum.CONTEXT.value):
233
+ return Scene3DItemEnum.CONTEXT.value
234
+ if node == self.relation.get(Scene3DItemEnum.ROBOT.value):
235
+ return Scene3DItemEnum.ROBOT.value
236
+ if node in self.relation.get(
237
+ Scene3DItemEnum.MANIPULATED_OBJS.value, []
238
+ ):
239
+ return Scene3DItemEnum.MANIPULATED_OBJS.value
240
+ if node in self.relation.get(
241
+ Scene3DItemEnum.DISTRACTOR_OBJS.value, []
242
+ ):
243
+ return Scene3DItemEnum.DISTRACTOR_OBJS.value
244
+ return Scene3DItemEnum.OTHERS.value
245
+
246
+ def _get_positions(
247
+ self, root, width=1.0, vert_gap=0.1, vert_loc=1, xcenter=0.5, pos=None
248
+ ):
249
+ if pos is None:
250
+ pos = {root: (xcenter, vert_loc)}
251
+ else:
252
+ pos[root] = (xcenter, vert_loc)
253
+
254
+ children = list(self.G.successors(root))
255
+ if children:
256
+ dx = width / len(children)
257
+ next_x = xcenter - width / 2 - dx / 2
258
+ for child in children:
259
+ next_x += dx
260
+ pos = self._get_positions(
261
+ child,
262
+ width=dx,
263
+ vert_gap=vert_gap,
264
+ vert_loc=vert_loc - vert_gap,
265
+ xcenter=next_x,
266
+ pos=pos,
267
+ )
268
+ return pos
269
+
270
+ def render(
271
+ self,
272
+ save_path: str,
273
+ figsize=(8, 6),
274
+ dpi=300,
275
+ title: str = "Scene 3D Hierarchy Tree",
276
+ ):
277
+ node_colors = [
278
+ self.role_colors[self._get_node_role(n)] for n in self.G.nodes
279
+ ]
280
+ pos = self._get_positions(self.root)
281
+
282
+ plt.figure(figsize=figsize)
283
+ nx.draw(
284
+ self.G,
285
+ pos,
286
+ with_labels=True,
287
+ arrows=False,
288
+ node_size=2000,
289
+ node_color=node_colors,
290
+ font_size=10,
291
+ font_weight="bold",
292
+ )
293
+
294
+ # Draw edge labels
295
+ edge_labels = nx.get_edge_attributes(self.G, "relation")
296
+ nx.draw_networkx_edge_labels(
297
+ self.G,
298
+ pos,
299
+ edge_labels=edge_labels,
300
+ font_size=9,
301
+ font_color="black",
302
+ )
303
+
304
+ # Draw small description text under each node (if available)
305
+ for node, (x, y) in pos.items():
306
+ desc = self.objs_desc.get(node)
307
+ if desc:
308
+ wrapped = "\n".join(textwrap.wrap(desc, width=30))
309
+ plt.text(
310
+ x,
311
+ y - 0.006,
312
+ wrapped,
313
+ fontsize=6,
314
+ ha="center",
315
+ va="top",
316
+ wrap=True,
317
+ color="black",
318
+ bbox=dict(
319
+ facecolor="dimgray",
320
+ edgecolor="darkgray",
321
+ alpha=0.1,
322
+ boxstyle="round,pad=0.2",
323
+ ),
324
+ )
325
+
326
+ plt.title(title, fontsize=12)
327
+ task_desc = self.relation.get("task_desc", "")
328
+ if task_desc:
329
+ plt.suptitle(
330
+ f"Task Description: {task_desc}", fontsize=10, y=0.999
331
+ )
332
+
333
+ plt.axis("off")
334
+
335
+ legend_handles = [
336
+ Patch(facecolor=color, edgecolor='black', label=role)
337
+ for role, color in self.role_colors.items()
338
+ ]
339
+ plt.legend(
340
+ handles=legend_handles,
341
+ loc="lower center",
342
+ ncol=3,
343
+ bbox_to_anchor=(0.5, -0.1),
344
+ fontsize=9,
345
+ )
346
+
347
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
348
+ plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
349
+ plt.close()
350
+
351
+
352
+ def load_scene_dict(file_path: str) -> dict:
353
+ scene_dict = {}
354
+ with open(file_path, "r", encoding='utf-8') as f:
355
+ for line in f:
356
+ line = line.strip()
357
+ if not line or ":" not in line:
358
+ continue
359
+ scene_id, desc = line.split(":", 1)
360
+ scene_dict[scene_id.strip()] = desc.strip()
361
+
362
+ return scene_dict
363
+
364
+
365
+ def is_image_file(filename: str) -> bool:
366
+ mime_type, _ = mimetypes.guess_type(filename)
367
+
368
+ return mime_type is not None and mime_type.startswith('image')
369
+
370
+
371
+ def parse_text_prompts(prompts: list[str]) -> list[str]:
372
+ if len(prompts) == 1 and prompts[0].endswith(".txt"):
373
+ with open(prompts[0], "r") as f:
374
+ prompts = [
375
+ line.strip()
376
+ for line in f
377
+ if line.strip() and not line.strip().startswith("#")
378
+ ]
379
+ return prompts
380
+
381
+
382
+ def alpha_blend_rgba(
383
+ fg_image: Union[str, Image.Image, np.ndarray],
384
+ bg_image: Union[str, Image.Image, np.ndarray],
385
+ ) -> Image.Image:
386
+ """Alpha blends a foreground RGBA image over a background RGBA image.
387
+
388
+ Args:
389
+ fg_image: Foreground image. Can be a file path (str), a PIL Image,
390
+ or a NumPy ndarray.
391
+ bg_image: Background image. Can be a file path (str), a PIL Image,
392
+ or a NumPy ndarray.
393
+
394
+ Returns:
395
+ A PIL Image representing the alpha-blended result in RGBA mode.
396
+ """
397
+ if isinstance(fg_image, str):
398
+ fg_image = Image.open(fg_image)
399
+ elif isinstance(fg_image, np.ndarray):
400
+ fg_image = Image.fromarray(fg_image)
401
+
402
+ if isinstance(bg_image, str):
403
+ bg_image = Image.open(bg_image)
404
+ elif isinstance(bg_image, np.ndarray):
405
+ bg_image = Image.fromarray(bg_image)
406
+
407
+ if fg_image.size != bg_image.size:
408
+ raise ValueError(
409
+ f"Image sizes not match {fg_image.size} v.s. {bg_image.size}."
410
+ )
411
+
412
+ fg = fg_image.convert("RGBA")
413
+ bg = bg_image.convert("RGBA")
414
+
415
+ return Image.alpha_composite(bg, fg)
416
+
417
+
418
+ def check_object_edge_truncated(
419
+ mask: np.ndarray, edge_threshold: int = 5
420
+ ) -> bool:
421
+ """Checks if a binary object mask is truncated at the image edges.
422
+
423
+ Args:
424
+ mask: A 2D binary NumPy array where nonzero values indicate the object region.
425
+ edge_threshold: Number of pixels from each image edge to consider for truncation.
426
+ Defaults to 5.
427
+
428
+ Returns:
429
+ True if the object is fully enclosed (not truncated).
430
+ False if the object touches or crosses any image boundary.
431
+ """
432
+ top = mask[:edge_threshold, :].any()
433
+ bottom = mask[-edge_threshold:, :].any()
434
+ left = mask[:, :edge_threshold].any()
435
+ right = mask[:, -edge_threshold:].any()
436
+
437
+ return not (top or bottom or left or right)
438
+
439
+
440
+ def vcat_pil_images(
441
+ images: list[Image.Image], image_mode: str = "RGB"
442
+ ) -> Image.Image:
443
+ widths, heights = zip(*(img.size for img in images))
444
+ total_height = sum(heights)
445
+ max_width = max(widths)
446
+ new_image = Image.new(image_mode, (max_width, total_height))
447
+ y_offset = 0
448
+ for image in images:
449
+ new_image.paste(image, (0, y_offset))
450
+ y_offset += image.size[1]
451
+
452
+ return new_image
453
+
454
+
455
+ if __name__ == "__main__":
456
+ image_paths = [
457
+ "outputs/layouts_sim/task_0000/images/pen.png",
458
+ "outputs/layouts_sim/task_0000/images/notebook.png",
459
+ "outputs/layouts_sim/task_0000/images/mug.png",
460
+ "outputs/layouts_sim/task_0000/images/lamp.png",
461
+ "outputs/layouts_sim2/task_0014/images/cloth.png", # TODO
462
+ ]
463
+ for image_path in image_paths:
464
+ image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
465
+ mask = image[..., -1]
466
+ flag = check_object_edge_truncated(mask)
467
+ print(flag, image_path)
embodied_gen/utils/simulation.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ import xml.etree.ElementTree as ET
20
+ from collections import defaultdict
21
+ from typing import Literal
22
+
23
+ import mplib
24
+ import numpy as np
25
+ import sapien.core as sapien
26
+ import sapien.physx as physx
27
+ import torch
28
+ from mani_skill.agents.base_agent import BaseAgent
29
+ from mani_skill.envs.scene import ManiSkillScene
30
+ from mani_skill.examples.motionplanning.panda.utils import (
31
+ compute_grasp_info_by_obb,
32
+ )
33
+ from mani_skill.utils.geometry.trimesh_utils import get_component_mesh
34
+ from PIL import Image, ImageColor
35
+ from scipy.spatial.transform import Rotation as R
36
+ from embodied_gen.data.utils import DiffrastRender
37
+ from embodied_gen.utils.enum import LayoutInfo, Scene3DItemEnum
38
+ from embodied_gen.utils.geometry import quaternion_multiply
39
+ from embodied_gen.utils.log import logger
40
+
41
+ COLORMAP = list(set(ImageColor.colormap.values()))
42
+ COLOR_PALETTE = np.array(
43
+ [ImageColor.getrgb(c) for c in COLORMAP], dtype=np.uint8
44
+ )
45
+ SIM_COORD_ALIGN = np.array(
46
+ [
47
+ [1.0, 0.0, 0.0, 0.0],
48
+ [0.0, -1.0, 0.0, 0.0],
49
+ [0.0, 0.0, -1.0, 0.0],
50
+ [0.0, 0.0, 0.0, 1.0],
51
+ ]
52
+ ) # Used to align SAPIEN, MuJoCo coordinate system with the world coordinate system
53
+
54
+ __all__ = [
55
+ "SIM_COORD_ALIGN",
56
+ "FrankaPandaGrasper",
57
+ "load_assets_from_layout_file",
58
+ "load_mani_skill_robot",
59
+ "render_images",
60
+ ]
61
+
62
+
63
+ def load_actor_from_urdf(
64
+ scene: sapien.Scene | ManiSkillScene,
65
+ file_path: str,
66
+ pose: sapien.Pose | None = None,
67
+ env_idx: int = None,
68
+ use_static: bool = False,
69
+ update_mass: bool = False,
70
+ scale: float | np.ndarray = 1.0,
71
+ ) -> sapien.pysapien.Entity:
72
+ def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose:
73
+ local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
74
+ if origin_tag is not None:
75
+ xyz = list(map(float, origin_tag.get("xyz", "0 0 0").split()))
76
+ rpy = list(map(float, origin_tag.get("rpy", "0 0 0").split()))
77
+ qx, qy, qz, qw = R.from_euler("xyz", rpy, degrees=False).as_quat()
78
+ local_pose = sapien.Pose(p=xyz, q=[qw, qx, qy, qz])
79
+
80
+ return local_pose
81
+
82
+ tree = ET.parse(file_path)
83
+ root = tree.getroot()
84
+ node_name = root.get("name")
85
+ file_dir = os.path.dirname(file_path)
86
+
87
+ visual_mesh = root.find(".//visual/geometry/mesh")
88
+ visual_file = visual_mesh.get("filename")
89
+ visual_scale = visual_mesh.get("scale", "1.0 1.0 1.0")
90
+ visual_scale = np.array([float(x) for x in visual_scale.split()])
91
+ visual_scale *= np.array(scale)
92
+
93
+ collision_mesh = root.find(".//collision/geometry/mesh")
94
+ collision_file = collision_mesh.get("filename")
95
+ collision_scale = collision_mesh.get("scale", "1.0 1.0 1.0")
96
+ collision_scale = np.array([float(x) for x in collision_scale.split()])
97
+ collision_scale *= np.array(scale)
98
+
99
+ visual_pose = _get_local_pose(root.find(".//visual/origin"))
100
+ collision_pose = _get_local_pose(root.find(".//collision/origin"))
101
+
102
+ visual_file = os.path.join(file_dir, visual_file)
103
+ collision_file = os.path.join(file_dir, collision_file)
104
+ static_fric = root.find(".//collision/gazebo/mu1").text
105
+ dynamic_fric = root.find(".//collision/gazebo/mu2").text
106
+
107
+ material = physx.PhysxMaterial(
108
+ static_friction=np.clip(float(static_fric), 0.1, 0.7),
109
+ dynamic_friction=np.clip(float(dynamic_fric), 0.1, 0.6),
110
+ restitution=0.05,
111
+ )
112
+ builder = scene.create_actor_builder()
113
+
114
+ body_type = "static" if use_static else "dynamic"
115
+ builder.set_physx_body_type(body_type)
116
+ builder.add_multiple_convex_collisions_from_file(
117
+ collision_file,
118
+ material=material,
119
+ scale=collision_scale,
120
+ # decomposition="coacd",
121
+ # decomposition_params=dict(
122
+ # threshold=0.05, max_convex_hull=64, verbose=False
123
+ # ),
124
+ pose=collision_pose,
125
+ )
126
+
127
+ builder.add_visual_from_file(
128
+ visual_file,
129
+ scale=visual_scale,
130
+ pose=visual_pose,
131
+ )
132
+ if pose is None:
133
+ pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0])
134
+
135
+ builder.set_initial_pose(pose)
136
+ if isinstance(scene, ManiSkillScene) and env_idx is not None:
137
+ builder.set_scene_idxs([env_idx])
138
+
139
+ actor = builder.build(
140
+ name=node_name if env_idx is None else f"{node_name}-{env_idx}"
141
+ )
142
+
143
+ if update_mass and hasattr(actor.components[1], "mass"):
144
+ node_mass = float(root.find(".//inertial/mass").get("value"))
145
+ actor.components[1].set_mass(node_mass)
146
+
147
+ return actor
148
+
149
+
150
+ def load_assets_from_layout_file(
151
+ scene: ManiSkillScene | sapien.Scene,
152
+ layout: str,
153
+ z_offset: float = 0.0,
154
+ init_quat: list[float] = [0, 0, 0, 1],
155
+ env_idx: int = None,
156
+ ) -> dict[str, sapien.pysapien.Entity]:
157
+ """Load assets from `EmbodiedGen` layout-gen output and create actors in the scene.
158
+
159
+ Args:
160
+ scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into.
161
+ layout (str): The layout file path.
162
+ z_offset (float): Offset to apply to the Z-coordinate of non-context objects.
163
+ init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment.
164
+ env_idx (int): Environment index for multi-environment setup.
165
+ """
166
+ asset_root = os.path.dirname(layout)
167
+ layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
168
+ actors = dict()
169
+ for node in layout.assets:
170
+ file_dir = layout.assets[node]
171
+ file_name = f"{node.replace(' ', '_')}.urdf"
172
+ urdf_file = os.path.join(asset_root, file_dir, file_name)
173
+
174
+ if layout.objs_mapping[node] == Scene3DItemEnum.BACKGROUND.value:
175
+ continue
176
+
177
+ position = layout.position[node].copy()
178
+ if layout.objs_mapping[node] != Scene3DItemEnum.CONTEXT.value:
179
+ position[2] += z_offset
180
+
181
+ use_static = (
182
+ layout.relation.get(Scene3DItemEnum.CONTEXT.value, None) == node
183
+ )
184
+
185
+ # Combine initial quaternion with object quaternion
186
+ x, y, z, qx, qy, qz, qw = position
187
+ qx, qy, qz, qw = quaternion_multiply([qx, qy, qz, qw], init_quat)
188
+ actor = load_actor_from_urdf(
189
+ scene,
190
+ urdf_file,
191
+ sapien.Pose(p=[x, y, z], q=[qw, qx, qy, qz]),
192
+ env_idx,
193
+ use_static=use_static,
194
+ update_mass=False,
195
+ )
196
+ actors[node] = actor
197
+
198
+ return actors
199
+
200
+
201
+ def load_mani_skill_robot(
202
+ scene: sapien.Scene | ManiSkillScene,
203
+ layout: LayoutInfo | str,
204
+ control_freq: int = 20,
205
+ robot_init_qpos_noise: float = 0.0,
206
+ control_mode: str = "pd_joint_pos",
207
+ backend_str: tuple[str, str] = ("cpu", "gpu"),
208
+ ) -> BaseAgent:
209
+ from mani_skill.agents import REGISTERED_AGENTS
210
+ from mani_skill.envs.scene import ManiSkillScene
211
+ from mani_skill.envs.utils.system.backend import (
212
+ parse_sim_and_render_backend,
213
+ )
214
+
215
+ if isinstance(layout, str) and layout.endswith(".json"):
216
+ layout = LayoutInfo.from_dict(json.load(open(layout, "r")))
217
+
218
+ robot_name = layout.relation[Scene3DItemEnum.ROBOT.value]
219
+ x, y, z, qx, qy, qz, qw = layout.position[robot_name]
220
+ delta_z = 0.002 # Add small offset to avoid collision.
221
+ pose = sapien.Pose([x, y, z + delta_z], [qw, qx, qy, qz])
222
+
223
+ if robot_name not in REGISTERED_AGENTS:
224
+ logger.warning(
225
+ f"Robot `{robot_name}` not registered, chosen from {REGISTERED_AGENTS.keys()}, use `panda` instead."
226
+ )
227
+ robot_name = "panda"
228
+
229
+ ROBOT_CLS = REGISTERED_AGENTS[robot_name].agent_cls
230
+ backend = parse_sim_and_render_backend(*backend_str)
231
+ if isinstance(scene, sapien.Scene):
232
+ scene = ManiSkillScene([scene], device=backend_str[0], backend=backend)
233
+ robot = ROBOT_CLS(
234
+ scene=scene,
235
+ control_freq=control_freq,
236
+ control_mode=control_mode,
237
+ initial_pose=pose,
238
+ )
239
+
240
+ # Set robot init joint rad agree(joint0 to joint6 w 2 finger).
241
+ qpos = np.array(
242
+ [
243
+ 0.0,
244
+ np.pi / 8,
245
+ 0,
246
+ -np.pi * 3 / 8,
247
+ 0,
248
+ np.pi * 3 / 4,
249
+ np.pi / 4,
250
+ 0.04,
251
+ 0.04,
252
+ ]
253
+ )
254
+ qpos = (
255
+ np.random.normal(
256
+ 0, robot_init_qpos_noise, (len(scene.sub_scenes), len(qpos))
257
+ )
258
+ + qpos
259
+ )
260
+ qpos[:, -2:] = 0.04
261
+ robot.reset(qpos)
262
+ robot.init_qpos = robot.robot.qpos
263
+ robot.controller.controllers["gripper"].reset()
264
+
265
+ return robot
266
+
267
+
268
+ def render_images(
269
+ camera: sapien.render.RenderCameraComponent,
270
+ render_keys: list[
271
+ Literal[
272
+ "Color",
273
+ "Segmentation",
274
+ "Normal",
275
+ "Mask",
276
+ "Depth",
277
+ "Foreground",
278
+ ]
279
+ ] = None,
280
+ ) -> dict[str, Image.Image]:
281
+ """Render images from a given sapien camera.
282
+
283
+ Args:
284
+ camera (sapien.render.RenderCameraComponent): The camera to render from.
285
+ render_keys (List[str]): Types of images to render (e.g., Color, Segmentation).
286
+
287
+ Returns:
288
+ Dict[str, Image.Image]: Dictionary of rendered images.
289
+ """
290
+ if render_keys is None:
291
+ render_keys = [
292
+ "Color",
293
+ "Segmentation",
294
+ "Normal",
295
+ "Mask",
296
+ "Depth",
297
+ "Foreground",
298
+ ]
299
+
300
+ results: dict[str, Image.Image] = {}
301
+ if "Color" in render_keys:
302
+ color = camera.get_picture("Color")
303
+ color_rgb = (np.clip(color[..., :3], 0, 1) * 255).astype(np.uint8)
304
+ results["Color"] = Image.fromarray(color_rgb)
305
+
306
+ if "Mask" in render_keys:
307
+ alpha = (np.clip(color[..., 3], 0, 1) * 255).astype(np.uint8)
308
+ results["Mask"] = Image.fromarray(alpha)
309
+
310
+ if "Segmentation" in render_keys:
311
+ seg_labels = camera.get_picture("Segmentation")
312
+ label0 = seg_labels[..., 0].astype(np.uint8)
313
+ seg_color = COLOR_PALETTE[label0]
314
+ results["Segmentation"] = Image.fromarray(seg_color)
315
+
316
+ if "Foreground" in render_keys:
317
+ seg_labels = camera.get_picture("Segmentation")
318
+ label0 = seg_labels[..., 0]
319
+ mask = np.where((label0 > 1), 255, 0).astype(np.uint8)
320
+ color = camera.get_picture("Color")
321
+ color_rgb = (np.clip(color[..., :3], 0, 1) * 255).astype(np.uint8)
322
+ foreground = np.concatenate([color_rgb, mask[..., None]], axis=-1)
323
+ results["Foreground"] = Image.fromarray(foreground)
324
+
325
+ if "Normal" in render_keys:
326
+ normal = camera.get_picture("Normal")[..., :3]
327
+ normal_img = (((normal + 1) / 2) * 255).astype(np.uint8)
328
+ results["Normal"] = Image.fromarray(normal_img)
329
+
330
+ if "Depth" in render_keys:
331
+ position_map = camera.get_picture("Position")
332
+ depth = -position_map[..., 2]
333
+ alpha = torch.tensor(color[..., 3], dtype=torch.float32)
334
+ norm_depth = DiffrastRender.normalize_map_by_mask(
335
+ torch.tensor(depth), alpha
336
+ )
337
+ depth_img = (norm_depth * 255).to(torch.uint8).numpy()
338
+ results["Depth"] = Image.fromarray(depth_img)
339
+
340
+ return results
341
+
342
+
343
+ class SapienSceneManager:
344
+ """A class to manage SAPIEN simulator."""
345
+
346
+ def __init__(
347
+ self, sim_freq: int, ray_tracing: bool, device: str = "cuda"
348
+ ) -> None:
349
+ self.sim_freq = sim_freq
350
+ self.ray_tracing = ray_tracing
351
+ self.device = device
352
+ self.renderer = sapien.SapienRenderer()
353
+ self.scene = self._setup_scene()
354
+ self.cameras: list[sapien.render.RenderCameraComponent] = []
355
+ self.actors: dict[str, sapien.pysapien.Entity] = {}
356
+
357
+ def _setup_scene(self) -> sapien.Scene:
358
+ """Set up the SAPIEN scene with lighting and ground."""
359
+ # Ray tracing settings
360
+ if self.ray_tracing:
361
+ sapien.render.set_camera_shader_dir("rt")
362
+ sapien.render.set_ray_tracing_samples_per_pixel(64)
363
+ sapien.render.set_ray_tracing_path_depth(10)
364
+ sapien.render.set_ray_tracing_denoiser("oidn")
365
+
366
+ scene = sapien.Scene()
367
+ scene.set_timestep(1 / self.sim_freq)
368
+
369
+ # Add lighting
370
+ scene.set_ambient_light([0.2, 0.2, 0.2])
371
+ scene.add_directional_light(
372
+ direction=[0, 1, -1],
373
+ color=[1.5, 1.45, 1.4],
374
+ shadow=True,
375
+ shadow_map_size=2048,
376
+ )
377
+ scene.add_directional_light(
378
+ direction=[0, -0.5, 1], color=[0.8, 0.8, 0.85], shadow=False
379
+ )
380
+ scene.add_directional_light(
381
+ direction=[0, -1, 1], color=[1.0, 1.0, 1.0], shadow=False
382
+ )
383
+
384
+ ground_material = self.renderer.create_material()
385
+ ground_material.base_color = [0.5, 0.5, 0.5, 1] # rgba, gray
386
+ ground_material.roughness = 0.7
387
+ ground_material.metallic = 0.0
388
+ scene.add_ground(0, render_material=ground_material)
389
+
390
+ return scene
391
+
392
+ def step_action(
393
+ self,
394
+ agent: BaseAgent,
395
+ action: torch.Tensor,
396
+ cameras: list[sapien.render.RenderCameraComponent],
397
+ render_keys: list[str],
398
+ sim_steps_per_control: int = 1,
399
+ ) -> dict:
400
+ agent.set_action(action)
401
+ frames = defaultdict(list)
402
+ for _ in range(sim_steps_per_control):
403
+ self.scene.step()
404
+
405
+ self.scene.update_render()
406
+ for camera in cameras:
407
+ camera.take_picture()
408
+ images = render_images(camera, render_keys=render_keys)
409
+ frames[camera.name].append(images)
410
+
411
+ return frames
412
+
413
+ def create_camera(
414
+ self,
415
+ cam_name: str,
416
+ pose: sapien.Pose,
417
+ image_hw: tuple[int, int],
418
+ fovy_deg: float,
419
+ ) -> sapien.render.RenderCameraComponent:
420
+ """Create a single camera in the scene.
421
+
422
+ Args:
423
+ cam_name (str): Name of the camera.
424
+ pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z)
425
+ image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
426
+ fovy_deg (float): Field of view in degrees for cameras.
427
+
428
+ Returns:
429
+ sapien.render.RenderCameraComponent: The created camera.
430
+ """
431
+ cam_actor = self.scene.create_actor_builder().build_kinematic()
432
+ cam_actor.set_pose(pose)
433
+ camera = self.scene.add_mounted_camera(
434
+ name=cam_name,
435
+ mount=cam_actor,
436
+ pose=sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0]),
437
+ width=image_hw[1],
438
+ height=image_hw[0],
439
+ fovy=np.deg2rad(fovy_deg),
440
+ near=0.01,
441
+ far=100,
442
+ )
443
+ self.cameras.append(camera)
444
+
445
+ return camera
446
+
447
+ def initialize_circular_cameras(
448
+ self,
449
+ num_cameras: int,
450
+ radius: float,
451
+ height: float,
452
+ target_pt: list[float],
453
+ image_hw: tuple[int, int],
454
+ fovy_deg: float,
455
+ ) -> list[sapien.render.RenderCameraComponent]:
456
+ """Initialize multiple cameras arranged in a circle.
457
+
458
+ Args:
459
+ num_cameras (int): Number of cameras to create.
460
+ radius (float): Radius of the camera circle.
461
+ height (float): Fixed Z-coordinate of the cameras.
462
+ target_pt (list[float]): 3D point (x, y, z) that cameras look at.
463
+ image_hw (Tuple[int, int]): Image resolution (height, width) for cameras.
464
+ fovy_deg (float): Field of view in degrees for cameras.
465
+
466
+ Returns:
467
+ List[sapien.render.RenderCameraComponent]: List of created cameras.
468
+ """
469
+ angle_step = 2 * np.pi / num_cameras
470
+ world_up_vec = np.array([0.0, 0.0, 1.0])
471
+ target_pt = np.array(target_pt)
472
+
473
+ for i in range(num_cameras):
474
+ angle = i * angle_step
475
+ cam_x = radius * np.cos(angle)
476
+ cam_y = radius * np.sin(angle)
477
+ cam_z = height
478
+ eye_pos = [cam_x, cam_y, cam_z]
479
+
480
+ forward_vec = target_pt - eye_pos
481
+ forward_vec = forward_vec / np.linalg.norm(forward_vec)
482
+ temp_right_vec = np.cross(forward_vec, world_up_vec)
483
+
484
+ if np.linalg.norm(temp_right_vec) < 1e-6:
485
+ temp_right_vec = np.array([1.0, 0.0, 0.0])
486
+ if np.abs(np.dot(temp_right_vec, forward_vec)) > 0.99:
487
+ temp_right_vec = np.array([0.0, 1.0, 0.0])
488
+
489
+ right_vec = temp_right_vec / np.linalg.norm(temp_right_vec)
490
+ up_vec = np.cross(right_vec, forward_vec)
491
+ rotation_matrix = np.array([forward_vec, -right_vec, up_vec]).T
492
+
493
+ rot = R.from_matrix(rotation_matrix)
494
+ scipy_quat = rot.as_quat() # (x, y, z, w)
495
+ quat = [
496
+ scipy_quat[3],
497
+ scipy_quat[0],
498
+ scipy_quat[1],
499
+ scipy_quat[2],
500
+ ] # (w, x, y, z)
501
+
502
+ self.create_camera(
503
+ f"camera_{i}",
504
+ sapien.Pose(p=eye_pos, q=quat),
505
+ image_hw,
506
+ fovy_deg,
507
+ )
508
+
509
+ return self.cameras
510
+
511
+
512
+ class FrankaPandaGrasper(object):
513
+ def __init__(
514
+ self,
515
+ agent: BaseAgent,
516
+ control_freq: float,
517
+ joint_vel_limits: float = 2.0,
518
+ joint_acc_limits: float = 1.0,
519
+ finger_length: float = 0.025,
520
+ ) -> None:
521
+ self.agent = agent
522
+ self.robot = agent.robot
523
+ self.control_freq = control_freq
524
+ self.control_timestep = 1 / control_freq
525
+ self.joint_vel_limits = joint_vel_limits
526
+ self.joint_acc_limits = joint_acc_limits
527
+ self.finger_length = finger_length
528
+ self.planners = self._setup_planner()
529
+
530
+ def _setup_planner(self) -> mplib.Planner:
531
+ planners = []
532
+ for pose in self.robot.pose:
533
+ link_names = [link.get_name() for link in self.robot.get_links()]
534
+ joint_names = [
535
+ joint.get_name() for joint in self.robot.get_active_joints()
536
+ ]
537
+ planner = mplib.Planner(
538
+ urdf=self.agent.urdf_path,
539
+ srdf=self.agent.urdf_path.replace(".urdf", ".srdf"),
540
+ user_link_names=link_names,
541
+ user_joint_names=joint_names,
542
+ move_group="panda_hand_tcp",
543
+ joint_vel_limits=np.ones(7) * self.joint_vel_limits,
544
+ joint_acc_limits=np.ones(7) * self.joint_acc_limits,
545
+ )
546
+ planner.set_base_pose(pose.raw_pose[0].tolist())
547
+ planners.append(planner)
548
+
549
+ return planners
550
+
551
+ def control_gripper(
552
+ self,
553
+ gripper_state: Literal[-1, 1],
554
+ n_step: int = 10,
555
+ ) -> np.ndarray:
556
+ qpos = self.robot.get_qpos()[0, :-2].cpu().numpy()
557
+ actions = []
558
+ for _ in range(n_step):
559
+ action = np.hstack([qpos, gripper_state])[None, ...]
560
+ actions.append(action)
561
+
562
+ return np.concatenate(actions, axis=0)
563
+
564
+ def move_to_pose(
565
+ self,
566
+ pose: sapien.Pose,
567
+ control_timestep: float,
568
+ gripper_state: Literal[-1, 1],
569
+ use_point_cloud: bool = False,
570
+ n_max_step: int = 100,
571
+ action_key: str = "position",
572
+ env_idx: int = 0,
573
+ ) -> np.ndarray:
574
+ result = self.planners[env_idx].plan_qpos_to_pose(
575
+ np.concatenate([pose.p, pose.q]),
576
+ self.robot.get_qpos().cpu().numpy()[0],
577
+ time_step=control_timestep,
578
+ use_point_cloud=use_point_cloud,
579
+ )
580
+
581
+ if result["status"] != "Success":
582
+ result = self.planners[env_idx].plan_screw(
583
+ np.concatenate([pose.p, pose.q]),
584
+ self.robot.get_qpos().cpu().numpy()[0],
585
+ time_step=control_timestep,
586
+ use_point_cloud=use_point_cloud,
587
+ )
588
+
589
+ if result["status"] != "Success":
590
+ return
591
+
592
+ sample_ratio = (len(result[action_key]) // n_max_step) + 1
593
+ result[action_key] = result[action_key][::sample_ratio]
594
+
595
+ n_step = len(result[action_key])
596
+ actions = []
597
+ for i in range(n_step):
598
+ qpos = result[action_key][i]
599
+ action = np.hstack([qpos, gripper_state])[None, ...]
600
+ actions.append(action)
601
+
602
+ return np.concatenate(actions, axis=0)
603
+
604
+ def compute_grasp_action(
605
+ self,
606
+ actor: sapien.pysapien.Entity,
607
+ reach_target_only: bool = True,
608
+ offset: tuple[float, float, float] = [0, 0, -0.05],
609
+ env_idx: int = 0,
610
+ ) -> np.ndarray:
611
+ physx_rigid = actor.components[1]
612
+ mesh = get_component_mesh(physx_rigid, to_world_frame=True)
613
+ obb = mesh.bounding_box_oriented
614
+ approaching = np.array([0, 0, -1])
615
+ tcp_pose = self.agent.tcp.pose[env_idx]
616
+ target_closing = (
617
+ tcp_pose.to_transformation_matrix()[0, :3, 1].cpu().numpy()
618
+ )
619
+ grasp_info = compute_grasp_info_by_obb(
620
+ obb,
621
+ approaching=approaching,
622
+ target_closing=target_closing,
623
+ depth=self.finger_length,
624
+ )
625
+
626
+ closing, center = grasp_info["closing"], grasp_info["center"]
627
+ raw_tcp_pose = tcp_pose.sp
628
+ grasp_pose = self.agent.build_grasp_pose(approaching, closing, center)
629
+ reach_pose = grasp_pose * sapien.Pose(p=offset)
630
+ grasp_pose = grasp_pose * sapien.Pose(p=[0, 0, 0.01])
631
+ actions = []
632
+ reach_actions = self.move_to_pose(
633
+ reach_pose,
634
+ self.control_timestep,
635
+ gripper_state=1,
636
+ env_idx=env_idx,
637
+ )
638
+ actions.append(reach_actions)
639
+
640
+ if reach_actions is None:
641
+ logger.warning(
642
+ f"Failed to reach the grasp pose for node `{actor.name}`, skipping grasping."
643
+ )
644
+ return None
645
+
646
+ if not reach_target_only:
647
+ grasp_actions = self.move_to_pose(
648
+ grasp_pose,
649
+ self.control_timestep,
650
+ gripper_state=1,
651
+ env_idx=env_idx,
652
+ )
653
+ actions.append(grasp_actions)
654
+ close_actions = self.control_gripper(
655
+ gripper_state=-1,
656
+ env_idx=env_idx,
657
+ )
658
+ actions.append(close_actions)
659
+ back_actions = self.move_to_pose(
660
+ raw_tcp_pose,
661
+ self.control_timestep,
662
+ gripper_state=-1,
663
+ env_idx=env_idx,
664
+ )
665
+ actions.append(back_actions)
666
+
667
+ return np.concatenate(actions, axis=0)
embodied_gen/utils/tags.py ADDED
@@ -0,0 +1 @@
 
 
1
+ VERSION = "v0.1.5"
embodied_gen/utils/trender.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+ import os
18
+ import sys
19
+
20
+ import numpy as np
21
+ import spaces
22
+ import torch
23
+ from tqdm import tqdm
24
+
25
+ current_file_path = os.path.abspath(__file__)
26
+ current_dir = os.path.dirname(current_file_path)
27
+ sys.path.append(os.path.join(current_dir, "../.."))
28
+ from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer
29
+ from thirdparty.TRELLIS.trellis.representations import MeshExtractResult
30
+ from thirdparty.TRELLIS.trellis.utils.render_utils import (
31
+ render_frames,
32
+ yaw_pitch_r_fov_to_extrinsics_intrinsics,
33
+ )
34
+
35
+ __all__ = [
36
+ "render_video",
37
+ ]
38
+
39
+
40
+ @spaces.GPU
41
+ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs):
42
+ renderer = MeshRenderer()
43
+ renderer.rendering_options.resolution = options.get("resolution", 512)
44
+ renderer.rendering_options.near = options.get("near", 1)
45
+ renderer.rendering_options.far = options.get("far", 100)
46
+ renderer.rendering_options.ssaa = options.get("ssaa", 4)
47
+ rets = {}
48
+ for extr, intr in tqdm(zip(extrinsics, intrinsics), desc="Rendering"):
49
+ res = renderer.render(sample, extr, intr)
50
+ if "normal" not in rets:
51
+ rets["normal"] = []
52
+ normal = torch.lerp(
53
+ torch.zeros_like(res["normal"]), res["normal"], res["mask"]
54
+ )
55
+ normal = np.clip(
56
+ normal.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255
57
+ ).astype(np.uint8)
58
+ rets["normal"].append(normal)
59
+
60
+ return rets
61
+
62
+
63
+ @spaces.GPU
64
+ def render_video(
65
+ sample,
66
+ resolution=512,
67
+ bg_color=(0, 0, 0),
68
+ num_frames=300,
69
+ r=2,
70
+ fov=40,
71
+ **kwargs,
72
+ ):
73
+ yaws = torch.linspace(0, 2 * 3.1415, num_frames)
74
+ yaws = yaws.tolist()
75
+ pitch = [0.5] * num_frames
76
+ extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(
77
+ yaws, pitch, r, fov
78
+ )
79
+ render_fn = (
80
+ render_mesh if isinstance(sample, MeshExtractResult) else render_frames
81
+ )
82
+ result = render_fn(
83
+ sample,
84
+ extrinsics,
85
+ intrinsics,
86
+ {"resolution": resolution, "bg_color": bg_color},
87
+ **kwargs,
88
+ )
89
+
90
+ return result
embodied_gen/validators/aesthetic_predictor.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
+ import os
19
+
20
+ import clip
21
+ import numpy as np
22
+ import pytorch_lightning as pl
23
+ import torch
24
+ import torch.nn as nn
25
+ from huggingface_hub import snapshot_download
26
+ from PIL import Image
27
+
28
+
29
+ class AestheticPredictor:
30
+ """Aesthetic Score Predictor.
31
+
32
+ Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main
33
+
34
+ Args:
35
+ clip_model_dir (str): Path to the directory of the CLIP model.
36
+ sac_model_path (str): Path to the pre-trained SAC model.
37
+ device (str): Device to use for computation ("cuda" or "cpu").
38
+ """
39
+
40
+ def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"):
41
+
42
+ self.device = device
43
+
44
+ if clip_model_dir is None:
45
+ model_path = snapshot_download(
46
+ repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
47
+ )
48
+ suffix = "aesthetic"
49
+ model_path = snapshot_download(
50
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
51
+ )
52
+ clip_model_dir = os.path.join(model_path, suffix)
53
+
54
+ if sac_model_path is None:
55
+ model_path = snapshot_download(
56
+ repo_id="xinjjj/RoboAssetGen", allow_patterns="aesthetic/*"
57
+ )
58
+ suffix = "aesthetic"
59
+ model_path = snapshot_download(
60
+ repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
61
+ )
62
+ sac_model_path = os.path.join(
63
+ model_path, suffix, "sac+logos+ava1-l14-linearMSE.pth"
64
+ )
65
+
66
+ self.clip_model, self.preprocess = self._load_clip_model(
67
+ clip_model_dir
68
+ )
69
+ self.sac_model = self._load_sac_model(sac_model_path, input_size=768)
70
+
71
+ class MLP(pl.LightningModule): # noqa
72
+ def __init__(self, input_size):
73
+ super().__init__()
74
+ self.layers = nn.Sequential(
75
+ nn.Linear(input_size, 1024),
76
+ nn.Dropout(0.2),
77
+ nn.Linear(1024, 128),
78
+ nn.Dropout(0.2),
79
+ nn.Linear(128, 64),
80
+ nn.Dropout(0.1),
81
+ nn.Linear(64, 16),
82
+ nn.Linear(16, 1),
83
+ )
84
+
85
+ def forward(self, x):
86
+ return self.layers(x)
87
+
88
+ @staticmethod
89
+ def normalized(a, axis=-1, order=2):
90
+ """Normalize the array to unit norm."""
91
+ l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
92
+ l2[l2 == 0] = 1
93
+ return a / np.expand_dims(l2, axis)
94
+
95
+ def _load_clip_model(self, model_dir: str, model_name: str = "ViT-L/14"):
96
+ """Load the CLIP model."""
97
+ model, preprocess = clip.load(
98
+ model_name, download_root=model_dir, device=self.device
99
+ )
100
+ return model, preprocess
101
+
102
+ def _load_sac_model(self, model_path, input_size):
103
+ """Load the SAC model."""
104
+ model = self.MLP(input_size)
105
+ ckpt = torch.load(model_path, weights_only=True)
106
+ model.load_state_dict(ckpt)
107
+ model.to(self.device)
108
+ model.eval()
109
+ return model
110
+
111
+ def predict(self, image_path):
112
+ """Predict the aesthetic score for a given image.
113
+
114
+ Args:
115
+ image_path (str): Path to the image file.
116
+
117
+ Returns:
118
+ float: Predicted aesthetic score.
119
+ """
120
+ pil_image = Image.open(image_path)
121
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
122
+
123
+ with torch.no_grad():
124
+ # Extract CLIP features
125
+ image_features = self.clip_model.encode_image(image)
126
+ # Normalize features
127
+ normalized_features = self.normalized(
128
+ image_features.cpu().detach().numpy()
129
+ )
130
+ # Predict score
131
+ prediction = self.sac_model(
132
+ torch.from_numpy(normalized_features)
133
+ .type(torch.FloatTensor)
134
+ .to(self.device)
135
+ )
136
+
137
+ return prediction.item()