wlyu-adobe commited on
Commit
bd9eb72
·
1 Parent(s): dc5c57c

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ ENV/
26
+ env/
27
+
28
+ # Model checkpoints and data (downloaded from HuggingFace)
29
+ checkpoints/
30
+ *.pt
31
+ *.pth
32
+ *.ckpt
33
+ *.safetensors
34
+
35
+ # Output files
36
+ outputs/
37
+ *.ply
38
+ *.mp4
39
+
40
+ # Examples (if large)
41
+ examples/
42
+
43
+ # IDE
44
+ .vscode/
45
+ .idea/
46
+ *.swp
47
+ *.swo
48
+ *~
49
+
50
+ # OS
51
+ .DS_Store
52
+ Thumbs.db
53
+
54
+ # Gradio cache
55
+ gradio_cached_examples/
56
+ flagged/
57
+
README.md CHANGED
@@ -10,4 +10,42 @@ pinned: false
10
  license: cc-by-nc-sa-4.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: cc-by-nc-sa-4.0
11
  ---
12
 
13
+ # FaceLift: Single Image 3D Face Reconstruction
14
+
15
+ Transform a single portrait image into a complete 3D head model using multi-view diffusion and Gaussian Splatting.
16
+
17
+ ## Features
18
+
19
+ - **Single Image Input**: Upload any portrait photo
20
+ - **Automatic Face Detection**: Optional auto-cropping and alignment
21
+ - **Multi-view Generation**: Creates 6 consistent views using diffusion models
22
+ - **3D Reconstruction**: Generates high-quality 3D Gaussian splats
23
+ - **Turntable Animation**: Exports rotating 360° video
24
+ - **Downloadable Model**: Get the 3D model as a .ply file
25
+
26
+ ## Usage
27
+
28
+ 1. Upload a portrait image
29
+ 2. Adjust parameters (optional):
30
+ - Auto Cropping: Enable for automatic face detection
31
+ - Guidance Scale: Controls generation quality (default: 3.0)
32
+ - Random Seed: For reproducible results
33
+ - Generation Steps: Higher = better quality but slower
34
+ 3. Click Submit and wait for processing
35
+ 4. Download the 3D model or turntable video
36
+
37
+ ## Citation
38
+
39
+ ```
40
+ @article{facelift2025,
41
+ title={FaceLift: Single Image 3D Face Reconstruction},
42
+ author={FaceLift Research Group},
43
+ year={2025}
44
+ }
45
+ ```
46
+
47
+ ## License
48
+
49
+ This software is free for non-commercial, research and evaluation use under the CC-BY-NC-SA-4.0 license.
50
+
51
+ For commercial use inquiries, contact: wlyu3@ucmerced.edu
app.py CHANGED
@@ -1,7 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
9
+ """
10
+ FaceLift: Single Image 3D Face Reconstruction
11
+ Generates 3D head models from single images using multi-view diffusion and GS-LRM.
12
+ """
13
+
14
+ import json
15
+ from pathlib import Path
16
+ from datetime import datetime
17
+
18
  import gradio as gr
19
+ import numpy as np
20
+ import torch
21
+ import yaml
22
+ from easydict import EasyDict as edict
23
+ from einops import rearrange
24
+ from PIL import Image
25
+ from huggingface_hub import snapshot_download
26
+
27
+ from gslrm.model.gaussians_renderer import render_turntable, imageseq2video
28
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
29
+ from utils_folder.face_utils import preprocess_image, preprocess_image_without_cropping
30
+
31
+ # HuggingFace repository configuration
32
+ HF_REPO_ID = "wlyu/OpenFaceLift"
33
+
34
+ def download_weights_from_hf() -> Path:
35
+ """Download model weights from HuggingFace if not already present.
36
+
37
+ Returns:
38
+ Path to the downloaded repository
39
+ """
40
+ workspace_dir = Path(__file__).parent
41
+
42
+ # Check if weights already exist locally
43
+ mvdiffusion_path = workspace_dir / "checkpoints/mvdiffusion/pipeckpts"
44
+ gslrm_path = workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt"
45
+
46
+ if mvdiffusion_path.exists() and gslrm_path.exists():
47
+ print("Using local model weights")
48
+ return workspace_dir
49
+
50
+ print(f"Downloading model weights from HuggingFace: {HF_REPO_ID}")
51
+ print("This may take a few minutes on first run...")
52
+
53
+ # Download to local directory
54
+ snapshot_download(
55
+ repo_id=HF_REPO_ID,
56
+ local_dir=str(workspace_dir / "checkpoints"),
57
+ local_dir_use_symlinks=False,
58
+ )
59
+
60
+ print("Model weights downloaded successfully!")
61
+ return workspace_dir
62
+
63
+ class FaceLiftPipeline:
64
+ """Pipeline for FaceLift 3D head generation from single images."""
65
+
66
+ def __init__(self):
67
+ # Download weights from HuggingFace if needed
68
+ workspace_dir = download_weights_from_hf()
69
+
70
+ # Setup paths
71
+ self.output_dir = workspace_dir / "outputs"
72
+ self.examples_dir = workspace_dir / "examples"
73
+ self.output_dir.mkdir(exist_ok=True)
74
+
75
+ # Parameters
76
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
77
+ self.image_size = 512
78
+ self.camera_indices = [2, 1, 0, 5, 4, 3]
79
+
80
+ # Load models
81
+ print("Loading models...")
82
+ self.mvdiffusion_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
83
+ str(workspace_dir / "checkpoints/mvdiffusion/pipeckpts"),
84
+ torch_dtype=torch.float16,
85
+ )
86
+ self.mvdiffusion_pipeline.unet.enable_xformers_memory_efficient_attention()
87
+ self.mvdiffusion_pipeline.to(self.device)
88
+
89
+ with open(workspace_dir / "configs/gslrm.yaml", "r") as f:
90
+ config = edict(yaml.safe_load(f))
91
+
92
+ module_name, class_name = config.model.class_name.rsplit(".", 1)
93
+ module = __import__(module_name, fromlist=[class_name])
94
+ ModelClass = getattr(module, class_name)
95
+
96
+ self.gs_lrm_model = ModelClass(config)
97
+ checkpoint = torch.load(
98
+ workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt",
99
+ map_location="cpu"
100
+ )
101
+ self.gs_lrm_model.load_state_dict(checkpoint["model"])
102
+ self.gs_lrm_model.to(self.device)
103
+
104
+ self.color_prompt_embedding = torch.load(
105
+ workspace_dir / "mvdiffusion/fixed_prompt_embeds_6view/clr_embeds.pt",
106
+ map_location=self.device
107
+ )
108
+
109
+ with open(workspace_dir / "utils_folder/opencv_cameras.json", 'r') as f:
110
+ self.cameras_data = json.load(f)["frames"]
111
+
112
+ print("Models loaded successfully!")
113
+
114
+ def generate_3d_head(self, image_path, auto_crop=True, guidance_scale=3.0,
115
+ random_seed=4, num_steps=50):
116
+ """Generate 3D head from single image."""
117
+ try:
118
+ # Setup output directory
119
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
120
+ output_dir = self.output_dir / timestamp
121
+ output_dir.mkdir(exist_ok=True)
122
+
123
+ # Preprocess input
124
+ original_img = np.array(Image.open(image_path))
125
+ input_image = preprocess_image(original_img) if auto_crop else \
126
+ preprocess_image_without_cropping(original_img)
127
+
128
+ if input_image.size != (self.image_size, self.image_size):
129
+ input_image = input_image.resize((self.image_size, self.image_size))
130
+
131
+ input_path = output_dir / "input.png"
132
+ input_image.save(input_path)
133
+
134
+ # Generate multi-view images
135
+ generator = torch.Generator(device=self.mvdiffusion_pipeline.unet.device)
136
+ generator.manual_seed(random_seed)
137
+
138
+ result = self.mvdiffusion_pipeline(
139
+ input_image, None,
140
+ prompt_embeds=self.color_prompt_embedding,
141
+ guidance_scale=guidance_scale,
142
+ num_images_per_prompt=1,
143
+ num_inference_steps=num_steps,
144
+ generator=generator,
145
+ eta=1.0,
146
+ )
147
+
148
+ selected_views = result.images[:6]
149
+
150
+ # Save multi-view composite
151
+ multiview_image = Image.new("RGB", (self.image_size * 6, self.image_size))
152
+ for i, view in enumerate(selected_views):
153
+ multiview_image.paste(view, (self.image_size * i, 0))
154
+
155
+ multiview_path = output_dir / "multiview.png"
156
+ multiview_image.save(multiview_path)
157
+
158
+ # Prepare 3D reconstruction input
159
+ view_arrays = [np.array(view) for view in selected_views]
160
+ lrm_input = torch.from_numpy(np.stack(view_arrays, axis=0)).float()
161
+ lrm_input = lrm_input[None].to(self.device) / 255.0
162
+ lrm_input = rearrange(lrm_input, "b v h w c -> b v c h w")
163
+
164
+ # Prepare camera parameters
165
+ selected_cameras = [self.cameras_data[i] for i in self.camera_indices]
166
+ fxfycxcy_list = [[c["fx"], c["fy"], c["cx"], c["cy"]] for c in selected_cameras]
167
+ c2w_list = [np.linalg.inv(np.array(c["w2c"])) for c in selected_cameras]
168
+
169
+ fxfycxcy = torch.from_numpy(np.stack(fxfycxcy_list, axis=0).astype(np.float32))
170
+ c2w = torch.from_numpy(np.stack(c2w_list, axis=0).astype(np.float32))
171
+ fxfycxcy = fxfycxcy[None].to(self.device)
172
+ c2w = c2w[None].to(self.device)
173
+
174
+ batch_indices = torch.stack([
175
+ torch.zeros(lrm_input.size(1)).long(),
176
+ torch.arange(lrm_input.size(1)).long(),
177
+ ], dim=-1)[None].to(self.device)
178
+
179
+ batch = edict({
180
+ "image": lrm_input,
181
+ "c2w": c2w,
182
+ "fxfycxcy": fxfycxcy,
183
+ "index": batch_indices,
184
+ })
185
+
186
+ # Run 3D reconstruction
187
+ with torch.autocast(enabled=True, device_type="cuda", dtype=torch.float16):
188
+ result = self.gs_lrm_model.forward(batch, create_visual=False, split_data=True)
189
+
190
+ comp_image = result.render[0].unsqueeze(0).detach()
191
+ gaussians = result.gaussians[0]
192
+
193
+ # Save filtered gaussians
194
+ filtered_gaussians = gaussians.apply_all_filters(
195
+ cam_origins=None,
196
+ opacity_thres=0.04,
197
+ scaling_thres=0.2,
198
+ floater_thres=0.75,
199
+ crop_bbx=[-0.91, 0.91, -0.91, 0.91, -1.0, 1.0],
200
+ nearfar_percent=(0.0001, 1.0),
201
+ )
202
+
203
+ ply_path = output_dir / "gaussians.ply"
204
+ filtered_gaussians.save_ply(str(ply_path))
205
+
206
+ # Save output image
207
+ comp_image = rearrange(comp_image, "x v c h w -> (x h) (v w) c")
208
+ comp_image = (comp_image.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
209
+ output_path = output_dir / "output.png"
210
+ Image.fromarray(comp_image).save(output_path)
211
+
212
+ # Generate turntable video
213
+ turntable_frames = render_turntable(gaussians, rendering_resolution=self.image_size,
214
+ num_views=180)
215
+ turntable_frames = rearrange(turntable_frames, "h (v w) c -> v h w c", v=180)
216
+ turntable_frames = np.ascontiguousarray(turntable_frames)
217
+
218
+ turntable_path = output_dir / "turntable.mp4"
219
+ imageseq2video(turntable_frames, str(turntable_path), fps=30)
220
+
221
+ return str(input_path), str(multiview_path), str(output_path), \
222
+ str(turntable_path), str(ply_path)
223
+
224
+ except Exception as e:
225
+ raise gr.Error(f"Generation failed: {str(e)}")
226
+
227
+
228
+ def main():
229
+ """Run the FaceLift application."""
230
+ pipeline = FaceLiftPipeline()
231
+
232
+ # Load examples
233
+ examples = []
234
+ if pipeline.examples_dir.exists():
235
+ examples = [[str(f)] for f in sorted(pipeline.examples_dir.iterdir())
236
+ if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}]
237
+
238
+ # Create interface
239
+ demo = gr.Interface(
240
+ fn=pipeline.generate_3d_head,
241
+ title="FaceLift: Single Image 3D Face Reconstruction",
242
+ description="""
243
+ Transform a single portrait image into a complete 3D head model.
244
+
245
+ **Tips:**
246
+ - Use high-quality portrait images with clear facial features
247
+ - If face detection fails, try disabling auto-cropping and manually crop to square
248
+ """,
249
+ inputs=[
250
+ gr.Image(type="filepath", label="Input Portrait Image"),
251
+ gr.Checkbox(value=True, label="Auto Cropping"),
252
+ gr.Slider(1.0, 10.0, 3.0, step=0.1, label="Guidance Scale"),
253
+ gr.Number(value=4, label="Random Seed"),
254
+ gr.Slider(10, 100, 50, step=5, label="Generation Steps"),
255
+ ],
256
+ outputs=[
257
+ gr.Image(label="Processed Input"),
258
+ gr.Image(label="Multi-view Generation"),
259
+ gr.Image(label="3D Reconstruction"),
260
+ gr.PlayableVideo(label="Turntable Animation"),
261
+ gr.File(label="3D Model (.ply)"),
262
+ ],
263
+ examples=examples,
264
+ allow_flagging="never",
265
+ )
266
+
267
+ demo.queue(max_size=10)
268
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
269
 
 
 
270
 
271
+ if __name__ == "__main__":
272
+ main()
configs/gslrm.yaml ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # General Configuration
3
+ # =============================================================================
4
+ profile: false
5
+ debug: false
6
+
7
+ # =============================================================================
8
+ # Model Configuration
9
+ # =============================================================================
10
+ model:
11
+ class_name: gslrm.model.gslrm.GSLRM
12
+
13
+ # Image processing settings
14
+ image_tokenizer:
15
+ image_size: 512
16
+ patch_size: 8
17
+ in_channels: 9 # 3 RGB + 3 direction + 3 Reference
18
+
19
+ # Transformer architecture
20
+ transformer:
21
+ d: 1024
22
+ d_head: 64
23
+ n_layer: 24
24
+
25
+ # Gaussian splatting configuration
26
+ gaussians:
27
+ n_gaussians: 2 # 12288
28
+ sh_degree: 0
29
+
30
+ upsampler:
31
+ upsample_factor: 1
32
+
33
+ # Model behavior flags
34
+ add_refsrc_marker: false
35
+ hard_pixelalign: true
36
+ use_custom_plucker: true
37
+ clip_xyz: true
38
+
39
+ # =============================================================================
40
+ # Training Configuration
41
+ # =============================================================================
42
+ training:
43
+ # Training runtime settings
44
+ runtime:
45
+ use_tf32: true
46
+ use_amp: true
47
+ amp_dtype: "bf16"
48
+ torch_compile: false
49
+ grad_accum_steps: 1
50
+ grad_clip_norm: 1.0
51
+ grad_checkpoint_every: 1
52
+
53
+ # Dataset configuration
54
+ dataset:
55
+ dataset_path: "data_sample/gslrm/data_gslrm_train.txt"
56
+
57
+ # View configuration
58
+ maximize_view_overlap: true
59
+ num_views: 8
60
+ num_input_views: 6 # In training, we set it as 4. In inference, we set it as 6.
61
+ target_has_input: true
62
+
63
+ # Data preprocessing
64
+ normalize_distance_to: 0.0
65
+ remove_alpha: false
66
+ background_color: "white"
67
+
68
+ # Data loader settings
69
+ dataloader:
70
+ batch_size_per_gpu: 2
71
+ num_workers: 4
72
+ num_threads: 32
73
+ prefetch_factor: 32
74
+
75
+ # Loss function weights
76
+ losses:
77
+ l2_loss_weight: 1.0
78
+ lpips_loss_weight: 0.0
79
+ perceptual_loss_weight: 0.5
80
+ ssim_loss_weight: 0.0
81
+ pixelalign_loss_weight: 0.0
82
+ masked_pixelalign_loss: true
83
+ pointsdist_loss_weight: 0.0
84
+ warmup_pointsdist: false
85
+ distill_loss_weight: 0.0
86
+
87
+ # Optimizer configuration (AdamW)
88
+ optimizer:
89
+ lr: 0.0001
90
+ beta1: 0.9
91
+ beta2: 0.95
92
+ weight_decay: 0.05
93
+ reset_lr: false
94
+ reset_weight_decay: false
95
+ reset_training_state: true
96
+
97
+ # Training schedule
98
+ schedule:
99
+ num_epochs: 100000 # dataset epochs
100
+ early_stop_after_epochs: 100000 # 40
101
+ max_fwdbwd_passes: 20000 # forward/backward pass steps
102
+ warmup: 500 # parameter update steps
103
+ l2_warmup_steps: 500
104
+
105
+ # Checkpointing
106
+ checkpointing:
107
+ resume_ckpt: "checkpoints/gslrm/stage_2"
108
+ checkpoint_every: 5000 # forward/backward pass steps
109
+ checkpoint_dir: "checkpoints/gslrm/stage_3"
110
+
111
+ # Logging and monitoring
112
+ logging:
113
+ print_every: 20 # forward/backward pass steps
114
+ vis_every: 250 # forward/backward pass steps
115
+
116
+ # Weights & Biases configuration
117
+ wandb:
118
+ project: "facelift_gslrm"
119
+ exp_name: "stage_3"
120
+ group: "facelift"
121
+ job_type: "train"
122
+ log_every: 50 # forward/backward pass steps
123
+ offline: false
124
+
125
+
126
+ # =============================================================================
127
+ # Inference Configuration
128
+ # =============================================================================
129
+ inference:
130
+ enabled: false
131
+ output_dir: "outputs/inference/gslrm/stage_3"
132
+
133
+ # =============================================================================
134
+ # Validation Configuration
135
+ # =============================================================================
136
+ validation:
137
+ enabled: true
138
+ val_every: 5000
139
+ output_dir: "outputs/validation/gslrm/stage_3"
140
+ dataset_path: "data_sample/gslrm/data_gslrm_val.txt"
gslrm/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
gslrm/model/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
gslrm/model/gaussians_renderer.py ADDED
@@ -0,0 +1,1028 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
9
+ import math
10
+ import os
11
+
12
+ import cv2
13
+ import matplotlib
14
+ import numpy as np
15
+ import torch
16
+ from diff_gaussian_rasterization import (
17
+ GaussianRasterizationSettings,
18
+ GaussianRasterizer,
19
+ )
20
+ from einops import rearrange
21
+ from plyfile import PlyData, PlyElement
22
+ from torch import nn
23
+
24
+ from collections import OrderedDict
25
+ import videoio
26
+
27
+ @torch.no_grad()
28
+ def get_turntable_cameras(
29
+ hfov=50,
30
+ num_views=8,
31
+ w=384,
32
+ h=384,
33
+ radius=2.7,
34
+ elevation=20,
35
+ up_vector=np.array([0, 0, 1]),
36
+ ):
37
+ fx = w / (2 * np.tan(np.deg2rad(hfov) / 2.0))
38
+ fy = fx
39
+ cx, cy = w / 2.0, h / 2.0
40
+ fxfycxcy = (
41
+ np.array([fx, fy, cx, cy]).reshape(1, 4).repeat(num_views, axis=0)
42
+ ) # [num_views, 4]
43
+ # azimuths = np.linspace(0, 360, num_views, endpoint=False)
44
+ azimuths = np.linspace(270, 630, num_views, endpoint=False)
45
+ elevations = np.ones_like(azimuths) * elevation
46
+ c2ws = []
47
+ for elev, azim in zip(elevations, azimuths):
48
+ elev, azim = np.deg2rad(elev), np.deg2rad(azim)
49
+ z = radius * np.sin(elev)
50
+ base = radius * np.cos(elev)
51
+ x = base * np.cos(azim)
52
+ y = base * np.sin(azim)
53
+ cam_pos = np.array([x, y, z])
54
+ forward = -cam_pos / np.linalg.norm(cam_pos)
55
+ right = np.cross(forward, up_vector)
56
+ right = right / np.linalg.norm(right)
57
+ up = np.cross(right, forward)
58
+ up = up / np.linalg.norm(up)
59
+ R = np.stack((right, -up, forward), axis=1)
60
+ c2w = np.eye(4)
61
+ c2w[:3, :4] = np.concatenate((R, cam_pos[:, None]), axis=1)
62
+ c2ws.append(c2w)
63
+ c2ws = np.stack(c2ws, axis=0) # [num_views, 4, 4]
64
+ return w, h, num_views, fxfycxcy, c2ws
65
+
66
+ def imageseq2video(images, filename, fps=24):
67
+ # if images is uint8, convert to float32
68
+ if images.dtype == np.uint8:
69
+ images = images.astype(np.float32) / 255.0
70
+
71
+ videoio.videosave(filename, images, lossless=True, preset="veryfast", fps=fps)
72
+
73
+
74
+ # copied from: utils.general_utils
75
+ def strip_lowerdiag(L):
76
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=L.device)
77
+
78
+ uncertainty[:, 0] = L[:, 0, 0]
79
+ uncertainty[:, 1] = L[:, 0, 1]
80
+ uncertainty[:, 2] = L[:, 0, 2]
81
+ uncertainty[:, 3] = L[:, 1, 1]
82
+ uncertainty[:, 4] = L[:, 1, 2]
83
+ uncertainty[:, 5] = L[:, 2, 2]
84
+ return uncertainty
85
+
86
+
87
+ def strip_symmetric(sym):
88
+ return strip_lowerdiag(sym)
89
+
90
+
91
+ def build_rotation(r):
92
+ norm = torch.sqrt(
93
+ r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
94
+ )
95
+
96
+ q = r / norm[:, None]
97
+
98
+ R = torch.zeros((q.size(0), 3, 3), device=r.device)
99
+
100
+ r = q[:, 0]
101
+ x = q[:, 1]
102
+ y = q[:, 2]
103
+ z = q[:, 3]
104
+
105
+ R[:, 0, 0] = 1 - 2 * (y * y + z * z)
106
+ R[:, 0, 1] = 2 * (x * y - r * z)
107
+ R[:, 0, 2] = 2 * (x * z + r * y)
108
+ R[:, 1, 0] = 2 * (x * y + r * z)
109
+ R[:, 1, 1] = 1 - 2 * (x * x + z * z)
110
+ R[:, 1, 2] = 2 * (y * z - r * x)
111
+ R[:, 2, 0] = 2 * (x * z - r * y)
112
+ R[:, 2, 1] = 2 * (y * z + r * x)
113
+ R[:, 2, 2] = 1 - 2 * (x * x + y * y)
114
+ return R
115
+
116
+
117
+ def build_scaling_rotation(s, r):
118
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=s.device)
119
+ R = build_rotation(r)
120
+
121
+ L[:, 0, 0] = s[:, 0]
122
+ L[:, 1, 1] = s[:, 1]
123
+ L[:, 2, 2] = s[:, 2]
124
+
125
+ L = R @ L
126
+ return L
127
+
128
+
129
+ # copied from: utils.sh_utils
130
+ C0 = 0.28209479177387814
131
+ C1 = 0.4886025119029199
132
+ C2 = [
133
+ 1.0925484305920792,
134
+ -1.0925484305920792,
135
+ 0.31539156525252005,
136
+ -1.0925484305920792,
137
+ 0.5462742152960396,
138
+ ]
139
+ C3 = [
140
+ -0.5900435899266435,
141
+ 2.890611442640554,
142
+ -0.4570457994644658,
143
+ 0.3731763325901154,
144
+ -0.4570457994644658,
145
+ 1.445305721320277,
146
+ -0.5900435899266435,
147
+ ]
148
+ C4 = [
149
+ 2.5033429417967046,
150
+ -1.7701307697799304,
151
+ 0.9461746957575601,
152
+ -0.6690465435572892,
153
+ 0.10578554691520431,
154
+ -0.6690465435572892,
155
+ 0.47308734787878004,
156
+ -1.7701307697799304,
157
+ 0.6258357354491761,
158
+ ]
159
+
160
+
161
+ def eval_sh(deg, sh, dirs):
162
+ """
163
+ Evaluate spherical harmonics at unit directions
164
+ using hardcoded SH polynomials.
165
+ Works with torch/np/jnp.
166
+ ... Can be 0 or more batch dimensions.
167
+ Args:
168
+ deg: int SH deg. Currently, 0-3 supported
169
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
170
+ dirs: jnp.ndarray unit directions [..., 3]
171
+ Returns:
172
+ [..., C]
173
+ """
174
+ assert deg <= 4 and deg >= 0
175
+ coeff = (deg + 1) ** 2
176
+ assert sh.shape[-1] >= coeff
177
+
178
+ result = C0 * sh[..., 0]
179
+ if deg > 0:
180
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
181
+ result = (
182
+ result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]
183
+ )
184
+
185
+ if deg > 1:
186
+ xx, yy, zz = x * x, y * y, z * z
187
+ xy, yz, xz = x * y, y * z, x * z
188
+ result = (
189
+ result
190
+ + C2[0] * xy * sh[..., 4]
191
+ + C2[1] * yz * sh[..., 5]
192
+ + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6]
193
+ + C2[3] * xz * sh[..., 7]
194
+ + C2[4] * (xx - yy) * sh[..., 8]
195
+ )
196
+
197
+ if deg > 2:
198
+ result = (
199
+ result
200
+ + C3[0] * y * (3 * xx - yy) * sh[..., 9]
201
+ + C3[1] * xy * z * sh[..., 10]
202
+ + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11]
203
+ + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12]
204
+ + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13]
205
+ + C3[5] * z * (xx - yy) * sh[..., 14]
206
+ + C3[6] * x * (xx - 3 * yy) * sh[..., 15]
207
+ )
208
+
209
+ if deg > 3:
210
+ result = (
211
+ result
212
+ + C4[0] * xy * (xx - yy) * sh[..., 16]
213
+ + C4[1] * yz * (3 * xx - yy) * sh[..., 17]
214
+ + C4[2] * xy * (7 * zz - 1) * sh[..., 18]
215
+ + C4[3] * yz * (7 * zz - 3) * sh[..., 19]
216
+ + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20]
217
+ + C4[5] * xz * (7 * zz - 3) * sh[..., 21]
218
+ + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22]
219
+ + C4[7] * xz * (xx - 3 * yy) * sh[..., 23]
220
+ + C4[8]
221
+ * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
222
+ * sh[..., 24]
223
+ )
224
+ return result
225
+
226
+
227
+ def RGB2SH(rgb):
228
+ return (rgb - 0.5) / C0
229
+
230
+
231
+ def SH2RGB(sh):
232
+ return sh * C0 + 0.5
233
+
234
+
235
+ def create_video(image_folder, output_video_file, framerate=30):
236
+ # Get all image file paths to a list.
237
+ images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
238
+ images.sort()
239
+
240
+ # Read the first image to know the height and width
241
+ frame = cv2.imread(os.path.join(image_folder, images[0]))
242
+ height, width, layers = frame.shape
243
+
244
+ video = cv2.VideoWriter(
245
+ output_video_file, cv2.VideoWriter_fourcc(*"mp4v"), framerate, (width, height)
246
+ )
247
+
248
+ # iterate over each image and add it to the video sequence
249
+ for image in images:
250
+ video.write(cv2.imread(os.path.join(image_folder, image)))
251
+
252
+ cv2.destroyAllWindows()
253
+ video.release()
254
+
255
+
256
+ class Camera(nn.Module):
257
+ def __init__(self, C2W, fxfycxcy, h, w):
258
+ """
259
+ C2W: 4x4 camera-to-world matrix; opencv convention
260
+ fxfycxcy: 4
261
+ """
262
+ super().__init__()
263
+ self.C2W = C2W.clone().float()
264
+ self.W2C = self.C2W.inverse()
265
+ self.h = h
266
+ self.w = w
267
+
268
+ self.znear = 0.01
269
+ self.zfar = 100.0
270
+
271
+ fx, fy, cx, cy = fxfycxcy[0], fxfycxcy[1], fxfycxcy[2], fxfycxcy[3]
272
+ self.tanfovX = w / (2 * fx)
273
+ self.tanfovY = h / (2 * fy)
274
+
275
+ def getProjectionMatrix(W, H, fx, fy, cx, cy, znear, zfar):
276
+ P = torch.zeros(4, 4, device=fx.device)
277
+ P[0, 0] = 2 * fx / W
278
+ P[1, 1] = 2 * fy / H
279
+ P[0, 2] = 2 * (cx / W) - 1
280
+ P[1, 2] = 2 * (cy / H) - 1
281
+ P[2, 2] = -(zfar + znear) / (zfar - znear)
282
+ P[3, 2] = 1.0
283
+ P[2, 3] = -(2 * zfar * znear) / (zfar - znear)
284
+ return P
285
+
286
+ self.world_view_transform = self.W2C.transpose(0, 1)
287
+ self.projection_matrix = getProjectionMatrix(
288
+ self.w, self.h, fx, fy, cx, cy, self.znear, self.zfar
289
+ ).transpose(0, 1)
290
+ self.full_proj_transform = (
291
+ self.world_view_transform.unsqueeze(0).bmm(
292
+ self.projection_matrix.unsqueeze(0)
293
+ )
294
+ ).squeeze(0)
295
+ self.camera_center = self.C2W[:3, 3]
296
+
297
+
298
+ # modified from scene/gaussian_model.py
299
+ class GaussianModel:
300
+ def setup_functions(self):
301
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
302
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
303
+ actual_covariance = L @ L.transpose(1, 2)
304
+ symm = strip_symmetric(actual_covariance)
305
+ return symm
306
+
307
+ self.scaling_activation = torch.exp
308
+ self.inv_scaling_activation = torch.log
309
+ self.rotation_activation = torch.nn.functional.normalize
310
+ self.opacity_activation = torch.sigmoid
311
+ self.covariance_activation = build_covariance_from_scaling_rotation
312
+
313
+ def __init__(self, sh_degree: int, scaling_modifier=None):
314
+ self.sh_degree = sh_degree
315
+ self._xyz = torch.empty(0)
316
+ self._features_dc = torch.empty(0)
317
+ if self.sh_degree > 0:
318
+ self._features_rest = torch.empty(0)
319
+ else:
320
+ self._features_rest = None
321
+ self._scaling = torch.empty(0)
322
+ self._rotation = torch.empty(0)
323
+ self._opacity = torch.empty(0)
324
+ self.setup_functions()
325
+
326
+ self.scaling_modifier = scaling_modifier
327
+
328
+ def empty(self):
329
+ self.__init__(self.sh_degree, self.scaling_modifier)
330
+
331
+ def set_data(self, xyz, features, scaling, rotation, opacity):
332
+ """
333
+ xyz : torch.tensor of shape (N, 3)
334
+ features : torch.tensor of shape (N, (self.sh_degree + 1) ** 2, 3)
335
+ scaling : torch.tensor of shape (N, 3)
336
+ rotation : torch.tensor of shape (N, 4)
337
+ opacity : torch.tensor of shape (N, 1)
338
+ """
339
+ self._xyz = xyz
340
+ self._features_dc = features[:, :1, :].contiguous()
341
+ if self.sh_degree > 0:
342
+ self._features_rest = features[:, 1:, :].contiguous()
343
+ else:
344
+ self._features_rest = None
345
+ self._scaling = scaling
346
+ self._rotation = rotation
347
+ self._opacity = opacity
348
+ return self
349
+
350
+ def to(self, device):
351
+ self._xyz = self._xyz.to(device)
352
+ self._features_dc = self._features_dc.to(device)
353
+ if self.sh_degree > 0:
354
+ self._features_rest = self._features_rest.to(device)
355
+ self._scaling = self._scaling.to(device)
356
+ self._rotation = self._rotation.to(device)
357
+ self._opacity = self._opacity.to(device)
358
+ return self
359
+
360
+ def filter(self, valid_mask):
361
+ self._xyz = self._xyz[valid_mask]
362
+ self._features_dc = self._features_dc[valid_mask]
363
+ if self.sh_degree > 0:
364
+ self._features_rest = self._features_rest[valid_mask]
365
+ self._scaling = self._scaling[valid_mask]
366
+ self._rotation = self._rotation[valid_mask]
367
+ self._opacity = self._opacity[valid_mask]
368
+ return self
369
+
370
+ def crop(self, crop_bbx=[-1, 1, -1, 1, -1, 1]):
371
+ x_min, x_max, y_min, y_max, z_min, z_max = crop_bbx
372
+ xyz = self._xyz
373
+ invalid_mask = (
374
+ (xyz[:, 0] < x_min)
375
+ | (xyz[:, 0] > x_max)
376
+ | (xyz[:, 1] < y_min)
377
+ | (xyz[:, 1] > y_max)
378
+ | (xyz[:, 2] < z_min)
379
+ | (xyz[:, 2] > z_max)
380
+ )
381
+ valid_mask = ~invalid_mask
382
+
383
+ return self.filter(valid_mask)
384
+
385
+ def crop_by_xyz(self, floater_thres=0.75):
386
+ xyz = self._xyz
387
+ invalid_mask = (
388
+ (((xyz[:, 0] < -floater_thres) & (xyz[:, 1] < -floater_thres))
389
+ | ((xyz[:, 0] < -floater_thres) & (xyz[:, 1] > floater_thres))
390
+ | ((xyz[:, 0] > floater_thres) & (xyz[:, 1] < -floater_thres))
391
+ | ((xyz[:, 0] > floater_thres) & (xyz[:, 1] > floater_thres)))
392
+ & (xyz[:, 2] < -floater_thres)
393
+ )
394
+ valid_mask = ~invalid_mask
395
+
396
+ return self.filter(valid_mask)
397
+
398
+ def prune(self, opacity_thres=0.05):
399
+ opacity = self.get_opacity.squeeze(1)
400
+ valid_mask = opacity > opacity_thres
401
+
402
+ return self.filter(valid_mask)
403
+
404
+ def prune_by_scaling(self, scaling_thres=0.1):
405
+ scaling = self.get_scaling
406
+ valid_mask = scaling.max(dim=1).values < scaling_thres
407
+ position_mask = self._xyz[:, 2] > 0
408
+
409
+ valid_mask = valid_mask | position_mask
410
+
411
+ return self.filter(valid_mask)
412
+
413
+ def prune_by_nearfar(self, cam_origins, nearfar_percent=(0.01, 0.99)):
414
+ # cam_origins: [num_cams, 3]
415
+ # nearfar_percent: [near, far]
416
+ assert len(nearfar_percent) == 2
417
+ assert nearfar_percent[0] < nearfar_percent[1]
418
+ assert nearfar_percent[0] >= 0 and nearfar_percent[1] <= 1
419
+
420
+ device = self._xyz.device
421
+ # compute distance of all points to all cameras
422
+ # [num_points, num_cams]
423
+ dists = torch.cdist(self._xyz[None], cam_origins[None].to(device))[0]
424
+ # [2, num_cams]
425
+ dists_percentile = torch.quantile(
426
+ dists, torch.tensor(nearfar_percent).to(device), dim=0
427
+ )
428
+ # prune all points that are outside the percentile range
429
+ # [num_points, num_cams]
430
+ # goal: prune points that are too close or too far from any camera
431
+ reject_mask = (dists < dists_percentile[0:1, :]) | (
432
+ dists > dists_percentile[1:2, :]
433
+ )
434
+ reject_mask = reject_mask.any(dim=1)
435
+ valid_mask = ~reject_mask
436
+
437
+ return self.filter(valid_mask)
438
+
439
+ def apply_all_filters(
440
+ self,
441
+ opacity_thres=0.05,
442
+ scaling_thres=None,
443
+ floater_thres=None,
444
+ crop_bbx=[-1, 1, -1, 1, -1, 1],
445
+ cam_origins=None,
446
+ nearfar_percent=(0.005, 1.0),
447
+ ):
448
+ self.prune(opacity_thres)
449
+ if scaling_thres is not None:
450
+ self.prune_by_scaling(scaling_thres)
451
+ if floater_thres is not None:
452
+ self.crop_by_xyz(floater_thres)
453
+ if crop_bbx is not None:
454
+ self.crop(crop_bbx)
455
+ if cam_origins is not None:
456
+ self.prune_by_nearfar(cam_origins, nearfar_percent)
457
+ return self
458
+
459
+ def shrink_bbx(self, drop_ratio=0.05):
460
+ xyz = self._xyz
461
+ xyz_min, xyz_max = torch.quantile(
462
+ xyz,
463
+ torch.tensor([drop_ratio, 1 - drop_ratio]).float().to(xyz.device),
464
+ dim=0,
465
+ ) # [2, N]
466
+ xyz_min = xyz_min.detach().cpu().numpy()
467
+ xyz_max = xyz_max.detach().cpu().numpy()
468
+ crop_bbx = [
469
+ xyz_min[0],
470
+ xyz_max[0],
471
+ xyz_min[1],
472
+ xyz_max[1],
473
+ xyz_min[2],
474
+ xyz_max[2],
475
+ ]
476
+ print(f"Shrinking bbx to {crop_bbx}")
477
+ return self.crop(crop_bbx)
478
+
479
+ def report_stats(self):
480
+ print(
481
+ f"xyz: {self._xyz.shape}, {self._xyz.min().item()}, {self._xyz.max().item()}"
482
+ )
483
+ print(
484
+ f"features_dc: {self._features_dc.shape}, {self._features_dc.min().item()}, {self._features_dc.max().item()}"
485
+ )
486
+ if self.sh_degree > 0:
487
+ print(
488
+ f"features_rest: {self._features_rest.shape}, {self._features_rest.min().item()}, {self._features_rest.max().item()}"
489
+ )
490
+ print(
491
+ f"scaling: {self._scaling.shape}, {self._scaling.min().item()}, {self._scaling.max().item()}"
492
+ )
493
+ print(
494
+ f"rotation: {self._rotation.shape}, {self._rotation.min().item()}, {self._rotation.max().item()}"
495
+ )
496
+ print(
497
+ f"opacity: {self._opacity.shape}, {self._opacity.min().item()}, {self._opacity.max().item()}"
498
+ )
499
+
500
+ print(
501
+ f"after activation, xyz: {self.get_xyz.shape}, {self.get_xyz.min().item()}, {self.get_xyz.max().item()}"
502
+ )
503
+ print(
504
+ f"after activation, features: {self.get_features.shape}, {self.get_features.min().item()}, {self.get_features.max().item()}"
505
+ )
506
+ print(
507
+ f"after activation, scaling: {self.get_scaling.shape}, {self.get_scaling.min().item()}, {self.get_scaling.max().item()}"
508
+ )
509
+ print(
510
+ f"after activation, rotation: {self.get_rotation.shape}, {self.get_rotation.min().item()}, {self.get_rotation.max().item()}"
511
+ )
512
+ print(
513
+ f"after activation, opacity: {self.get_opacity.shape}, {self.get_opacity.min().item()}, {self.get_opacity.max().item()}"
514
+ )
515
+ print(
516
+ f"after activation, covariance: {self.get_covariance().shape}, {self.get_covariance().min().item()}, {self.get_covariance().max().item()}"
517
+ )
518
+
519
+ @property
520
+ def get_scaling(self):
521
+ if self.scaling_modifier is not None:
522
+ return self.scaling_activation(self._scaling) * self.scaling_modifier
523
+ else:
524
+ return self.scaling_activation(self._scaling)
525
+
526
+ @property
527
+ def get_rotation(self):
528
+ return self.rotation_activation(self._rotation)
529
+
530
+ @property
531
+ def get_xyz(self):
532
+ return self._xyz
533
+
534
+ @property
535
+ def get_features(self):
536
+ if self.sh_degree > 0:
537
+ features_dc = self._features_dc
538
+ features_rest = self._features_rest
539
+ return torch.cat((features_dc, features_rest), dim=1)
540
+ else:
541
+ return self._features_dc
542
+
543
+ @property
544
+ def get_opacity(self):
545
+ return self.opacity_activation(self._opacity)
546
+
547
+ def get_covariance(self, scaling_modifier=1):
548
+ return self.covariance_activation(
549
+ self.get_scaling, scaling_modifier, self._rotation
550
+ )
551
+
552
+ def construct_dtypes(self, use_fp16=False, enable_gs_viewer=True):
553
+ if not use_fp16:
554
+ l = [
555
+ ("x", "f4"),
556
+ ("y", "f4"),
557
+ ("z", "f4"),
558
+ ("red", "u1"),
559
+ ("green", "u1"),
560
+ ("blue", "u1"),
561
+ ]
562
+ # All channels except the 3 DC
563
+ for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
564
+ l.append((f"f_dc_{i}", "f4"))
565
+
566
+ if enable_gs_viewer:
567
+ assert self.sh_degree <= 3, "GS viewer only supports SH up to degree 3"
568
+ sh_degree = 3
569
+ for i in range(((sh_degree + 1) ** 2 - 1) * 3):
570
+ l.append((f"f_rest_{i}", "f4"))
571
+ else:
572
+ if self.sh_degree > 0:
573
+ for i in range(
574
+ self._features_rest.shape[1] * self._features_rest.shape[2]
575
+ ):
576
+ l.append((f"f_rest_{i}", "f4"))
577
+
578
+ l.append(("opacity", "f4"))
579
+ for i in range(self._scaling.shape[1]):
580
+ l.append((f"scale_{i}", "f4"))
581
+ for i in range(self._rotation.shape[1]):
582
+ l.append((f"rot_{i}", "f4"))
583
+ else:
584
+ l = [
585
+ ("x", "f2"),
586
+ ("y", "f2"),
587
+ ("z", "f2"),
588
+ ("red", "u1"),
589
+ ("green", "u1"),
590
+ ("blue", "u1"),
591
+ ]
592
+ # All channels except the 3 DC
593
+ for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
594
+ l.append((f"f_dc_{i}", "f2"))
595
+
596
+ if self.sh_degree > 0:
597
+ for i in range(
598
+ self._features_rest.shape[1] * self._features_rest.shape[2]
599
+ ):
600
+ l.append((f"f_rest_{i}", "f2"))
601
+ l.append(("opacity", "f2"))
602
+ for i in range(self._scaling.shape[1]):
603
+ l.append((f"scale_{i}", "f2"))
604
+ for i in range(self._rotation.shape[1]):
605
+ l.append((f"rot_{i}", "f2"))
606
+ return l
607
+
608
+ def save_ply(
609
+ self,
610
+ path,
611
+ use_fp16=False,
612
+ enable_gs_viewer=True,
613
+ color_code=False,
614
+ filter_mask=None,
615
+ ):
616
+ os.makedirs(os.path.dirname(path), exist_ok=True)
617
+
618
+ xyz = self._xyz.detach().cpu().numpy()
619
+ f_dc = (
620
+ self._features_dc.detach()
621
+ .transpose(1, 2)
622
+ .flatten(start_dim=1)
623
+ .contiguous()
624
+ .cpu()
625
+ .numpy()
626
+ )
627
+ if not color_code:
628
+ rgb = (SH2RGB(f_dc) * 255.0).clip(0.0, 255.0).astype(np.uint8)
629
+ else:
630
+ # use an color map to color code the index of points
631
+ index = np.linspace(0, 1, xyz.shape[0])
632
+ rgb = matplotlib.colormaps["viridis"](index)[..., :3]
633
+ rgb = (rgb * 255.0).clip(0.0, 255.0).astype(np.uint8)
634
+
635
+ opacities = self._opacity.detach().cpu().numpy()
636
+ if self.scaling_modifier is not None:
637
+ scale = self.inv_scaling_activation(self.get_scaling).detach().cpu().numpy()
638
+ else:
639
+ scale = self._scaling.detach().cpu().numpy()
640
+ rotation = self._rotation.detach().cpu().numpy()
641
+
642
+ dtype_full = self.construct_dtypes(use_fp16, enable_gs_viewer)
643
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
644
+
645
+ f_rest = None
646
+ if self.sh_degree > 0:
647
+ f_rest = (
648
+ self._features_rest.detach()
649
+ .transpose(1, 2)
650
+ .flatten(start_dim=1)
651
+ .contiguous()
652
+ .cpu()
653
+ .numpy()
654
+ ) # (3, (self.sh_degree + 1) ** 2 - 1)
655
+
656
+ if enable_gs_viewer:
657
+ sh_degree = 3
658
+ if f_rest is None:
659
+ f_rest = np.zeros(
660
+ (xyz.shape[0], 3 * ((sh_degree + 1) ** 2 - 1)), dtype=np.float32
661
+ )
662
+ elif f_rest.shape[1] < 3 * ((sh_degree + 1) ** 2 - 1):
663
+ f_rest_pad = np.zeros(
664
+ (xyz.shape[0], 3 * ((sh_degree + 1) ** 2 - 1)), dtype=np.float32
665
+ )
666
+ f_rest_pad[:, : f_rest.shape[1]] = f_rest
667
+ f_rest = f_rest_pad
668
+
669
+ if f_rest is not None:
670
+ attributes = np.concatenate(
671
+ (xyz, rgb, f_dc, f_rest, opacities, scale, rotation), axis=1
672
+ )
673
+ else:
674
+ attributes = np.concatenate(
675
+ (xyz, rgb, f_dc, opacities, scale, rotation), axis=1
676
+ )
677
+
678
+ if filter_mask is not None:
679
+ attributes = attributes[filter_mask]
680
+ elements = elements[filter_mask]
681
+
682
+ elements[:] = list(map(tuple, attributes))
683
+ el = PlyElement.describe(elements, "vertex")
684
+ PlyData([el]).write(path)
685
+
686
+ def load_ply(self, path):
687
+ plydata = PlyData.read(path)
688
+
689
+ xyz = np.stack(
690
+ (
691
+ np.asarray(plydata.elements[0]["x"]),
692
+ np.asarray(plydata.elements[0]["y"]),
693
+ np.asarray(plydata.elements[0]["z"]),
694
+ ),
695
+ axis=1,
696
+ )
697
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
698
+
699
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
700
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
701
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
702
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
703
+
704
+ if self.sh_degree > 0:
705
+ extra_f_names = [
706
+ p.name
707
+ for p in plydata.elements[0].properties
708
+ if p.name.startswith("f_rest_")
709
+ ]
710
+ extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
711
+ assert len(extra_f_names) == 3 * (self.sh_degree + 1) ** 2 - 3
712
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
713
+ for idx, attr_name in enumerate(extra_f_names):
714
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
715
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
716
+ features_extra = features_extra.reshape(
717
+ (features_extra.shape[0], 3, (self.sh_degree + 1) ** 2 - 1)
718
+ )
719
+
720
+ scale_names = [
721
+ p.name
722
+ for p in plydata.elements[0].properties
723
+ if p.name.startswith("scale_")
724
+ ]
725
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
726
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
727
+ for idx, attr_name in enumerate(scale_names):
728
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
729
+
730
+ rot_names = [
731
+ p.name for p in plydata.elements[0].properties if p.name.startswith("rot")
732
+ ]
733
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
734
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
735
+ for idx, attr_name in enumerate(rot_names):
736
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
737
+
738
+ self._xyz = torch.from_numpy(xyz.astype(np.float32))
739
+ self._features_dc = (
740
+ torch.from_numpy(features_dc.astype(np.float32))
741
+ .transpose(1, 2)
742
+ .contiguous()
743
+ )
744
+ if self.sh_degree > 0:
745
+ self._features_rest = (
746
+ torch.from_numpy(features_extra.astype(np.float32))
747
+ .transpose(1, 2)
748
+ .contiguous()
749
+ )
750
+ self._opacity = torch.from_numpy(
751
+ np.copy(opacities).astype(np.float32)
752
+ ).contiguous()
753
+ self._scaling = torch.from_numpy(scales.astype(np.float32)).contiguous()
754
+ self._rotation = torch.from_numpy(rots.astype(np.float32)).contiguous()
755
+
756
+
757
+ def render_opencv_cam(
758
+ pc: GaussianModel,
759
+ height: int,
760
+ width: int,
761
+ C2W: torch.Tensor,
762
+ fxfycxcy: torch.Tensor,
763
+ bg_color=(1.0, 1.0, 1.0),
764
+ scaling_modifier=1.0,
765
+ ):
766
+ """
767
+ Render the scene.
768
+
769
+ Background tensor (bg_color) must be on GPU!
770
+ """
771
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
772
+ screenspace_points = torch.empty_like(
773
+ pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
774
+ )
775
+ # try:
776
+ # screenspace_points.retain_grad()
777
+ # except:
778
+ # pass
779
+
780
+ viewpoint_camera = Camera(C2W=C2W, fxfycxcy=fxfycxcy, h=height, w=width)
781
+
782
+ bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=C2W.device)
783
+
784
+ # Set up rasterization configuration
785
+ raster_settings = GaussianRasterizationSettings(
786
+ image_height=int(viewpoint_camera.h),
787
+ image_width=int(viewpoint_camera.w),
788
+ tanfovx=viewpoint_camera.tanfovX,
789
+ tanfovy=viewpoint_camera.tanfovY,
790
+ bg=bg_color,
791
+ scale_modifier=scaling_modifier,
792
+ viewmatrix=viewpoint_camera.world_view_transform,
793
+ projmatrix=viewpoint_camera.full_proj_transform,
794
+ sh_degree=pc.sh_degree,
795
+ campos=viewpoint_camera.camera_center,
796
+ prefiltered=False,
797
+ debug=False,
798
+ )
799
+
800
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
801
+
802
+ means3D = pc.get_xyz
803
+ means2D = screenspace_points
804
+ opacity = pc.get_opacity
805
+ scales = pc.get_scaling
806
+ rotations = pc.get_rotation
807
+ shs = pc.get_features
808
+
809
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
810
+ rendered_image, radii, _, _ = rasterizer(
811
+ means3D=means3D,
812
+ means2D=means2D,
813
+ shs=shs,
814
+ colors_precomp=None,
815
+ opacities=opacity,
816
+ scales=scales,
817
+ rotations=rotations,
818
+ cov3D_precomp=None,
819
+ )
820
+
821
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
822
+ # They will be excluded from value updates used in the splitting criteria.
823
+ return {
824
+ "render": rendered_image,
825
+ "viewspace_points": screenspace_points,
826
+ "visibility_filter": radii > 0,
827
+ "radii": radii,
828
+ }
829
+
830
+
831
+ class DeferredGaussianRender(torch.autograd.Function):
832
+ @staticmethod
833
+ def forward(
834
+ ctx,
835
+ xyz,
836
+ features,
837
+ scaling,
838
+ rotation,
839
+ opacity,
840
+ height,
841
+ width,
842
+ C2W,
843
+ fxfycxcy,
844
+ scaling_modifier=None,
845
+ ):
846
+ """
847
+ xyz: [b, n_gaussians, 3]
848
+ features: [b, n_gaussians, (sh_degree+1)^2, 3]
849
+ scaling: [b, n_gaussians, 3]
850
+ rotation: [b, n_gaussians, 4]
851
+ opacity: [b, n_gaussians, 1]
852
+
853
+ height: int
854
+ width: int
855
+ C2W: [b, v, 4, 4]
856
+ fxfycxcy: [b, v, 4]
857
+
858
+ output: [b, v, 3, height, width]
859
+ """
860
+ ctx.scaling_modifier = scaling_modifier
861
+
862
+ # Infer sh_degree from features
863
+ sh_degree = int(math.sqrt(features.shape[-2])) - 1
864
+
865
+ # Create a temp class to hold the data and for rendering
866
+ gaussians_model = GaussianModel(sh_degree, scaling_modifier)
867
+
868
+ with torch.no_grad():
869
+ b, v = C2W.size(0), C2W.size(1)
870
+ renders = []
871
+ for i in range(b):
872
+ pc = gaussians_model.set_data(
873
+ xyz[i], features[i], scaling[i], rotation[i], opacity[i]
874
+ )
875
+ for j in range(v):
876
+ renders.append(
877
+ render_opencv_cam(pc, height, width, C2W[i, j], fxfycxcy[i, j])[
878
+ "render"
879
+ ]
880
+ )
881
+ renders = torch.stack(renders, dim=0)
882
+ renders = renders.reshape(b, v, 3, height, width)
883
+
884
+ renders = renders.requires_grad_()
885
+
886
+ # Save_for_backward only supports tensors
887
+ ctx.save_for_backward(xyz, features, scaling, rotation, opacity, C2W, fxfycxcy)
888
+ ctx.rendering_size = (height, width)
889
+ ctx.sh_degree = sh_degree
890
+
891
+ # Release the temp class; do not save it.
892
+ del gaussians_model
893
+
894
+ return renders
895
+
896
+ @staticmethod
897
+ def backward(ctx, grad_output):
898
+ # Restore params
899
+ xyz, features, scaling, rotation, opacity, C2W, fxfycxcy = ctx.saved_tensors
900
+ height, width = ctx.rendering_size
901
+ sh_degree = ctx.sh_degree
902
+
903
+ # **The order of this dict should not be changed**
904
+ input_dict = OrderedDict(
905
+ [
906
+ ("xyz", xyz),
907
+ ("features", features),
908
+ ("scaling", scaling),
909
+ ("rotation", rotation),
910
+ ("opacity", opacity),
911
+ ]
912
+ )
913
+ input_dict = {k: v.detach().requires_grad_() for k, v in input_dict.items()}
914
+
915
+ # Create a temp class to hold the data and for rendering
916
+ gaussians_model = GaussianModel(sh_degree, ctx.scaling_modifier)
917
+
918
+ with torch.enable_grad():
919
+ b, v = C2W.size(0), C2W.size(1)
920
+ for i in range(b):
921
+ for j in range(v):
922
+ # The backward will remove the diff graph, thus each time we need a copy
923
+ pc = gaussians_model.set_data(
924
+ **{k: v[i] for k, v in input_dict.items()}
925
+ )
926
+
927
+ # Forward
928
+ render = render_opencv_cam(
929
+ pc, height, width, C2W[i, j], fxfycxcy[i, j]
930
+ )["render"]
931
+
932
+ # Backward, suppose that only values in input_dict will get gradients.
933
+ render.backward(grad_output[i, j])
934
+
935
+ del gaussians_model
936
+
937
+ return *[var.grad for var in input_dict.values()], None, None, None, None, None
938
+
939
+
940
+ # Function for the class
941
+ deferred_gaussian_render = DeferredGaussianRender.apply
942
+
943
+ @torch.no_grad()
944
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
945
+ def render_turntable(pc: GaussianModel, rendering_resolution=384, num_views=8):
946
+ w, h, v, fxfycxcy, c2w = get_turntable_cameras(
947
+ h=rendering_resolution, w=rendering_resolution, num_views=num_views,
948
+ elevation=0, # For MAX SNEAK
949
+ )
950
+
951
+ device = pc._xyz.device
952
+ fxfycxcy = torch.from_numpy(fxfycxcy).float().to(device) # [v, 4]
953
+ c2w = torch.from_numpy(c2w).float().to(device) # [v, 4, 4]
954
+
955
+ renderings = torch.zeros(v, 3, h, w, dtype=torch.float32, device=device)
956
+ for j in range(v):
957
+ renderings[j] = render_opencv_cam(pc, h, w, c2w[j], fxfycxcy[j])["render"]
958
+ torch.cuda.empty_cache() # free up memory on GPU
959
+ renderings = renderings.detach().cpu().numpy()
960
+ renderings = (renderings * 255).clip(0, 255).astype(np.uint8)
961
+ renderings = rearrange(renderings, "v c h w -> h (v w) c")
962
+ return renderings
963
+
964
+
965
+ if __name__ == "__main__":
966
+ import json
967
+
968
+ from PIL import Image
969
+ from tqdm import tqdm
970
+
971
+ out_dir = "/mnt/localssd/debug-3dgs"
972
+ os.makedirs(out_dir, exist_ok=True)
973
+
974
+ os.system(
975
+ f"wget https://phidias.s3.us-west-2.amazonaws.com/kaiz/neural-capture/eval-3dgs-lowres/AWS_test_set/results/1.fashion_boots_rubber_boots__short__Feb_21__2023_at_5_19_25_PM_yf/point_cloud/iteration_30000_fg/point_cloud.ply -O {out_dir}/point_cloud.ply"
976
+ )
977
+ os.system(
978
+ f"wget https://neural-capture.s3.us-west-2.amazonaws.com/data/AWS_test_set/preprocessed/1.fashion_boots_rubber_boots__short__Feb_21__2023_at_5_19_25_PM_yf/opencv_cameras_traj_norm.json -O {out_dir}/opencv_cameras_traj_norm.json"
979
+ )
980
+
981
+ device = "cuda:0"
982
+
983
+ pc = GaussianModel(sh_degree=3)
984
+ pc.load_ply(f"{out_dir}/point_cloud.ply")
985
+ pc = pc.to(device)
986
+
987
+ # pc.save_ply(f"{out_dir}/point_cloud_shrink.ply")
988
+ # pc.load_ply(f"{out_dir}/point_cloud_shrink.ply")
989
+ # pc = pc.to(device)
990
+
991
+ # pc.prune(opacity_thres=0.05)
992
+ # pc.save_ply(f"{out_dir}/point_cloud_shrink_prune.ply")
993
+ # pc = pc.to(device)
994
+
995
+ # pc.shrink_bbx(drop_ratio=0.01)
996
+ # pc.save_ply(f"{out_dir}/point_cloud_shrink_prune.ply")
997
+ # pc = pc.to(device)
998
+
999
+ pc.report_stats()
1000
+
1001
+ with open(f"{out_dir}/opencv_cameras_traj_norm.json", "r") as f:
1002
+ cam_traj = json.load(f)
1003
+
1004
+ for i, cam in tqdm(enumerate(cam_traj["frames"]), desc="Rendering progress"):
1005
+ w2c = np.array(cam["w2c"])
1006
+ c2w = np.linalg.inv(w2c)
1007
+ c2w = torch.from_numpy(c2w.astype(np.float32)).to(device)
1008
+
1009
+ fx = cam["fx"]
1010
+ fy = cam["fy"]
1011
+ cx = cam["cx"]
1012
+ cy = cam["cy"]
1013
+ cx = cx - 5
1014
+ cy = cy + 4
1015
+ fxfycxcy = torch.tensor([fx, fy, cx, cy], dtype=torch.float32, device=device)
1016
+
1017
+ h = cam["h"]
1018
+ w = cam["w"]
1019
+
1020
+ im = render_opencv_cam(pc, h, w, c2w, fxfycxcy, bg_color=[0.0, 0.0, 0.0])[
1021
+ "render"
1022
+ ]
1023
+ im = im.detach().cpu().numpy().transpose(1, 2, 0)
1024
+ im = (im * 255).astype(np.uint8)
1025
+ Image.fromarray(im).save(f"{out_dir}/render_{i:08d}.png")
1026
+
1027
+ create_video(out_dir, f"{out_dir}/render.mp4", framerate=30)
1028
+ print(f"Saved {out_dir}/render.mp4")
gslrm/model/gslrm.py ADDED
@@ -0,0 +1,1647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
9
+ """
10
+ GSLRM (Gaussian Splatting Large Reconstruction Model)
11
+
12
+ This module implements a transformer-based model for generating 3D Gaussian splats
13
+ from multi-view images. The model uses a combination of image tokenization,
14
+ transformer processing, and Gaussian splatting for novel view synthesis.
15
+
16
+ Classes:
17
+ Renderer: Handles Gaussian splatting rendering operations
18
+ GaussiansUpsampler: Converts transformer tokens to Gaussian parameters
19
+ LossComputer: Computes various loss functions for training
20
+ TransformTarget: Handles target image transformations (cropping, etc.)
21
+ GSLRM: Main model class that orchestrates the entire pipeline
22
+ """
23
+
24
+ import copy
25
+ import os
26
+ import time
27
+ from typing import Dict, List, Optional, Tuple, Union
28
+
29
+ import cv2
30
+ import lpips
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from easydict import EasyDict as edict
36
+ from einops import rearrange
37
+ from einops.layers.torch import Rearrange
38
+ from PIL import Image
39
+
40
+ # Local imports
41
+ from .utils_losses import PerceptualLoss, SsimLoss
42
+ from .gaussians_renderer import (
43
+ GaussianModel,
44
+ RGB2SH,
45
+ deferred_gaussian_render,
46
+ imageseq2video,
47
+ render_opencv_cam,
48
+ render_turntable,
49
+ )
50
+ from .transform_data import SplitData, TransformInput, TransformTarget
51
+ from .utils_transformer import (
52
+ TransformerBlock,
53
+ _init_weights,
54
+ )
55
+
56
+ class Renderer(nn.Module):
57
+ """
58
+ Handles Gaussian splatting rendering operations.
59
+
60
+ Supports both deferred rendering (for training with gradients) and
61
+ standard rendering (for inference).
62
+ """
63
+
64
+ def __init__(self, config: edict):
65
+ super().__init__()
66
+ self.config = config
67
+
68
+ # Initialize Gaussian model with scaling modifier
69
+ self.scaling_modifier = config.model.gaussians.get("scaling_modifier", None)
70
+ self.gaussians_model = GaussianModel(
71
+ config.model.gaussians.sh_degree,
72
+ self.scaling_modifier
73
+ )
74
+
75
+ print(f"Renderer initialized with scaling_modifier: {self.scaling_modifier}")
76
+
77
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
78
+ def forward(
79
+ self,
80
+ xyz: torch.Tensor, # [b, n_gaussians, 3]
81
+ features: torch.Tensor, # [b, n_gaussians, (sh_degree+1)^2, 3]
82
+ scaling: torch.Tensor, # [b, n_gaussians, 3]
83
+ rotation: torch.Tensor, # [b, n_gaussians, 4]
84
+ opacity: torch.Tensor, # [b, n_gaussians, 1]
85
+ height: int,
86
+ width: int,
87
+ C2W: torch.Tensor, # [b, v, 4, 4]
88
+ fxfycxcy: torch.Tensor, # [b, v, 4]
89
+ deferred: bool = True,
90
+ ) -> torch.Tensor: # [b, v, 3, height, width]
91
+ """
92
+ Render Gaussian splats to images.
93
+
94
+ Args:
95
+ xyz: Gaussian positions
96
+ features: Gaussian spherical harmonic features
97
+ scaling: Gaussian scaling parameters
98
+ rotation: Gaussian rotation quaternions
99
+ opacity: Gaussian opacity values
100
+ height: Output image height
101
+ width: Output image width
102
+ C2W: Camera-to-world transformation matrices
103
+ fxfycxcy: Camera intrinsics (fx, fy, cx, cy)
104
+ deferred: Whether to use deferred rendering (maintains gradients)
105
+
106
+ Returns:
107
+ Rendered images
108
+ """
109
+ if deferred:
110
+ return deferred_gaussian_render(
111
+ xyz, features, scaling, rotation, opacity,
112
+ height, width, C2W, fxfycxcy, self.scaling_modifier
113
+ )
114
+ else:
115
+ return self._render_sequential(
116
+ xyz, features, scaling, rotation, opacity,
117
+ height, width, C2W, fxfycxcy
118
+ )
119
+
120
+ def _render_sequential(
121
+ self, xyz, features, scaling, rotation, opacity,
122
+ height, width, C2W, fxfycxcy
123
+ ) -> torch.Tensor:
124
+ """Sequential rendering without gradient support (used for inference)."""
125
+ b, v = C2W.size(0), C2W.size(1)
126
+ renderings = torch.zeros(
127
+ b, v, 3, height, width, dtype=torch.float32, device=xyz.device
128
+ )
129
+
130
+ for i in range(b):
131
+ pc = self.gaussians_model.set_data(
132
+ xyz[i], features[i], scaling[i], rotation[i], opacity[i]
133
+ )
134
+ for j in range(v):
135
+ renderings[i, j] = render_opencv_cam(
136
+ pc, height, width, C2W[i, j], fxfycxcy[i, j]
137
+ )["render"]
138
+
139
+ return renderings
140
+
141
+
142
+ class GaussiansUpsampler(nn.Module):
143
+ """
144
+ Converts transformer output tokens to Gaussian splatting parameters.
145
+
146
+ Takes high-dimensional transformer features and projects them to the
147
+ concatenated Gaussian parameter space (xyz + features + scaling + rotation + opacity).
148
+ """
149
+
150
+ def __init__(self, config: edict):
151
+ super().__init__()
152
+ self.config = config
153
+
154
+ # Layer normalization before final projection
155
+ self.layernorm = nn.LayerNorm(config.model.transformer.d, bias=False)
156
+
157
+ # Calculate output dimension for Gaussian parameters
158
+ sh_dim = (config.model.gaussians.sh_degree + 1) ** 2 * 3
159
+ gaussian_param_dim = 3 + sh_dim + 3 + 4 + 1 # xyz + features + scaling + rotation + opacity
160
+
161
+ # Check upsampling factor (currently only supports 1x)
162
+ upsample_factor = config.model.gaussians.upsampler.upsample_factor
163
+ if upsample_factor > 1:
164
+ raise NotImplementedError("GaussiansUpsampler only supports upsample_factor=1")
165
+
166
+ # Linear projection to Gaussian parameters
167
+ self.linear = nn.Linear(
168
+ config.model.transformer.d,
169
+ gaussian_param_dim,
170
+ bias=False,
171
+ )
172
+
173
+ def forward(
174
+ self,
175
+ gaussians: torch.Tensor, # [b, n_gaussians, d]
176
+ images: torch.Tensor # [b, l, d] (unused but kept for interface compatibility)
177
+ ) -> torch.Tensor: # [b, n_gaussians, gaussian_param_dim]
178
+ """
179
+ Convert transformer tokens to Gaussian parameters.
180
+
181
+ Args:
182
+ gaussians: Transformer output tokens for Gaussians
183
+ images: Image tokens (unused but kept for compatibility)
184
+
185
+ Returns:
186
+ Raw Gaussian parameters (before conversion to final format)
187
+ """
188
+ upsample_factor = self.config.model.gaussians.upsampler.upsample_factor
189
+ if upsample_factor > 1:
190
+ raise NotImplementedError("GaussiansUpsampler only supports upsample_factor=1")
191
+
192
+ return self.linear(self.layernorm(gaussians))
193
+
194
+ def to_gs(self, gaussians: torch.Tensor) -> Tuple[torch.Tensor, ...]:
195
+ """
196
+ Convert raw Gaussian parameters to final format.
197
+
198
+ Args:
199
+ gaussians: Raw Gaussian parameters [b, n_gaussians, param_dim]
200
+
201
+ Returns:
202
+ Tuple of (xyz, features, scaling, rotation, opacity)
203
+ """
204
+ sh_dim = (self.config.model.gaussians.sh_degree + 1) ** 2 * 3
205
+
206
+ # Split concatenated parameters
207
+ xyz, features, scaling, rotation, opacity = gaussians.split(
208
+ [3, sh_dim, 3, 4, 1], dim=2
209
+ )
210
+
211
+ # Reshape features to proper spherical harmonics format
212
+ features = features.reshape(
213
+ features.size(0),
214
+ features.size(1),
215
+ (self.config.model.gaussians.sh_degree + 1) ** 2,
216
+ 3,
217
+ )
218
+
219
+ # Apply activation functions with specific biases
220
+ # Scaling: exp(x - 2.3) clamped to prevent too large values
221
+ scaling = (scaling - 2.3).clamp(max=-1.20)
222
+
223
+ # Opacity: sigmoid(x - 2.0) to get values in [0, 1]
224
+ opacity = opacity - 2.0
225
+
226
+ return xyz, features, scaling, rotation, opacity
227
+
228
+
229
+ class LossComputer(nn.Module):
230
+ """
231
+ Computes various loss functions for training the GSLRM model.
232
+
233
+ Supports multiple loss types:
234
+ - L2 (MSE) loss
235
+ - LPIPS perceptual loss
236
+ - Custom perceptual loss
237
+ - SSIM loss
238
+ - Pixel alignment loss
239
+ - Point distance regularization loss
240
+ """
241
+
242
+ def __init__(self, config: edict):
243
+ super().__init__()
244
+ self.config = config
245
+
246
+ # Initialize loss modules based on config
247
+ self._init_loss_modules()
248
+
249
+ def _init_loss_modules(self):
250
+ """Initialize the various loss computation modules."""
251
+ # LPIPS loss
252
+ if self.config.training.losses.lpips_loss_weight > 0.0:
253
+ self.lpips_loss_module = lpips.LPIPS(net="vgg")
254
+ self.lpips_loss_module.eval()
255
+ # Freeze LPIPS parameters
256
+ for param in self.lpips_loss_module.parameters():
257
+ param.requires_grad = False
258
+
259
+ # Perceptual loss
260
+ if self.config.training.losses.perceptual_loss_weight > 0.0:
261
+ self.perceptual_loss_module = PerceptualLoss()
262
+ self.perceptual_loss_module.eval()
263
+ # Freeze perceptual loss parameters
264
+ for param in self.perceptual_loss_module.parameters():
265
+ param.requires_grad = False
266
+
267
+ # SSIM loss
268
+ if self.config.training.losses.ssim_loss_weight > 0.0:
269
+ self.ssim_loss_module = SsimLoss()
270
+ self.ssim_loss_module.eval()
271
+ # Freeze SSIM parameters
272
+ for param in self.ssim_loss_module.parameters():
273
+ param.requires_grad = False
274
+
275
+ def forward(
276
+ self,
277
+ rendering: torch.Tensor, # [b, v, 3, h, w]
278
+ target: torch.Tensor, # [b, v, 3, h, w]
279
+ img_aligned_xyz: torch.Tensor, # [b, v, 3, h, w]
280
+ input: edict,
281
+ result_softpa: Optional[edict] = None,
282
+ create_visual: bool = False,
283
+ ) -> edict:
284
+ """
285
+ Compute all losses between rendered and target images.
286
+
287
+ Args:
288
+ rendering: Rendered images in range [0, 1]
289
+ target: Target images in range [0, 1]
290
+ img_aligned_xyz: Image-aligned 3D positions
291
+ input: Input data containing ray information
292
+ result_softpa: Additional results (unused)
293
+ create_visual: Whether to create visualization images
294
+
295
+ Returns:
296
+ Dictionary containing all loss values and metrics
297
+ """
298
+ b, v, _, h, w = rendering.size()
299
+ rendering_flat = rendering.reshape(b * v, -1, h, w)
300
+ target_flat = target.reshape(b * v, -1, h, w)
301
+
302
+ # Handle alpha channel if present
303
+ mask = None
304
+ if target_flat.size(1) == 4:
305
+ target_flat, mask = target_flat.split([3, 1], dim=1)
306
+
307
+ # Compute individual losses
308
+ losses = self._compute_all_losses(
309
+ rendering_flat, target_flat, img_aligned_xyz, input, mask, b, v, h, w
310
+ )
311
+
312
+ # Compute total weighted loss
313
+ total_loss = self._compute_total_loss(losses)
314
+
315
+ # Create visualization if requested
316
+ visual = self._create_visual(rendering_flat, target_flat, v) if create_visual else None
317
+
318
+ # Compile loss metrics
319
+ return self._compile_loss_metrics(losses, total_loss, visual)
320
+
321
+ def _compute_all_losses(self, rendering, target, img_aligned_xyz, input, mask, b, v, h, w):
322
+ """Compute all individual loss components."""
323
+ losses = {}
324
+
325
+ # L2 (MSE) loss
326
+ losses['l2'] = self._compute_l2_loss(rendering, target)
327
+ losses['psnr'] = -10.0 * torch.log10(losses['l2'])
328
+
329
+ # LPIPS loss
330
+ losses['lpips'] = self._compute_lpips_loss(rendering, target)
331
+
332
+ # Perceptual loss
333
+ losses['perceptual'] = self._compute_perceptual_loss(rendering, target)
334
+
335
+ # SSIM loss
336
+ losses['ssim'] = self._compute_ssim_loss(rendering, target)
337
+
338
+ # Pixel alignment loss
339
+ losses['pixelalign'] = self._compute_pixelalign_loss(
340
+ img_aligned_xyz, input, mask, b, v, h, w
341
+ )
342
+
343
+ # Point distance loss
344
+ losses['pointsdist'] = self._compute_pointsdist_loss(
345
+ img_aligned_xyz, input, b, v, h, w
346
+ )
347
+
348
+ return losses
349
+
350
+ def _compute_l2_loss(self, rendering, target):
351
+ """Compute L2 (MSE) loss."""
352
+ if self.config.training.losses.l2_loss_weight > 0.0:
353
+ return F.mse_loss(rendering, target)
354
+ return torch.tensor(1e-8, device=rendering.device)
355
+
356
+ def _compute_lpips_loss(self, rendering, target):
357
+ """Compute LPIPS perceptual loss."""
358
+ if self.config.training.losses.lpips_loss_weight > 0.0:
359
+ # LPIPS expects inputs in range [-1, 1]
360
+ return self.lpips_loss_module(
361
+ rendering * 2.0 - 1.0, target * 2.0 - 1.0
362
+ ).mean()
363
+ return torch.tensor(0.0, device=rendering.device)
364
+
365
+ def _compute_perceptual_loss(self, rendering, target):
366
+ """Compute custom perceptual loss."""
367
+ if self.config.training.losses.perceptual_loss_weight > 0.0:
368
+ return self.perceptual_loss_module(rendering, target)
369
+ return torch.tensor(0.0, device=rendering.device)
370
+
371
+ def _compute_ssim_loss(self, rendering, target):
372
+ """Compute SSIM loss."""
373
+ if self.config.training.losses.ssim_loss_weight > 0.0:
374
+ return self.ssim_loss_module(rendering, target)
375
+ return torch.tensor(0.0, device=rendering.device)
376
+
377
+ def _compute_pixelalign_loss(self, img_aligned_xyz, input, mask, b, v, h, w):
378
+ """Compute pixel alignment loss."""
379
+ if self.config.training.losses.pixelalign_loss_weight > 0.0:
380
+ # Compute orthogonal component to ray direction
381
+ xyz_vec = img_aligned_xyz - input.ray_o
382
+ ortho_vec = (
383
+ xyz_vec
384
+ - torch.sum(xyz_vec.detach() * input.ray_d, dim=2, keepdim=True)
385
+ * input.ray_d
386
+ )
387
+
388
+ # Apply mask if enabled
389
+ if self.config.training.losses.get("masked_pixelalign_loss", False):
390
+ assert mask is not None, "mask is None but masked_pixelalign_loss is enabled"
391
+ mask_reshaped = mask.view(b, v, 1, h, w)
392
+ ortho_vec = ortho_vec * mask_reshaped
393
+
394
+ return torch.mean(ortho_vec.norm(dim=2, p=2))
395
+
396
+ return torch.tensor(0.0, device=img_aligned_xyz.device)
397
+
398
+ def _compute_pointsdist_loss(self, img_aligned_xyz, input, b, v, h, w):
399
+ """Compute point distance regularization loss."""
400
+ if self.config.training.losses.pointsdist_loss_weight > 0.0:
401
+ # Target mean distance (distance from origin to ray origin)
402
+ target_mean_dist = torch.norm(input.ray_o, dim=2, p=2, keepdim=True)
403
+ target_std_dist = 0.5
404
+
405
+ # Predicted distance
406
+ pred_dist = (img_aligned_xyz - input.ray_o).norm(dim=2, p=2, keepdim=True)
407
+
408
+ # Normalize to target distribution
409
+ pred_dist_detach = pred_dist.detach()
410
+ pred_mean = pred_dist_detach.mean(dim=(2, 3, 4), keepdim=True)
411
+ pred_std = pred_dist_detach.std(dim=(2, 3, 4), keepdim=True)
412
+
413
+ target_dist = (pred_dist_detach - pred_mean) / (pred_std + 1e-8) * target_std_dist + target_mean_dist
414
+
415
+ return torch.mean((pred_dist - target_dist) ** 2)
416
+
417
+ return torch.tensor(0.0, device=img_aligned_xyz.device)
418
+
419
+ def _compute_total_loss(self, losses):
420
+ """Compute weighted sum of all losses."""
421
+ weights = self.config.training.losses
422
+ return (
423
+ weights.l2_loss_weight * losses['l2']
424
+ + weights.lpips_loss_weight * losses['lpips']
425
+ + weights.perceptual_loss_weight * losses['perceptual']
426
+ + weights.ssim_loss_weight * losses['ssim']
427
+ + weights.pixelalign_loss_weight * losses['pixelalign']
428
+ + weights.pointsdist_loss_weight * losses['pointsdist']
429
+ )
430
+
431
+ def _create_visual(self, rendering, target, v):
432
+ """Create visualization by concatenating target and rendering."""
433
+ visual = torch.cat((target, rendering), dim=3).detach().cpu() # [b*v, c, h, w*2]
434
+ visual = rearrange(visual, "(b v) c h (m w) -> (b h) (v m w) c", v=v, m=2)
435
+ return (visual.numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
436
+
437
+ def _compile_loss_metrics(self, losses, total_loss, visual):
438
+ """Compile all loss metrics into a dictionary."""
439
+ l2_loss = losses['l2']
440
+
441
+ return edict(
442
+ loss=total_loss,
443
+ l2_loss=l2_loss,
444
+ psnr=losses['psnr'],
445
+ lpips_loss=losses['lpips'],
446
+ perceptual_loss=losses['perceptual'],
447
+ ssim_loss=losses['ssim'],
448
+ pixelalign_loss=losses['pixelalign'],
449
+ pointsdist_loss=losses['pointsdist'],
450
+ visual=visual,
451
+ # Normalized losses for logging
452
+ norm_perceptual_loss=losses['perceptual'] / l2_loss,
453
+ norm_lpips_loss=losses['lpips'] / l2_loss,
454
+ norm_ssim_loss=losses['ssim'] / l2_loss,
455
+ norm_pixelalign_loss=losses['pixelalign'] / l2_loss,
456
+ norm_pointsdist_loss=losses['pointsdist'] / l2_loss,
457
+ )
458
+
459
+
460
+ class GSLRM(nn.Module):
461
+ """
462
+ Gaussian Splatting Large Reconstruction Model.
463
+
464
+ A transformer-based model that generates 3D Gaussian splats from multi-view images.
465
+ The model processes input images through tokenization, transformer layers, and
466
+ generates Gaussian parameters for novel view synthesis.
467
+
468
+ Architecture:
469
+ 1. Image tokenization with patch-based encoding
470
+ 2. Transformer processing with Gaussian positional embeddings
471
+ 3. Gaussian parameter generation and upsampling
472
+ 4. Rendering and loss computation
473
+ """
474
+
475
+ def __init__(self, config: edict):
476
+ super().__init__()
477
+ self.config = config
478
+
479
+ # Initialize data processing modules
480
+ self._init_data_processors(config)
481
+
482
+ # Initialize core model components
483
+ self._init_tokenizer(config)
484
+ self._init_positional_embeddings(config)
485
+ self._init_transformer(config)
486
+ self._init_gaussian_modules(config)
487
+ self._init_rendering_modules(config)
488
+
489
+ # Initialize training state management
490
+ self._init_training_state(config)
491
+
492
+ def _init_data_processors(self, config: edict) -> None:
493
+ """Initialize data splitting and transformation modules."""
494
+ self.data_splitter = SplitData(config)
495
+ self.input_transformer = TransformInput(config)
496
+ self.target_transformer = TransformTarget(config)
497
+
498
+ def _init_tokenizer(self, config: edict) -> None:
499
+ """Initialize image tokenization pipeline."""
500
+ patch_size = config.model.image_tokenizer.patch_size
501
+ input_channels = config.model.image_tokenizer.in_channels
502
+ hidden_dim = config.model.transformer.d
503
+
504
+ self.patch_embedder = nn.Sequential(
505
+ Rearrange(
506
+ "batch views channels (height patch_h) (width patch_w) -> (batch views) (height width) (patch_h patch_w channels)",
507
+ patch_h=patch_size,
508
+ patch_w=patch_size,
509
+ ),
510
+ nn.Linear(
511
+ input_channels * (patch_size ** 2),
512
+ hidden_dim,
513
+ bias=False,
514
+ ),
515
+ )
516
+ self.patch_embedder.apply(_init_weights)
517
+
518
+ def _init_positional_embeddings(self, config: edict) -> None:
519
+ """Initialize positional embeddings for reference/source markers and Gaussians."""
520
+ hidden_dim = config.model.transformer.d
521
+
522
+ # Optional reference/source view markers
523
+ self.view_type_embeddings = None
524
+ if config.model.get("add_refsrc_marker", False):
525
+ self.view_type_embeddings = nn.Parameter(
526
+ torch.randn(2, hidden_dim) # [reference_marker, source_marker]
527
+ )
528
+ nn.init.trunc_normal_(self.view_type_embeddings, std=0.02)
529
+
530
+ # Gaussian positional embeddings
531
+ num_gaussians = config.model.gaussians.n_gaussians
532
+ self.gaussian_position_embeddings = nn.Parameter(
533
+ torch.randn(num_gaussians, hidden_dim)
534
+ )
535
+ nn.init.trunc_normal_(self.gaussian_position_embeddings, std=0.02)
536
+
537
+ def _init_transformer(self, config: edict) -> None:
538
+ """Initialize transformer architecture."""
539
+ hidden_dim = config.model.transformer.d
540
+ head_dim = config.model.transformer.d_head
541
+ num_layers = config.model.transformer.n_layer
542
+
543
+ self.input_layer_norm = nn.LayerNorm(hidden_dim, bias=False)
544
+ self.transformer_layers = nn.ModuleList([
545
+ TransformerBlock(hidden_dim, head_dim)
546
+ for _ in range(num_layers)
547
+ ])
548
+ self.transformer_layers.apply(_init_weights)
549
+
550
+ def _init_gaussian_modules(self, config: edict) -> None:
551
+ """Initialize Gaussian parameter generation modules."""
552
+ hidden_dim = config.model.transformer.d
553
+ patch_size = config.model.image_tokenizer.patch_size
554
+ sh_degree = config.model.gaussians.sh_degree
555
+
556
+ # Calculate output dimension for pixel-aligned Gaussians
557
+ # Components: xyz(3) + sh_features((sh_degree+1)^2*3) + scaling(3) + rotation(4) + opacity(1)
558
+ gaussian_param_dim = 3 + (sh_degree + 1) ** 2 * 3 + 3 + 4 + 1
559
+
560
+ # Gaussian upsampler for transformer tokens
561
+ self.gaussian_upsampler = GaussiansUpsampler(config)
562
+ self.gaussian_upsampler.apply(_init_weights)
563
+
564
+ # Pixel-aligned Gaussian decoder
565
+ self.pixel_gaussian_decoder = nn.Sequential(
566
+ nn.LayerNorm(hidden_dim, bias=False),
567
+ nn.Linear(
568
+ hidden_dim,
569
+ (patch_size ** 2) * gaussian_param_dim,
570
+ bias=False,
571
+ ),
572
+ )
573
+ self.pixel_gaussian_decoder.apply(_init_weights)
574
+
575
+ def _init_rendering_modules(self, config: edict) -> None:
576
+ """Initialize rendering and loss computation modules."""
577
+ self.gaussian_renderer = Renderer(config)
578
+ self.loss_calculator = LossComputer(config)
579
+
580
+ def _init_training_state(self, config: edict) -> None:
581
+ """Initialize training state management variables."""
582
+ self.training_step = None
583
+ self.training_start_step = None
584
+ self.training_max_step = None
585
+ self.original_config = copy.deepcopy(config)
586
+
587
+ def set_training_step(self, current_step: int, start_step: int, max_step: int) -> None:
588
+ """
589
+ Update training step and dynamically adjust configuration based on training phase.
590
+
591
+ Args:
592
+ current_step: Current training step
593
+ start_step: Starting step of training
594
+ max_step: Maximum training steps
595
+ """
596
+ self.training_step = current_step
597
+ self.training_start_step = start_step
598
+ self.training_max_step = max_step
599
+
600
+ # Determine if config modification is needed based on warmup settings
601
+ needs_config_modification = self._should_modify_config_for_warmup(current_step)
602
+
603
+ if needs_config_modification:
604
+ # Always use original config as base for modifications
605
+ self.config = copy.deepcopy(self.original_config)
606
+ self._apply_warmup_modifications(current_step)
607
+ else:
608
+ # Restore original configuration
609
+ self.config = self.original_config
610
+
611
+ # Update loss calculator with current config
612
+ self.loss_calculator.config = self.config
613
+
614
+ def _should_modify_config_for_warmup(self, current_step: int) -> bool:
615
+ """Check if configuration should be modified for warmup phases."""
616
+ pointsdist_warmup = (
617
+ self.config.training.losses.get("warmup_pointsdist", False)
618
+ and current_step < 1000
619
+ )
620
+ l2_warmup = (
621
+ self.config.training.schedule.get("l2_warmup_steps", 0) > 0
622
+ and current_step < self.config.training.schedule.l2_warmup_steps
623
+ )
624
+ return pointsdist_warmup or l2_warmup
625
+
626
+ def _apply_warmup_modifications(self, current_step: int) -> None:
627
+ """Apply configuration modifications for warmup phases."""
628
+ # Point distance warmup phase
629
+ if (self.config.training.losses.get("warmup_pointsdist", False)
630
+ and current_step < 1000):
631
+ self.config.training.losses.l2_loss_weight = 0.0
632
+ self.config.training.losses.perceptual_loss_weight = 0.0
633
+ self.config.training.losses.pointsdist_loss_weight = 0.1
634
+ self.config.model.clip_xyz = False # Disable xyz clipping during warmup
635
+
636
+ # L2 loss warmup phase
637
+ if (self.config.training.schedule.get("l2_warmup_steps", 0) > 0
638
+ and current_step < self.config.training.schedule.l2_warmup_steps):
639
+ self.config.training.losses.perceptual_loss_weight = 0.0
640
+ self.config.training.losses.lpips_loss_weight = 0.0
641
+
642
+ def set_current_step(self, current_step: int, start_step: int, max_step: int) -> None:
643
+ """Backward compatibility wrapper for set_training_step."""
644
+ self.set_training_step(current_step, start_step, max_step)
645
+
646
+ def train(self, mode: bool = True) -> None:
647
+ """
648
+ Override train method to keep frozen modules in eval mode.
649
+
650
+ Args:
651
+ mode: Whether to set training mode (True) or evaluation mode (False)
652
+ """
653
+ super().train(mode)
654
+ # Keep loss calculator in eval mode to prevent training of frozen components
655
+ if self.loss_calculator is not None:
656
+ self.loss_calculator.eval()
657
+
658
+ def get_parameter_overview(self) -> edict:
659
+ """
660
+ Get overview of trainable parameters in each module.
661
+
662
+ Returns:
663
+ Dictionary containing parameter counts for each major component
664
+ """
665
+ def count_trainable_params(module: nn.Module) -> int:
666
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
667
+
668
+ return edict(
669
+ patch_embedder=count_trainable_params(self.patch_embedder),
670
+ gaussian_position_embeddings=self.gaussian_position_embeddings.data.numel(),
671
+ transformer_total=(
672
+ count_trainable_params(self.transformer_layers) +
673
+ count_trainable_params(self.input_layer_norm)
674
+ ),
675
+ gaussian_upsampler=count_trainable_params(self.gaussian_upsampler),
676
+ pixel_gaussian_decoder=count_trainable_params(self.pixel_gaussian_decoder),
677
+ )
678
+
679
+ def get_overview(self) -> edict:
680
+ """Backward compatibility wrapper for get_parameter_overview."""
681
+ return self.get_parameter_overview()
682
+
683
+ def _create_transformer_layer_runner(self, start_layer: int, end_layer: int):
684
+ """
685
+ Create a function to run a subset of transformer layers.
686
+
687
+ Args:
688
+ start_layer: Starting layer index
689
+ end_layer: Ending layer index (exclusive)
690
+
691
+ Returns:
692
+ Function that processes tokens through specified layers
693
+ """
694
+ def run_transformer_layers(token_sequence: torch.Tensor) -> torch.Tensor:
695
+ for layer_idx in range(start_layer, min(end_layer, len(self.transformer_layers))):
696
+ token_sequence = self.transformer_layers[layer_idx](token_sequence)
697
+ return token_sequence
698
+ return run_transformer_layers
699
+
700
+ def _create_posed_images_with_plucker(self, input_data: edict) -> torch.Tensor:
701
+ """
702
+ Create posed images by concatenating RGB with Plucker coordinates.
703
+
704
+ Args:
705
+ input_data: Input data containing images and ray information
706
+
707
+ Returns:
708
+ Posed images with Plucker coordinates [batch, views, channels, height, width]
709
+ """
710
+ # Normalize RGB to [-1, 1] range
711
+ normalized_rgb = input_data.image[:, :, :3, :, :] * 2.0 - 1.0
712
+
713
+ if self.config.model.get("use_custom_plucker", False):
714
+ # Custom Plucker: RGB + ray_direction + nearest_points
715
+ ray_origin_dot_direction = torch.sum(
716
+ -input_data.ray_o * input_data.ray_d, dim=2, keepdim=True
717
+ )
718
+ nearest_points = input_data.ray_o + ray_origin_dot_direction * input_data.ray_d
719
+
720
+ return torch.cat([
721
+ normalized_rgb,
722
+ input_data.ray_d,
723
+ nearest_points,
724
+ ], dim=2)
725
+
726
+ elif self.config.model.get("use_aug_plucker", False):
727
+ # Augmented Plucker: RGB + cross_product + ray_direction + nearest_points
728
+ ray_cross_product = torch.cross(input_data.ray_o, input_data.ray_d, dim=2)
729
+ ray_origin_dot_direction = torch.sum(
730
+ -input_data.ray_o * input_data.ray_d, dim=2, keepdim=True
731
+ )
732
+ nearest_points = input_data.ray_o + ray_origin_dot_direction * input_data.ray_d
733
+
734
+ return torch.cat([
735
+ normalized_rgb,
736
+ ray_cross_product,
737
+ input_data.ray_d,
738
+ nearest_points,
739
+ ], dim=2)
740
+
741
+ else:
742
+ # Standard Plucker: RGB + cross_product + ray_direction
743
+ ray_cross_product = torch.cross(input_data.ray_o, input_data.ray_d, dim=2)
744
+
745
+ return torch.cat([
746
+ normalized_rgb,
747
+ ray_cross_product,
748
+ input_data.ray_d,
749
+ ], dim=2)
750
+
751
+ def _add_view_type_embeddings(
752
+ self,
753
+ image_tokens: torch.Tensor,
754
+ batch_size: int,
755
+ num_views: int,
756
+ num_patches: int,
757
+ hidden_dim: int
758
+ ) -> torch.Tensor:
759
+ """Add view type embeddings to distinguish reference vs source views."""
760
+ image_tokens = image_tokens.reshape(batch_size, num_views, num_patches, hidden_dim)
761
+
762
+ # Create view type markers: first view is reference, rest are source
763
+ view_markers = [self.view_type_embeddings[0]] + [
764
+ self.view_type_embeddings[1] for _ in range(1, num_views)
765
+ ]
766
+ view_markers = torch.stack(view_markers, dim=0)[None, :, None, :] # [1, views, 1, hidden_dim]
767
+
768
+ # Add markers to image tokens
769
+ image_tokens = image_tokens + view_markers
770
+ return image_tokens.reshape(batch_size, num_views * num_patches, hidden_dim)
771
+
772
+ def _process_through_transformer(
773
+ self,
774
+ gaussian_tokens: torch.Tensor,
775
+ image_tokens: torch.Tensor
776
+ ) -> torch.Tensor:
777
+ """Process combined tokens through transformer with gradient checkpointing."""
778
+ # Combine Gaussian and image tokens
779
+ combined_tokens = torch.cat((gaussian_tokens, image_tokens), dim=1)
780
+ combined_tokens = self.input_layer_norm(combined_tokens)
781
+
782
+ # Process through transformer layers with gradient checkpointing
783
+ checkpoint_interval = self.config.training.runtime.grad_checkpoint_every
784
+ num_layers = len(self.transformer_layers)
785
+
786
+ for start_idx in range(0, num_layers, checkpoint_interval):
787
+ end_idx = start_idx + checkpoint_interval
788
+ layer_runner = self._create_transformer_layer_runner(start_idx, end_idx)
789
+
790
+ combined_tokens = torch.utils.checkpoint.checkpoint(
791
+ layer_runner,
792
+ combined_tokens,
793
+ use_reentrant=False,
794
+ )
795
+
796
+ return combined_tokens
797
+
798
+ def _apply_hard_pixel_alignment(
799
+ self,
800
+ pixel_aligned_xyz: torch.Tensor,
801
+ input_data: edict
802
+ ) -> torch.Tensor:
803
+ """Apply hard pixel alignment to ensure Gaussians align with ray directions."""
804
+ depth_bias = self.config.model.get("depth_preact_bias", 0.0)
805
+
806
+ # Apply sigmoid activation to depth values
807
+ depth_values = torch.sigmoid(
808
+ pixel_aligned_xyz.mean(dim=2, keepdim=True) + depth_bias
809
+ )
810
+
811
+ # Apply different depth computation strategies
812
+ if (self.config.model.get("use_aug_plucker", False) or
813
+ self.config.model.get("use_custom_plucker", False)):
814
+ # For Plucker coordinates: use dot product offset
815
+ ray_origin_dot_direction = torch.sum(
816
+ -input_data.ray_o * input_data.ray_d, dim=2, keepdim=True
817
+ )
818
+ depth_values = (2.0 * depth_values - 1.0) * 1.8 + ray_origin_dot_direction
819
+
820
+ elif (self.config.model.get("depth_min", -1.0) > 0.0 and
821
+ self.config.model.get("depth_max", -1.0) > 0.0):
822
+ # Use explicit depth range
823
+ depth_min = self.config.model.depth_min
824
+ depth_max = self.config.model.depth_max
825
+ depth_values = depth_values * (depth_max - depth_min) + depth_min
826
+
827
+ elif self.config.model.get("depth_reference_origin", False):
828
+ # Reference from ray origin norm
829
+ ray_origin_norm = input_data.ray_o.norm(dim=2, p=2, keepdim=True)
830
+ depth_values = (2.0 * depth_values - 1.0) * 1.8 + ray_origin_norm
831
+
832
+ else:
833
+ # Default depth computation
834
+ depth_values = (2.0 * depth_values - 1.0) * 1.5 + 2.7
835
+
836
+ # Compute final 3D positions along rays
837
+ aligned_positions = input_data.ray_o + depth_values * input_data.ray_d
838
+
839
+ # Apply coordinate clipping if enabled (only during training)
840
+ if (self.config.model.get("clip_xyz", False) and
841
+ not self.config.inference):
842
+ aligned_positions = aligned_positions.clamp(-1.0, 1.0)
843
+
844
+ return aligned_positions
845
+
846
+ @staticmethod
847
+ def translate_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
848
+ """
849
+ Translate legacy model parameter names to new parameter names.
850
+
851
+ This function allows loading models saved with the old variable names
852
+ by mapping them to the new, cleaner variable names.
853
+
854
+ Args:
855
+ state_dict: Dictionary containing model parameters with old names
856
+
857
+ Returns:
858
+ Dictionary with parameters mapped to new names
859
+ """
860
+ # Define the mapping from old names to new names
861
+ name_mapping = {
862
+ # Data processors
863
+ 'split_data.': 'data_splitter.',
864
+ 'transform_input.': 'input_transformer.',
865
+ 'transform_target.': 'target_transformer.',
866
+
867
+ # Tokenizer
868
+ 'image_tokenizer.': 'patch_embedder.',
869
+
870
+ # Positional embeddings
871
+ 'refsrc_marker': 'view_type_embeddings',
872
+ 'gaussians_pos_embedding': 'gaussian_position_embeddings',
873
+
874
+ # Transformer
875
+ 'transformer_input_layernorm.': 'input_layer_norm.',
876
+ 'transformer.': 'transformer_layers.',
877
+
878
+ # Gaussian modules
879
+ 'upsampler.': 'gaussian_upsampler.',
880
+ 'image_token_decoder.': 'pixel_gaussian_decoder.',
881
+
882
+ # Rendering modules
883
+ 'renderer.': 'gaussian_renderer.',
884
+ 'loss_computer.': 'loss_calculator.',
885
+ }
886
+
887
+ # Create new state dict with translated names
888
+ new_state_dict = {}
889
+
890
+ for old_key, value in state_dict.items():
891
+ new_key = old_key
892
+
893
+ # Apply name mappings
894
+ for old_pattern, new_pattern in name_mapping.items():
895
+ if old_key.startswith(old_pattern):
896
+ new_key = old_key.replace(old_pattern, new_pattern, 1)
897
+ break
898
+
899
+ # Fix specific key naming issues
900
+ # Change loss_computer.perceptual_loss_module.Net to loss_computer.perceptual_loss_module.net
901
+ if "loss_computer.perceptual_loss_module.Net" in new_key:
902
+ old_net_key = new_key
903
+ new_key = new_key.replace("loss_computer.perceptual_loss_module.Net", "loss_computer.perceptual_loss_module.net")
904
+ print(f"Renamed checkpoint key: {old_net_key} -> {new_key}")
905
+ # Also handle the new naming convention
906
+ elif "loss_calculator.perceptual_loss_module.Net" in new_key:
907
+ old_net_key = new_key
908
+ new_key = new_key.replace("loss_calculator.perceptual_loss_module.Net", "loss_calculator.perceptual_loss_module.net")
909
+ print(f"Renamed checkpoint key: {old_net_key} -> {new_key}")
910
+
911
+ new_state_dict[new_key] = value
912
+
913
+ return new_state_dict
914
+
915
+ def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
916
+ """
917
+ Load model state dict with automatic legacy name translation.
918
+
919
+ Args:
920
+ state_dict: Model state dictionary (potentially with old parameter names)
921
+ strict: Whether to strictly enforce parameter name matching
922
+ """
923
+ # Check if this is a legacy state dict by looking for old parameter names
924
+ legacy_indicators = [
925
+ 'image_tokenizer.',
926
+ 'refsrc_marker',
927
+ 'gaussians_pos_embedding',
928
+ 'transformer_input_layernorm.',
929
+ 'upsampler.',
930
+ 'image_token_decoder.',
931
+ 'renderer.',
932
+ 'loss_computer.'
933
+ ]
934
+
935
+ is_legacy = any(
936
+ any(key.startswith(indicator) for key in state_dict.keys())
937
+ for indicator in legacy_indicators
938
+ )
939
+
940
+ if is_legacy:
941
+ print("Detected legacy model format. Translating parameter names...")
942
+ state_dict = self.translate_legacy_state_dict(state_dict)
943
+ print("Parameter name translation completed.")
944
+
945
+ # Load the (potentially translated) state dict
946
+ return super().load_state_dict(state_dict, strict=strict)
947
+
948
+ @classmethod
949
+ def load_from_checkpoint(
950
+ cls,
951
+ checkpoint_path: str,
952
+ config: edict,
953
+ map_location: Optional[str] = None
954
+ ) -> 'GSLRM':
955
+ """
956
+ Load model from checkpoint with automatic legacy name translation.
957
+
958
+ Args:
959
+ checkpoint_path: Path to the checkpoint file
960
+ config: Model configuration
961
+ map_location: Device to map tensors to (e.g., 'cpu', 'cuda:0')
962
+
963
+ Returns:
964
+ Loaded GSLRM model
965
+ """
966
+ # Create model instance
967
+ model = cls(config)
968
+
969
+ # Load checkpoint
970
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
971
+
972
+ # Extract state dict (handle different checkpoint formats)
973
+ if isinstance(checkpoint, dict):
974
+ if 'model_state_dict' in checkpoint:
975
+ state_dict = checkpoint['model_state_dict']
976
+ elif 'state_dict' in checkpoint:
977
+ state_dict = checkpoint['state_dict']
978
+ else:
979
+ state_dict = checkpoint
980
+ else:
981
+ state_dict = checkpoint
982
+
983
+ # Load state dict with automatic translation
984
+ model.load_state_dict(state_dict)
985
+
986
+ print(f"Successfully loaded model from {checkpoint_path}")
987
+ return model
988
+
989
+ def _create_gaussian_models_and_stats(
990
+ self,
991
+ xyz: torch.Tensor,
992
+ features: torch.Tensor,
993
+ scaling: torch.Tensor,
994
+ rotation: torch.Tensor,
995
+ opacity: torch.Tensor,
996
+ num_pixel_aligned: int,
997
+ num_views: int,
998
+ height: int,
999
+ width: int,
1000
+ patch_size: int
1001
+ ) -> Tuple[List, torch.Tensor, List[float]]:
1002
+ """
1003
+ Create Gaussian models for each batch item and compute usage statistics.
1004
+
1005
+ Returns:
1006
+ Tuple of (gaussian_models, pixel_aligned_positions, usage_statistics)
1007
+ """
1008
+ gaussian_models = []
1009
+ pixel_aligned_positions_list = []
1010
+ usage_statistics = []
1011
+
1012
+ batch_size = xyz.size(0)
1013
+ opacity_threshold = 0.05
1014
+
1015
+ for batch_idx in range(batch_size):
1016
+ # Create fresh Gaussian model for this batch item
1017
+ self.gaussian_renderer.gaussians_model.empty()
1018
+ gaussian_model = copy.deepcopy(self.gaussian_renderer.gaussians_model)
1019
+
1020
+ # Set Gaussian data
1021
+ gaussian_model = gaussian_model.set_data(
1022
+ xyz[batch_idx].detach().float(),
1023
+ features[batch_idx].detach().float(),
1024
+ scaling[batch_idx].detach().float(),
1025
+ rotation[batch_idx].detach().float(),
1026
+ opacity[batch_idx].detach().float(),
1027
+ )
1028
+ gaussian_models.append(gaussian_model)
1029
+
1030
+ # Compute usage statistics (fraction of Gaussians above opacity threshold)
1031
+ opacity_mask = gaussian_model.get_opacity > opacity_threshold
1032
+ usage_ratio = opacity_mask.sum() / opacity_mask.numel()
1033
+ if torch.is_tensor(usage_ratio):
1034
+ usage_ratio = usage_ratio.item()
1035
+ usage_statistics.append(usage_ratio)
1036
+
1037
+ # Extract pixel-aligned positions and reshape
1038
+ pixel_xyz = gaussian_model.get_xyz[-num_pixel_aligned:, :]
1039
+ pixel_xyz_reshaped = rearrange(
1040
+ pixel_xyz,
1041
+ "(views height width patch_h patch_w) coords -> views coords (height patch_h) (width patch_w)",
1042
+ views=num_views,
1043
+ height=height // patch_size,
1044
+ width=width // patch_size,
1045
+ patch_h=patch_size,
1046
+ patch_w=patch_size,
1047
+ )
1048
+ pixel_aligned_positions_list.append(pixel_xyz_reshaped)
1049
+
1050
+ # Stack pixel-aligned positions
1051
+ pixel_aligned_positions = torch.stack(pixel_aligned_positions_list, dim=0)
1052
+
1053
+ return gaussian_models, pixel_aligned_positions, usage_statistics
1054
+
1055
+ def forward(
1056
+ self,
1057
+ batch_data: edict,
1058
+ create_visual: bool = False,
1059
+ split_data: bool = True
1060
+ ) -> edict:
1061
+ """
1062
+ Forward pass of the GSLRM model.
1063
+
1064
+ Args:
1065
+ batch_data: Input batch containing:
1066
+ - image: Multi-view images [batch, views, channels, height, width]
1067
+ - fxfycxcy: Camera intrinsics [batch, views, 4]
1068
+ - c2w: Camera-to-world matrices [batch, views, 4, 4]
1069
+ create_visual: Whether to create visualization outputs
1070
+ split_data: Whether to split input/target data
1071
+
1072
+ Returns:
1073
+ Dictionary containing model outputs including Gaussians, renders, and losses
1074
+ """
1075
+ with torch.no_grad():
1076
+ target_data = None
1077
+ if split_data:
1078
+ batch_data, target_data = self.data_splitter(
1079
+ batch_data, self.config.training.dataset.target_has_input
1080
+ )
1081
+ target_data = self.target_transformer(target_data)
1082
+
1083
+ input_data = self.input_transformer(batch_data)
1084
+
1085
+ # Prepare posed images with Plucker coordinates [batch, views, channels, height, width]
1086
+ posed_images = self._create_posed_images_with_plucker(input_data)
1087
+
1088
+ # Process images through tokenization and transformer
1089
+ batch_size, num_views, channels, height, width = posed_images.size()
1090
+
1091
+ # Tokenize images into patches
1092
+ image_patch_tokens = self.patch_embedder(posed_images) # [batch*views, num_patches, hidden_dim]
1093
+ _, num_patches, hidden_dim = image_patch_tokens.size()
1094
+ image_patch_tokens = image_patch_tokens.reshape(
1095
+ batch_size, num_views * num_patches, hidden_dim
1096
+ ) # [batch, views*patches, hidden_dim]
1097
+
1098
+ # Add view type embeddings if enabled (reference vs source views)
1099
+ if self.view_type_embeddings is not None:
1100
+ image_patch_tokens = self._add_view_type_embeddings(
1101
+ image_patch_tokens, batch_size, num_views, num_patches, hidden_dim
1102
+ )
1103
+
1104
+ # Prepare Gaussian tokens with positional embeddings
1105
+ gaussian_tokens = self.gaussian_position_embeddings.expand(batch_size, -1, -1)
1106
+
1107
+ # Process through transformer with gradient checkpointing
1108
+ combined_tokens = self._process_through_transformer(
1109
+ gaussian_tokens, image_patch_tokens
1110
+ )
1111
+
1112
+ # Split back into Gaussian and image tokens
1113
+ num_gaussians = self.config.model.gaussians.n_gaussians
1114
+ gaussian_tokens, image_patch_tokens = combined_tokens.split(
1115
+ [num_gaussians, num_views * num_patches], dim=1
1116
+ )
1117
+
1118
+ # Generate Gaussian parameters from transformer outputs
1119
+ gaussian_params = self.gaussian_upsampler(gaussian_tokens, image_patch_tokens)
1120
+
1121
+ # Generate pixel-aligned Gaussians from image tokens
1122
+ pixel_aligned_gaussian_params = self.pixel_gaussian_decoder(image_patch_tokens)
1123
+
1124
+ # Calculate Gaussian parameter dimensions
1125
+ sh_degree = self.config.model.gaussians.sh_degree
1126
+ gaussian_param_dim = 3 + (sh_degree + 1) ** 2 * 3 + 3 + 4 + 1
1127
+
1128
+ pixel_aligned_gaussian_params = pixel_aligned_gaussian_params.reshape(
1129
+ batch_size, -1, gaussian_param_dim
1130
+ ) # [batch, views*pixels, gaussian_params]
1131
+ num_pixel_aligned_gaussians = pixel_aligned_gaussian_params.size(1)
1132
+
1133
+ # Combine all Gaussian parameters
1134
+ all_gaussian_params = torch.cat((gaussian_params, pixel_aligned_gaussian_params), dim=1)
1135
+
1136
+ # Convert to final Gaussian format
1137
+ xyz, features, scaling, rotation, opacity = self.gaussian_upsampler.to_gs(all_gaussian_params)
1138
+
1139
+ # Extract pixel-aligned Gaussian positions for processing
1140
+ pixel_aligned_xyz = xyz[:, -num_pixel_aligned_gaussians:, :]
1141
+ patch_size = self.config.model.image_tokenizer.patch_size
1142
+
1143
+ pixel_aligned_xyz = rearrange(
1144
+ pixel_aligned_xyz,
1145
+ "batch (views height width patch_h patch_w) coords -> batch views coords (height patch_h) (width patch_w)",
1146
+ views=num_views,
1147
+ height=height // patch_size,
1148
+ width=width // patch_size,
1149
+ patch_h=patch_size,
1150
+ patch_w=patch_size,
1151
+ )
1152
+
1153
+ # Apply hard pixel alignment if enabled
1154
+ if self.config.model.hard_pixelalign:
1155
+ pixel_aligned_xyz = self._apply_hard_pixel_alignment(
1156
+ pixel_aligned_xyz, input_data
1157
+ )
1158
+
1159
+ # Reshape back to flat format and update xyz
1160
+ pixel_aligned_xyz_flat = rearrange(
1161
+ pixel_aligned_xyz,
1162
+ "batch views coords (height patch_h) (width patch_w) -> batch (views height width patch_h patch_w) coords",
1163
+ patch_h=patch_size,
1164
+ patch_w=patch_size,
1165
+ )
1166
+
1167
+ # Replace pixel-aligned Gaussians in the full xyz tensor
1168
+ xyz = torch.cat(
1169
+ (xyz[:, :-num_pixel_aligned_gaussians, :], pixel_aligned_xyz_flat),
1170
+ dim=1
1171
+ )
1172
+
1173
+ # Create Gaussian splatting result structure
1174
+ gaussian_splat_result = edict(
1175
+ xyz=xyz,
1176
+ features=features,
1177
+ scaling=scaling,
1178
+ rotation=rotation,
1179
+ opacity=opacity,
1180
+ )
1181
+
1182
+ # Perform rendering and loss computation if target data is available
1183
+ loss_metrics = None
1184
+ rendered_images = None
1185
+
1186
+ if target_data is not None:
1187
+ target_height, target_width = target_data.image.size(3), target_data.image.size(4)
1188
+
1189
+ # Render images using Gaussian splatting
1190
+ rendered_images = self.gaussian_renderer(
1191
+ xyz, features, scaling, rotation, opacity,
1192
+ target_height, target_width,
1193
+ C2W=target_data.c2w,
1194
+ fxfycxcy=target_data.fxfycxcy,
1195
+ )
1196
+
1197
+ # Compute losses if rendered and target have matching dimensions
1198
+ if rendered_images.shape[1] == target_data.image.shape[1]:
1199
+ loss_metrics = self.loss_calculator(
1200
+ rendered_images,
1201
+ target_data.image,
1202
+ pixel_aligned_xyz,
1203
+ input_data,
1204
+ create_visual=create_visual,
1205
+ result_softpa=gaussian_splat_result,
1206
+ )
1207
+
1208
+ # Create Gaussian models for each batch item and compute usage statistics
1209
+ gaussian_models, pixel_aligned_positions, usage_statistics = self._create_gaussian_models_and_stats(
1210
+ xyz, features, scaling, rotation, opacity,
1211
+ num_pixel_aligned_gaussians, num_views, height, width, patch_size
1212
+ )
1213
+
1214
+ # Add usage statistics to loss metrics for logging
1215
+ if loss_metrics is not None:
1216
+ loss_metrics.gaussians_usage = torch.tensor(
1217
+ np.mean(np.array(usage_statistics))
1218
+ ).float()
1219
+
1220
+ # Compile final results
1221
+ return edict(
1222
+ input=input_data,
1223
+ target=target_data,
1224
+ gaussians=gaussian_models,
1225
+ pixelalign_xyz=pixel_aligned_positions,
1226
+ img_tokens=image_patch_tokens,
1227
+ loss_metrics=loss_metrics,
1228
+ render=rendered_images,
1229
+ )
1230
+
1231
+ @torch.no_grad()
1232
+ def save_visualization_outputs(
1233
+ self,
1234
+ output_directory: str,
1235
+ model_results: edict,
1236
+ batch_data: edict,
1237
+ save_all_items: bool = False
1238
+ ) -> None:
1239
+ """
1240
+ Save visualization outputs including rendered images and Gaussian models.
1241
+
1242
+ Args:
1243
+ output_directory: Directory to save outputs
1244
+ model_results: Results from model forward pass
1245
+ batch_data: Original batch data
1246
+ save_all_items: Whether to save all batch items or just the first
1247
+ """
1248
+ os.makedirs(output_directory, exist_ok=True)
1249
+
1250
+ input_data, target_data = model_results.input, model_results.target
1251
+
1252
+ # Save supervision visualization if available
1253
+ if (model_results.loss_metrics is not None and
1254
+ model_results.loss_metrics.visual is not None):
1255
+
1256
+ batch_uids = [
1257
+ target_data.index[b, 0, -1].item()
1258
+ for b in range(target_data.index.size(0))
1259
+ ]
1260
+
1261
+ uid_range = f"{batch_uids[0]:08}_{batch_uids[-1]:08}"
1262
+
1263
+ # Save supervision comparison image
1264
+ Image.fromarray(model_results.loss_metrics.visual).save(
1265
+ os.path.join(output_directory, f"supervision_{uid_range}.jpg")
1266
+ )
1267
+
1268
+ # Save UIDs for reference
1269
+ with open(os.path.join(output_directory, "uids.txt"), "w") as f:
1270
+ uid_string = "_".join([f"{uid:08}" for uid in batch_uids])
1271
+ f.write(uid_string)
1272
+
1273
+ # Save input images
1274
+ input_visualization = rearrange(
1275
+ input_data.image, "batch views channels height width -> (batch height) (views width) channels"
1276
+ )
1277
+ input_visualization = (
1278
+ (input_visualization.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1279
+ )
1280
+ Image.fromarray(input_visualization[..., :3]).save(
1281
+ os.path.join(output_directory, f"input_{uid_range}.jpg")
1282
+ )
1283
+
1284
+ # Process each batch item individually
1285
+ batch_size = input_data.image.size(0)
1286
+ for batch_idx in range(batch_size):
1287
+ item_uid = input_data.index[batch_idx, 0, -1].item()
1288
+
1289
+ # Render turntable visualization
1290
+ turntable_image = render_turntable(model_results.gaussians[batch_idx])
1291
+ Image.fromarray(turntable_image).save(
1292
+ os.path.join(output_directory, f"turntable_{item_uid}.jpg")
1293
+ )
1294
+
1295
+ # Save individual input images during inference
1296
+ if self.config.inference:
1297
+ individual_input = rearrange(
1298
+ input_data.image[batch_idx], "views channels height width -> height (views width) channels"
1299
+ )
1300
+ individual_input = (
1301
+ (individual_input.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1302
+ )
1303
+ Image.fromarray(individual_input[..., :3]).save(
1304
+ os.path.join(output_directory, f"input_{item_uid}.jpg")
1305
+ )
1306
+
1307
+ # Extract image dimensions and create opacity/depth visualizations
1308
+ _, num_views, _, img_height, img_width = input_data.image.size()
1309
+ patch_size = self.config.model.image_tokenizer.patch_size
1310
+
1311
+ # Get opacity values for pixel-aligned Gaussians
1312
+ gaussian_opacity = model_results.gaussians[batch_idx].get_opacity
1313
+ pixel_opacity = gaussian_opacity[-num_views * img_height * img_width:]
1314
+
1315
+ # Reshape opacity to image format
1316
+ opacity_visualization = rearrange(
1317
+ pixel_opacity,
1318
+ "(views height width patch_h patch_w) channels -> (height patch_h) (views width patch_w) channels",
1319
+ views=num_views,
1320
+ height=img_height // patch_size,
1321
+ width=img_width // patch_size,
1322
+ patch_h=patch_size,
1323
+ patch_w=patch_size,
1324
+ ).squeeze(-1).cpu().numpy()
1325
+ opacity_visualization = (opacity_visualization * 255.0).clip(0.0, 255.0).astype(np.uint8)
1326
+
1327
+ # Get 3D positions and compute depth visualization
1328
+ gaussian_positions = model_results.gaussians[batch_idx].get_xyz
1329
+ pixel_positions = gaussian_positions[-num_views * img_height * img_width:]
1330
+
1331
+ # Reshape positions to image format
1332
+ pixel_positions_reshaped = rearrange(
1333
+ pixel_positions,
1334
+ "(views height width patch_h patch_w) coords -> views coords (height patch_h) (width patch_w)",
1335
+ views=num_views,
1336
+ height=img_height // patch_size,
1337
+ width=img_width // patch_size,
1338
+ patch_h=patch_size,
1339
+ patch_w=patch_size,
1340
+ )
1341
+
1342
+ # Compute distances from ray origins
1343
+ ray_distances = (pixel_positions_reshaped - input_data.ray_o[batch_idx]).norm(dim=1, p=2)
1344
+ distance_visualization = rearrange(ray_distances, "views height width -> height (views width)")
1345
+ distance_visualization = distance_visualization.cpu().numpy()
1346
+
1347
+ # Normalize distances for visualization
1348
+ dist_min, dist_max = distance_visualization.min(), distance_visualization.max()
1349
+ distance_visualization = (distance_visualization - dist_min) / (dist_max - dist_min)
1350
+ distance_visualization = (distance_visualization * 255.0).clip(0.0, 255.0).astype(np.uint8)
1351
+
1352
+ # Combine opacity and depth visualizations
1353
+ combined_visualization = np.concatenate([opacity_visualization, distance_visualization], axis=0)
1354
+ Image.fromarray(combined_visualization).save(
1355
+ os.path.join(output_directory, f"aligned_gs_opacity_depth_{item_uid}.jpg")
1356
+ )
1357
+
1358
+ # Save unfiltered Gaussian model for small images during early training
1359
+ if (self.config.model.image_tokenizer.image_size <= 256 and
1360
+ self.training_step is not None and self.training_step <= 5000):
1361
+ model_results.gaussians[batch_idx].save_ply(
1362
+ os.path.join(output_directory, f"gaussians_{item_uid}_unfiltered.ply")
1363
+ )
1364
+
1365
+ # Save filtered Gaussian model
1366
+ camera_origins = None # Could use input_data.ray_o[batch_idx, :, :, 0, 0] if needed
1367
+ default_crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
1368
+
1369
+ model_results.gaussians[batch_idx].apply_all_filters(
1370
+ opacity_thres=0.02,
1371
+ crop_bbx=default_crop_box,
1372
+ cam_origins=camera_origins,
1373
+ nearfar_percent=(0.0001, 1.0),
1374
+ ).save_ply(os.path.join(output_directory, f"gaussians_{item_uid}.ply"))
1375
+
1376
+ print(f"Saved visualization for UID: {item_uid}")
1377
+
1378
+ # Break after first item unless saving all
1379
+ if not save_all_items:
1380
+ break
1381
+
1382
+ @torch.no_grad()
1383
+ def save_visuals(self, out_dir: str, result: edict, batch: edict, save_all: bool = False) -> None:
1384
+ """Backward compatibility wrapper for save_visualization_outputs."""
1385
+ self.save_visualization_outputs(out_dir, result, batch, save_all)
1386
+
1387
+ @torch.no_grad()
1388
+ def save_evaluation_results(
1389
+ self,
1390
+ output_directory: str,
1391
+ model_results: edict,
1392
+ batch_data: edict,
1393
+ dataset
1394
+ ) -> None:
1395
+ """Save comprehensive evaluation results including metrics, visualizations, and 3D models."""
1396
+ from .utils_metrics import compute_psnr, compute_lpips, compute_ssim
1397
+
1398
+ os.makedirs(output_directory, exist_ok=True)
1399
+ input_data, target_data = model_results.input, model_results.target
1400
+
1401
+ for batch_idx in range(input_data.image.size(0)):
1402
+ item_uid = input_data.index[batch_idx, 0, -1].item()
1403
+ item_output_dir = os.path.join(output_directory, f"{item_uid:08d}")
1404
+ os.makedirs(item_output_dir, exist_ok=True)
1405
+
1406
+ # Save input image
1407
+ input_image = rearrange(
1408
+ input_data.image[batch_idx], "views channels height width -> height (views width) channels"
1409
+ )
1410
+ input_image = (input_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1411
+ Image.fromarray(input_image[..., :3]).save(os.path.join(item_output_dir, "input.png"))
1412
+
1413
+ # Save ground truth vs prediction comparison
1414
+ comparison_image = torch.stack((target_data.image[batch_idx], model_results.render[batch_idx]), dim=0)
1415
+ num_views = comparison_image.size(1)
1416
+ if num_views > 10:
1417
+ comparison_image = comparison_image[:, ::num_views // 10, :, :, :]
1418
+ comparison_image = rearrange(
1419
+ comparison_image, "comparison_type views channels height width -> (comparison_type height) (views width) channels"
1420
+ )
1421
+ comparison_image = (comparison_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1422
+ Image.fromarray(comparison_image).save(os.path.join(item_output_dir, "gt_vs_pred.png"))
1423
+
1424
+ # Compute and save metrics
1425
+ per_view_psnr = compute_psnr(target_data.image[batch_idx], model_results.render[batch_idx])
1426
+ per_view_lpips = compute_lpips(target_data.image[batch_idx], model_results.render[batch_idx])
1427
+ per_view_ssim = compute_ssim(target_data.image[batch_idx], model_results.render[batch_idx])
1428
+
1429
+ # Save per-view metrics
1430
+ view_ids = target_data.index[batch_idx, :, 0].cpu().numpy()
1431
+ with open(os.path.join(item_output_dir, "perview_metrics.txt"), "w") as f:
1432
+ for i in range(per_view_psnr.size(0)):
1433
+ f.write(
1434
+ f"view {view_ids[i]:0>6}, psnr: {per_view_psnr[i].item():.4f}, "
1435
+ f"lpips: {per_view_lpips[i].item():.4f}, ssim: {per_view_ssim[i].item():.4f}\n"
1436
+ )
1437
+
1438
+ # Save average metrics
1439
+ avg_psnr = per_view_psnr.mean().item()
1440
+ avg_lpips = per_view_lpips.mean().item()
1441
+ avg_ssim = per_view_ssim.mean().item()
1442
+
1443
+ with open(os.path.join(item_output_dir, "metrics.txt"), "w") as f:
1444
+ f.write(f"psnr: {avg_psnr:.4f}\nlpips: {avg_lpips:.4f}\nssim: {avg_ssim:.4f}\n")
1445
+
1446
+ print(f"UID {item_uid}: PSNR={avg_psnr:.4f}, LPIPS={avg_lpips:.4f}, SSIM={avg_ssim:.4f}")
1447
+
1448
+ # Save Gaussian model
1449
+ crop_box = None
1450
+ if self.config.model.get("clip_xyz", False):
1451
+ if self.config.model.get("half_bbx_size", None) is not None:
1452
+ half_size = self.config.model.half_bbx_size
1453
+ crop_box = [-half_size, half_size, -half_size, half_size, -half_size, half_size]
1454
+ else:
1455
+ crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
1456
+
1457
+ model_results.gaussians[batch_idx].apply_all_filters(
1458
+ opacity_thres=0.02, crop_bbx=crop_box, cam_origins=None, nearfar_percent=(0.0001, 1.0)
1459
+ ).save_ply(os.path.join(item_output_dir, "gaussians.ply"))
1460
+
1461
+ # Create turntable visualization
1462
+ num_turntable_views = 150
1463
+ render_resolution = input_image.shape[0]
1464
+
1465
+ turntable_frames = render_turntable(
1466
+ model_results.gaussians[batch_idx], rendering_resolution=render_resolution, num_views=num_turntable_views
1467
+ )
1468
+ turntable_frames = rearrange(
1469
+ turntable_frames, "height (views width) channels -> views height width channels", views=num_turntable_views
1470
+ )
1471
+ turntable_frames = np.ascontiguousarray(turntable_frames)
1472
+
1473
+ # Save basic turntable video
1474
+ imageseq2video(turntable_frames, os.path.join(item_output_dir, "turntable.mp4"), fps=30)
1475
+
1476
+ # Save description and preview if available
1477
+ try:
1478
+ description = dataset.get_description(item_uid)["prompt"]
1479
+ if len(description) > 0:
1480
+ with open(os.path.join(item_output_dir, "description.txt"), "w") as f:
1481
+ f.write(description)
1482
+
1483
+ # Create preview image (subsample to 10 views)
1484
+ preview_frames = turntable_frames[::num_turntable_views // 10]
1485
+ preview_image = rearrange(preview_frames, "views height width channels -> height (views width) channels")
1486
+ Image.fromarray(preview_image).save(os.path.join(item_output_dir, "turntable_preview.png"))
1487
+ except (AttributeError, KeyError):
1488
+ pass
1489
+
1490
+ # Create turntable with input overlay
1491
+ border_width = 2
1492
+ target_width = render_resolution
1493
+ target_height = int(input_image.shape[0] / input_image.shape[1] * target_width)
1494
+
1495
+ resized_input = cv2.resize(
1496
+ input_image, (target_width - border_width * 2, target_height - border_width * 2), interpolation=cv2.INTER_AREA
1497
+ )
1498
+ bordered_input = np.pad(
1499
+ resized_input, ((border_width, border_width), (border_width, border_width), (0, 0)),
1500
+ mode="constant", constant_values=200
1501
+ )
1502
+
1503
+ input_sequence = np.tile(bordered_input[None], (turntable_frames.shape[0], 1, 1, 1))
1504
+ combined_frames = np.concatenate((turntable_frames, input_sequence), axis=1)
1505
+
1506
+ imageseq2video(combined_frames, os.path.join(item_output_dir, "turntable_with_input.mp4"), fps=30)
1507
+
1508
+ @torch.no_grad()
1509
+ def save_evaluations(self, out_dir: str, result: edict, batch: edict, dataset) -> None:
1510
+ """Backward compatibility wrapper for save_evaluation_results."""
1511
+ self.save_evaluation_results(out_dir, result, batch, dataset)
1512
+
1513
+ @torch.no_grad()
1514
+ def save_validation_results(
1515
+ self,
1516
+ output_directory: str,
1517
+ model_results: edict,
1518
+ batch_data: edict,
1519
+ dataset,
1520
+ save_visualizations: bool = False
1521
+ ) -> Dict[str, float]:
1522
+ """Save validation results and compute aggregated metrics."""
1523
+ from .utils_metrics import compute_psnr, compute_lpips, compute_ssim
1524
+
1525
+ os.makedirs(output_directory, exist_ok=True)
1526
+ input_data, target_data = model_results.input, model_results.target
1527
+ validation_metrics = {"psnr": [], "lpips": [], "ssim": []}
1528
+
1529
+ for batch_idx in range(input_data.image.size(0)):
1530
+ item_uid = input_data.index[batch_idx, 0, -1].item()
1531
+ should_save_visuals = (batch_idx == 0) and save_visualizations
1532
+
1533
+ # Compute metrics (RGB only)
1534
+ target_image = target_data.image[batch_idx][:, :3, ...]
1535
+ per_view_psnr = compute_psnr(target_image, model_results.render[batch_idx])
1536
+ per_view_lpips = compute_lpips(target_image, model_results.render[batch_idx])
1537
+ per_view_ssim = compute_ssim(target_image, model_results.render[batch_idx])
1538
+
1539
+ avg_psnr = per_view_psnr.mean().item()
1540
+ avg_lpips = per_view_lpips.mean().item()
1541
+ avg_ssim = per_view_ssim.mean().item()
1542
+
1543
+ validation_metrics["psnr"].append(avg_psnr)
1544
+ validation_metrics["lpips"].append(avg_lpips)
1545
+ validation_metrics["ssim"].append(avg_ssim)
1546
+
1547
+ # Save visualizations only for first item if requested
1548
+ if should_save_visuals:
1549
+ item_output_dir = os.path.join(output_directory, f"{item_uid:08d}")
1550
+ os.makedirs(item_output_dir, exist_ok=True)
1551
+
1552
+ # Save input image
1553
+ input_image = rearrange(
1554
+ input_data.image[batch_idx][:, :3, ...], "views channels height width -> height (views width) channels"
1555
+ )
1556
+ input_image = (input_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1557
+ Image.fromarray(input_image).save(os.path.join(item_output_dir, "input.png"))
1558
+
1559
+ # Save ground truth vs prediction comparison
1560
+ comparison_image = torch.stack((target_image, model_results.render[batch_idx]), dim=0)
1561
+ num_views = comparison_image.size(1)
1562
+ if num_views > 10:
1563
+ comparison_image = comparison_image[:, ::num_views // 10, :, :, :]
1564
+ comparison_image = rearrange(
1565
+ comparison_image, "comparison_type views channels height width -> (comparison_type height) (views width) channels"
1566
+ )
1567
+ comparison_image = (comparison_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1568
+ Image.fromarray(comparison_image).save(os.path.join(item_output_dir, "gt_vs_pred.png"))
1569
+
1570
+ # Save per-view metrics
1571
+ view_ids = target_data.index[batch_idx, :, 0].cpu().numpy()
1572
+ with open(os.path.join(item_output_dir, "perview_metrics.txt"), "w") as f:
1573
+ for i in range(per_view_psnr.size(0)):
1574
+ f.write(
1575
+ f"view {view_ids[i]:0>6}, psnr: {per_view_psnr[i].item():.4f}, "
1576
+ f"lpips: {per_view_lpips[i].item():.4f}, ssim: {per_view_ssim[i].item():.4f}\n"
1577
+ )
1578
+
1579
+ # Save averaged metrics
1580
+ with open(os.path.join(item_output_dir, "metrics.txt"), "w") as f:
1581
+ f.write(f"psnr: {avg_psnr:.4f}\nlpips: {avg_lpips:.4f}\nssim: {avg_ssim:.4f}\n")
1582
+
1583
+ print(f"Validation UID {item_uid}: PSNR={avg_psnr:.4f}, LPIPS={avg_lpips:.4f}, SSIM={avg_ssim:.4f}")
1584
+
1585
+ # Save Gaussian model
1586
+ crop_box = None
1587
+ if self.config.model.get("clip_xyz", False):
1588
+ if self.config.model.get("half_bbx_size", None) is not None:
1589
+ half_size = self.config.model.half_bbx_size
1590
+ crop_box = [-half_size, half_size, -half_size, half_size, -half_size, half_size]
1591
+ else:
1592
+ crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
1593
+
1594
+ model_results.gaussians[batch_idx].apply_all_filters(
1595
+ opacity_thres=0.02, crop_bbx=crop_box, cam_origins=None, nearfar_percent=(0.0001, 1.0)
1596
+ ).save_ply(os.path.join(item_output_dir, "gaussians.ply"))
1597
+
1598
+ # Create turntable visualization
1599
+ num_turntable_views = 150
1600
+ render_resolution = input_image.shape[0]
1601
+
1602
+ turntable_frames = render_turntable(
1603
+ model_results.gaussians[batch_idx], rendering_resolution=render_resolution, num_views=num_turntable_views
1604
+ )
1605
+ turntable_frames = rearrange(
1606
+ turntable_frames, "height (views width) channels -> views height width channels", views=num_turntable_views
1607
+ )
1608
+ turntable_frames = np.ascontiguousarray(turntable_frames)
1609
+
1610
+ imageseq2video(turntable_frames, os.path.join(item_output_dir, "turntable.mp4"), fps=30)
1611
+
1612
+ # Create turntable with input overlay
1613
+ border_width = 2
1614
+ target_width = render_resolution
1615
+ target_height = int(input_image.shape[0] / input_image.shape[1] * target_width)
1616
+
1617
+ resized_input = cv2.resize(
1618
+ input_image, (target_width - border_width * 2, target_height - border_width * 2), interpolation=cv2.INTER_AREA
1619
+ )
1620
+ bordered_input = np.pad(
1621
+ resized_input, ((border_width, border_width), (border_width, border_width), (0, 0)),
1622
+ mode="constant", constant_values=200
1623
+ )
1624
+
1625
+ input_sequence = np.tile(bordered_input[None], (turntable_frames.shape[0], 1, 1, 1))
1626
+ combined_frames = np.concatenate((turntable_frames, input_sequence), axis=1)
1627
+
1628
+ imageseq2video(combined_frames, os.path.join(item_output_dir, "turntable_with_input.mp4"), fps=30)
1629
+
1630
+ # Return averaged metrics
1631
+ return {
1632
+ "psnr": torch.tensor(validation_metrics["psnr"]).mean().item(),
1633
+ "lpips": torch.tensor(validation_metrics["lpips"]).mean().item(),
1634
+ "ssim": torch.tensor(validation_metrics["ssim"]).mean().item(),
1635
+ }
1636
+
1637
+ @torch.no_grad()
1638
+ def save_validations(
1639
+ self,
1640
+ out_dir: str,
1641
+ result: edict,
1642
+ batch: edict,
1643
+ dataset,
1644
+ save_img: bool = False
1645
+ ) -> Dict[str, float]:
1646
+ """Backward compatibility wrapper for save_validation_results."""
1647
+ return self.save_validation_results(out_dir, result, batch, dataset, save_img)
gslrm/model/transform_data.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
9
+ """
10
+ Data transformation utilities for GSLRM model.
11
+
12
+ This module contains classes and utilities for transforming input and target data
13
+ for training and inference in the GSLRM (Gaussian Splatting Latent Radiance Model).
14
+ """
15
+
16
+ import itertools
17
+ import random
18
+ from typing import Dict, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from easydict import EasyDict as edict
25
+
26
+ # =============================================================================
27
+ # Utility Functions
28
+ # =============================================================================
29
+
30
+ def compute_camera_rays(
31
+ fxfycxcy: torch.Tensor,
32
+ c2w: torch.Tensor,
33
+ h: int,
34
+ w: int,
35
+ device: torch.device
36
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
37
+ """
38
+ Compute camera rays for given intrinsics and extrinsics.
39
+
40
+ Args:
41
+ fxfycxcy: Camera intrinsics [b*v, 4]
42
+ c2w: Camera-to-world matrices [b*v, 4, 4]
43
+ h: Image height
44
+ w: Image width
45
+ device: Target device
46
+
47
+ Returns:
48
+ Tuple of (ray_origins, ray_directions, ray_directions_camera)
49
+ """
50
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
51
+ y, x = y.to(device), x.to(device)
52
+
53
+ b_v = fxfycxcy.size(0)
54
+ x = x[None, :, :].expand(b_v, -1, -1).reshape(b_v, -1)
55
+ y = y[None, :, :].expand(b_v, -1, -1).reshape(b_v, -1)
56
+
57
+ # Convert to normalized camera coordinates
58
+ x = (x + 0.5 - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1]
59
+ y = (y + 0.5 - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2]
60
+ z = torch.ones_like(x)
61
+
62
+ ray_d_cam = torch.stack([x, y, z], dim=2) # [b*v, h*w, 3]
63
+ ray_d_cam = ray_d_cam / torch.norm(ray_d_cam, dim=2, keepdim=True)
64
+
65
+ # Transform to world coordinates
66
+ ray_d = torch.bmm(ray_d_cam, c2w[:, :3, :3].transpose(1, 2))
67
+ ray_d = ray_d / torch.norm(ray_d, dim=2, keepdim=True)
68
+ ray_o = c2w[:, :3, 3][:, None, :].expand_as(ray_d)
69
+
70
+ return ray_o, ray_d, ray_d_cam
71
+
72
+
73
+ def sample_patch_rays(
74
+ image: torch.Tensor,
75
+ fxfycxcy: torch.Tensor,
76
+ c2w: torch.Tensor,
77
+ patch_size: int,
78
+ h: int,
79
+ w: int
80
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
81
+ """
82
+ Sample rays at patch centers for efficient processing.
83
+
84
+ Args:
85
+ image: Input images [b*v, c, h, w]
86
+ fxfycxcy: Camera intrinsics [b*v, 4]
87
+ c2w: Camera-to-world matrices [b*v, 4, 4]
88
+ patch_size: Size of patches
89
+ h: Image height
90
+ w: Image width
91
+
92
+ Returns:
93
+ Tuple of (colors, ray_origins, ray_directions, xy_norm, projection_matrices)
94
+ """
95
+ b_v, c = image.shape[:2]
96
+ device = image.device
97
+
98
+ start_patch_center = patch_size / 2.0
99
+ y, x = torch.meshgrid(
100
+ torch.arange(h // patch_size) * patch_size + start_patch_center,
101
+ torch.arange(w // patch_size) * patch_size + start_patch_center,
102
+ indexing="ij",
103
+ )
104
+ y, x = y.to(device), x.to(device)
105
+
106
+ x_flat = x[None, :, :].expand(b_v, -1, -1).reshape(b_v, -1)
107
+ y_flat = y[None, :, :].expand(b_v, -1, -1).reshape(b_v, -1)
108
+
109
+ # Sample colors at patch centers
110
+ ray_color = F.grid_sample(
111
+ image,
112
+ torch.stack([x_flat / w * 2.0 - 1.0, y_flat / h * 2.0 - 1.0], dim=2).reshape(
113
+ b_v, -1, 1, 2
114
+ ),
115
+ align_corners=False,
116
+ ).squeeze(-1).permute(0, 2, 1).contiguous()
117
+
118
+ # Compute normalized coordinates
119
+ ray_xy_norm = torch.stack([x_flat / w, y_flat / h], dim=2)
120
+
121
+ # Compute projection matrices
122
+ K_norm = torch.eye(3, device=device).unsqueeze(0).repeat(b_v, 1, 1)
123
+ K_norm[:, 0, 0] = fxfycxcy[:, 0] / w
124
+ K_norm[:, 1, 1] = fxfycxcy[:, 1] / h
125
+ K_norm[:, 0, 2] = fxfycxcy[:, 2] / w
126
+ K_norm[:, 1, 2] = fxfycxcy[:, 3] / h
127
+
128
+ w2c = torch.inverse(c2w)
129
+ proj_mat = torch.bmm(K_norm, w2c[:, :3, :4])
130
+ proj_mat = proj_mat.reshape(b_v, 12)
131
+ proj_mat = proj_mat / (proj_mat.norm(dim=1, keepdim=True) + 1e-6)
132
+ proj_mat = proj_mat.reshape(b_v, 3, 4)
133
+ proj_mat = proj_mat * proj_mat[:, 0:1, 0:1].sign()
134
+
135
+ # Compute ray directions
136
+ x_norm = (x_flat - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1]
137
+ y_norm = (y_flat - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2]
138
+ z_norm = torch.ones_like(x_norm)
139
+
140
+ ray_d = torch.stack([x_norm, y_norm, z_norm], dim=2)
141
+ ray_d = torch.bmm(ray_d, c2w[:, :3, :3].transpose(1, 2))
142
+ ray_d = ray_d / torch.norm(ray_d, dim=2, keepdim=True)
143
+ ray_o = c2w[:, :3, 3][:, None, :].expand_as(ray_d)
144
+
145
+ return ray_color, ray_o, ray_d, ray_xy_norm, proj_mat
146
+
147
+
148
+ # =============================================================================
149
+ # Main Classes
150
+ # =============================================================================
151
+
152
+ class SplitData(nn.Module):
153
+ """
154
+ Split data batch into input and target views for training.
155
+ """
156
+
157
+ def __init__(self, config):
158
+ super().__init__()
159
+ self.config = config
160
+
161
+ @torch.no_grad()
162
+ def forward(self, data_batch: Dict[str, torch.Tensor], target_has_input: bool = True) -> Tuple[edict, edict]:
163
+ """
164
+ Split data into input and target views.
165
+
166
+ Args:
167
+ data_batch: Dictionary containing batch data
168
+ target_has_input: Whether target views can overlap with input views
169
+
170
+ Returns:
171
+ Tuple of (input_data, target_data)
172
+ """
173
+ input_data, target_data = {}, {}
174
+ index = None
175
+
176
+ for key, value in data_batch.items():
177
+ # Always use first N views as input
178
+ input_data[key] = value[:, :self.config.training.dataset.num_input_views, ...]
179
+
180
+ # Calculate num_target_views from num_views (not explicitly in config)
181
+ num_target_views = self.config.training.dataset.num_views
182
+
183
+ if num_target_views >= value.size(1):
184
+ target_data[key] = value
185
+ else:
186
+ if index is None:
187
+ index = self._generate_target_indices(
188
+ value, target_has_input
189
+ )
190
+
191
+ target_data[key] = self._gather_target_data(value, index)
192
+
193
+ return edict(input_data), edict(target_data)
194
+
195
+ def _generate_target_indices(self, value: torch.Tensor, target_has_input: bool) -> torch.Tensor:
196
+ """Generate indices for target view selection."""
197
+ b, v = value.shape[:2]
198
+
199
+ # Get config values
200
+ num_input_views = self.config.training.dataset.num_input_views
201
+ num_views = self.config.training.dataset.num_views
202
+ num_target_views = num_views # Use all views as targets
203
+
204
+ if target_has_input:
205
+ # Random sampling from all views
206
+ index = np.array([
207
+ random.sample(range(v), num_target_views)
208
+ for _ in range(b)
209
+ ])
210
+ else:
211
+ # Use last N views to avoid overlap with input views
212
+ assert (
213
+ num_input_views + num_target_views <= num_views
214
+ ), "num_input_views + num_target_views must <= num_views to avoid duplicate views"
215
+
216
+ index = np.array([
217
+ [num_views - 1 - j for j in range(num_target_views)]
218
+ for _ in range(b)
219
+ ])
220
+
221
+ return torch.from_numpy(index).long().to(value.device)
222
+
223
+ def _gather_target_data(self, value: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
224
+ """Gather target data using provided indices."""
225
+ value_index = index
226
+ if value.dim() > 2:
227
+ dummy_dims = [1] * (value.dim() - 2)
228
+ value_index = index.reshape(index.size(0), index.size(1), *dummy_dims)
229
+
230
+ try:
231
+ return torch.gather(
232
+ value,
233
+ dim=1,
234
+ index=value_index.expand(-1, -1, *value.size()[2:]),
235
+ )
236
+ except Exception as e:
237
+ print(f"Error gathering data for key with value shape: {value.size()}")
238
+ print(f"Index shape: {value_index.size()}")
239
+ raise e
240
+
241
+
242
+ class TransformInput(nn.Module):
243
+ """
244
+ Transform input data for feeding into the transformer network.
245
+ """
246
+
247
+ def __init__(self, config):
248
+ super().__init__()
249
+ self.config = config
250
+
251
+ @torch.no_grad()
252
+ def forward(self, data_batch: edict, patch_size: Optional[int] = None) -> edict:
253
+ """
254
+ Transform input images to rays and other representations.
255
+
256
+ Args:
257
+ data_batch: Input data batch
258
+ patch_size: Optional patch size for patch-based processing
259
+
260
+ Returns:
261
+ Transformed input data
262
+ """
263
+ self._validate_input(data_batch)
264
+
265
+ image, fxfycxcy, c2w, index = (
266
+ data_batch.image, data_batch.fxfycxcy,
267
+ data_batch.c2w, data_batch.index
268
+ )
269
+
270
+ b, v, c, h, w = image.size()
271
+
272
+ # Reshape for processing
273
+ image_flat = image.reshape(b * v, c, h * w)
274
+ fxfycxcy_flat = fxfycxcy.reshape(b * v, 4)
275
+ c2w_flat = c2w.reshape(b * v, 4, 4)
276
+
277
+ # Compute normalized coordinates for full image
278
+ xy_norm = self._compute_normalized_coordinates(b, v, h, w, image.device)
279
+
280
+ # Compute camera rays
281
+ ray_o, ray_d, ray_d_cam = compute_camera_rays(
282
+ fxfycxcy_flat, c2w_flat, h, w, image.device
283
+ )
284
+
285
+ # Process patches if patch_size is provided
286
+ patch_data = self._process_patches(
287
+ image_flat, fxfycxcy_flat, c2w_flat, patch_size, h, w, b, v, c
288
+ ) if patch_size is not None else (None, None, None, None, None)
289
+
290
+ # Reshape outputs
291
+ ray_o = ray_o.reshape(b, v, h, w, 3).permute(0, 1, 4, 2, 3)
292
+ ray_d = ray_d.reshape(b, v, h, w, 3).permute(0, 1, 4, 2, 3)
293
+ ray_d_cam = ray_d_cam.reshape(b, v, h, w, 3).permute(0, 1, 4, 2, 3)
294
+
295
+ return edict(
296
+ image=image,
297
+ ray_o=ray_o,
298
+ ray_d=ray_d,
299
+ ray_d_cam=ray_d_cam,
300
+ fxfycxcy=fxfycxcy,
301
+ c2w=c2w,
302
+ index=index,
303
+ xy_norm=xy_norm,
304
+ ray_color_patch=patch_data[0],
305
+ ray_o_patch=patch_data[1],
306
+ ray_d_patch=patch_data[2],
307
+ ray_xy_norm_patch=patch_data[3],
308
+ proj_mat=patch_data[4],
309
+ )
310
+
311
+ def _validate_input(self, data_batch: edict) -> None:
312
+ """Validate input data dimensions."""
313
+ assert data_batch.image.dim() == 5, f"image dim should be 5, got {data_batch.image.dim()}"
314
+ assert data_batch.fxfycxcy.dim() == 3, f"fxfycxcy dim should be 3, got {data_batch.fxfycxcy.dim()}"
315
+ assert data_batch.c2w.dim() == 4, f"c2w dim should be 4, got {data_batch.c2w.dim()}"
316
+
317
+ def _compute_normalized_coordinates(self, b: int, v: int, h: int, w: int, device: torch.device) -> torch.Tensor:
318
+ """Compute normalized coordinates for the full image."""
319
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
320
+ y, x = y.to(device), x.to(device)
321
+
322
+ y_norm = (y + 0.5) / h * 2 - 1
323
+ x_norm = (x + 0.5) / w * 2 - 1
324
+
325
+ return torch.stack([x_norm, y_norm], dim=0)[None, None, :, :, :].expand(b, v, -1, -1, -1)
326
+
327
+ def _process_patches(
328
+ self,
329
+ image: torch.Tensor,
330
+ fxfycxcy: torch.Tensor,
331
+ c2w: torch.Tensor,
332
+ patch_size: int,
333
+ h: int,
334
+ w: int,
335
+ b: int,
336
+ v: int,
337
+ c: int
338
+ ) -> Tuple[Optional[torch.Tensor], ...]:
339
+ """Process patch-based data if patch_size is provided."""
340
+ ray_color, ray_o, ray_d, ray_xy_norm, proj_mat = sample_patch_rays(
341
+ image.reshape(b * v, c, h, w), fxfycxcy, c2w, patch_size, h, w
342
+ )
343
+
344
+ n_patch = ray_color.size(1)
345
+
346
+ return (
347
+ ray_color.reshape(b, v, n_patch, c),
348
+ ray_o.reshape(b, v, n_patch, 3),
349
+ ray_d.reshape(b, v, n_patch, 3),
350
+ ray_xy_norm.reshape(b, v, n_patch, 2),
351
+ proj_mat.reshape(b, v, 3, 4),
352
+ )
353
+
354
+
355
+ class TransformTarget(nn.Module):
356
+ """
357
+ Handles target image transformations during training.
358
+
359
+ Currently implements random cropping for data augmentation.
360
+ """
361
+
362
+ def __init__(self, config: edict):
363
+ super().__init__()
364
+ self.config = config
365
+
366
+ @torch.no_grad()
367
+ def forward(self, data_batch: edict) -> edict:
368
+ """
369
+ Apply transformations to target data.
370
+
371
+ Args:
372
+ data_batch: Dictionary containing 'image' and 'fxfycxcy'
373
+
374
+ Returns:
375
+ Transformed data batch
376
+ """
377
+ image = data_batch["image"] # [b, v, c, h, w]
378
+ fxfycxcy = data_batch["fxfycxcy"] # [b, v, 4]
379
+
380
+ b, v, c, h, w = image.size()
381
+ crop_size = getattr(self.config.training, 'crop_size', min(h, w))
382
+
383
+ # Apply random cropping if image is larger than crop size
384
+ if h > crop_size or w > crop_size:
385
+ crop_image = torch.zeros(
386
+ (b, v, c, crop_size, crop_size),
387
+ dtype=image.dtype,
388
+ device=image.device
389
+ )
390
+ crop_fxfycxcy = fxfycxcy.clone()
391
+
392
+ for i in range(b):
393
+ for j in range(v):
394
+ # Random crop position
395
+ idx_x = torch.randint(low=0, high=w - crop_size, size=(1,)).item()
396
+ idx_y = torch.randint(low=0, high=h - crop_size, size=(1,)).item()
397
+
398
+ # Apply crop
399
+ crop_image[i, j] = image[
400
+ i, j, :, idx_y:idx_y + crop_size, idx_x:idx_x + crop_size
401
+ ]
402
+
403
+ # Adjust camera intrinsics
404
+ crop_fxfycxcy[i, j, 2] -= idx_x # cx
405
+ crop_fxfycxcy[i, j, 3] -= idx_y # cy
406
+
407
+ data_batch["image"] = crop_image
408
+ data_batch["fxfycxcy"] = crop_fxfycxcy
409
+
410
+ return data_batch
gslrm/model/utils_losses.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
9
+ """
10
+ Perceptual Loss Implementation using VGG19 and SSIM Loss Implementation.
11
+
12
+ Adapted from https://github.com/zhengqili/Crowdsampling-the-Plenoptic-Function/blob/f5216f312cf82d77f8d20454b5eeb3930324630a/models/networks.py#L1478
13
+ """
14
+ import os
15
+ from typing import List, Tuple, Union, Optional
16
+
17
+ import scipy.io
18
+ import torch
19
+ import torch.nn as nn
20
+ from pytorch_msssim import SSIM
21
+
22
+ # VGG19 ImageNet normalization constants
23
+ IMAGENET_MEAN = [123.6800, 116.7790, 103.9390]
24
+
25
+ # VGG19 layer configuration
26
+ VGG19_LAYER_INDICES = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34]
27
+ VGG19_LAYER_NAMES = [
28
+ "conv1", "conv2", "conv3", "conv4", "conv5", "conv6", "conv7", "conv8",
29
+ "conv9", "conv10", "conv11", "conv12", "conv13", "conv14", "conv15", "conv16"
30
+ ]
31
+ VGG19_CHANNEL_SIZES = [64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 512]
32
+
33
+ # Perceptual loss weighting factors
34
+ LAYER_WEIGHTS = [1.0, 1/2.6, 1/4.8, 1/3.7, 1/5.6, 10/1.5]
35
+
36
+ class VGG19(nn.Module):
37
+ """
38
+ VGG19 network implementation for perceptual loss computation.
39
+
40
+ This class implements the VGG19 architecture with specific layer outputs
41
+ used for computing perceptual losses at different scales.
42
+ """
43
+
44
+ def __init__(self) -> None:
45
+ """Initialize VGG19 network layers."""
46
+ super(VGG19, self).__init__()
47
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True)
48
+ self.relu1 = nn.ReLU(inplace=True)
49
+
50
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
51
+ self.relu2 = nn.ReLU(inplace=True)
52
+ self.max1 = nn.AvgPool2d(kernel_size=2, stride=2)
53
+
54
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True)
55
+ self.relu3 = nn.ReLU(inplace=True)
56
+
57
+ self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True)
58
+ self.relu4 = nn.ReLU(inplace=True)
59
+ self.max2 = nn.AvgPool2d(kernel_size=2, stride=2)
60
+
61
+ self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True)
62
+ self.relu5 = nn.ReLU(inplace=True)
63
+
64
+ self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
65
+ self.relu6 = nn.ReLU(inplace=True)
66
+
67
+ self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
68
+ self.relu7 = nn.ReLU(inplace=True)
69
+
70
+ self.conv8 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
71
+ self.relu8 = nn.ReLU(inplace=True)
72
+ self.max3 = nn.AvgPool2d(kernel_size=2, stride=2)
73
+
74
+ self.conv9 = nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=True)
75
+ self.relu9 = nn.ReLU(inplace=True)
76
+
77
+ self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
78
+ self.relu10 = nn.ReLU(inplace=True)
79
+
80
+ self.conv11 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
81
+ self.relu11 = nn.ReLU(inplace=True)
82
+
83
+ self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
84
+ self.relu12 = nn.ReLU(inplace=True)
85
+ self.max4 = nn.AvgPool2d(kernel_size=2, stride=2)
86
+
87
+ self.conv13 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
88
+ self.relu13 = nn.ReLU(inplace=True)
89
+
90
+ self.conv14 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
91
+ self.relu14 = nn.ReLU(inplace=True)
92
+
93
+ self.conv15 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
94
+ self.relu15 = nn.ReLU(inplace=True)
95
+
96
+ self.conv16 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
97
+ self.relu16 = nn.ReLU(inplace=True)
98
+ self.max5 = nn.AvgPool2d(kernel_size=2, stride=2)
99
+
100
+ def forward(self, x: torch.Tensor, return_style: int) -> Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]:
101
+ """
102
+ Forward pass through VGG19 network.
103
+
104
+ Args:
105
+ x: Input tensor of shape [B, 3, H, W]
106
+ return_style: If > 0, return style features as list; otherwise return content features as tuple
107
+
108
+ Returns:
109
+ Either a list of style features or tuple of content features from different layers
110
+ """
111
+ out1 = self.conv1(x)
112
+ out2 = self.relu1(out1)
113
+
114
+ out3 = self.conv2(out2)
115
+ out4 = self.relu2(out3)
116
+ out5 = self.max1(out4)
117
+
118
+ out6 = self.conv3(out5)
119
+ out7 = self.relu3(out6)
120
+ out8 = self.conv4(out7)
121
+ out9 = self.relu4(out8)
122
+ out10 = self.max2(out9)
123
+ out11 = self.conv5(out10)
124
+ out12 = self.relu5(out11)
125
+ out13 = self.conv6(out12)
126
+ out14 = self.relu6(out13)
127
+ out15 = self.conv7(out14)
128
+ out16 = self.relu7(out15)
129
+ out17 = self.conv8(out16)
130
+ out18 = self.relu8(out17)
131
+ out19 = self.max3(out18)
132
+ out20 = self.conv9(out19)
133
+ out21 = self.relu9(out20)
134
+ out22 = self.conv10(out21)
135
+ out23 = self.relu10(out22)
136
+ out24 = self.conv11(out23)
137
+ out25 = self.relu11(out24)
138
+ out26 = self.conv12(out25)
139
+ out27 = self.relu12(out26)
140
+ out28 = self.max4(out27)
141
+ out29 = self.conv13(out28)
142
+ out30 = self.relu13(out29)
143
+ out31 = self.conv14(out30)
144
+ out32 = self.relu14(out31)
145
+
146
+ if return_style > 0:
147
+ return [out2, out7, out12, out21, out30]
148
+ else:
149
+ return out4, out9, out14, out23, out32
150
+
151
+
152
+ class PerceptualLoss(nn.Module):
153
+ """
154
+ Perceptual Loss module using pre-trained VGG19.
155
+
156
+ This class implements perceptual loss by comparing features extracted from
157
+ different layers of a pre-trained VGG19 network. It computes weighted
158
+ differences across multiple scales to capture both low-level and high-level
159
+ visual differences between images.
160
+ """
161
+
162
+ def __init__(self, device: str = "cpu", weight_file: Optional[str] = None) -> None:
163
+ """
164
+ Initialize PerceptualLoss module.
165
+
166
+ Args:
167
+ device: Device to run computations on ('cpu' or 'cuda')
168
+ weight_file: Path to VGG19 weight file. If None, uses default path or environment variable.
169
+
170
+ Raises:
171
+ FileNotFoundError: If weight file is not found
172
+ RuntimeError: If weight file cannot be loaded
173
+ """
174
+ super().__init__()
175
+ self.device = device
176
+ self.net = VGG19()
177
+
178
+ # Determine weight file path
179
+ if weight_file is None:
180
+ # Check environment variable first
181
+ weight_file = os.environ.get('VGG19_WEIGHTS_PATH')
182
+ if weight_file is None:
183
+ # Fallback to default path
184
+ weight_file = "/sensei-fs/users/kaiz/repos/weight-collections/imagenet-vgg-verydeep-19.mat"
185
+
186
+ # Load VGG19 weights
187
+ if not os.path.isfile(weight_file):
188
+ raise FileNotFoundError(
189
+ f"VGG19 weight file not found: {weight_file}\n"
190
+ f"Download it from: https://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat\n"
191
+ f"Expected MD5: 106118b7cf60435e6d8e04f6a6dc3657"
192
+ )
193
+
194
+ try:
195
+ vgg_rawnet = scipy.io.loadmat(weight_file)
196
+ vgg_layers = vgg_rawnet["layers"][0]
197
+ except Exception as e:
198
+ raise RuntimeError(f"Failed to load VGG19 weights from {weight_file}: {e}")
199
+
200
+ # Load pre-trained weights into the network
201
+ self._load_pretrained_weights(vgg_layers)
202
+
203
+ # Set network to evaluation mode and freeze parameters
204
+ self.net = self.net.eval().to(device)
205
+ for param in self.net.parameters():
206
+ param.requires_grad = False
207
+
208
+ def _load_pretrained_weights(self, vgg_layers) -> None:
209
+ """Load pre-trained VGG19 weights into the network."""
210
+ for layer_idx in range(len(VGG19_LAYER_NAMES)):
211
+ layer_name = VGG19_LAYER_NAMES[layer_idx]
212
+ mat_layer_idx = VGG19_LAYER_INDICES[layer_idx]
213
+ channel_size = VGG19_CHANNEL_SIZES[layer_idx]
214
+
215
+ # Extract weights and biases from MATLAB format
216
+ layer_weights = torch.from_numpy(
217
+ vgg_layers[mat_layer_idx][0][0][2][0][0]
218
+ ).permute(3, 2, 0, 1)
219
+ layer_biases = torch.from_numpy(
220
+ vgg_layers[mat_layer_idx][0][0][2][0][1]
221
+ ).view(channel_size)
222
+
223
+ # Assign to network
224
+ getattr(self.net, layer_name).weight = nn.Parameter(layer_weights)
225
+ getattr(self.net, layer_name).bias = nn.Parameter(layer_biases)
226
+
227
+ def _compute_l1_error(self, truth: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
228
+ """
229
+ Compute L1 (Mean Absolute Error) between two tensors.
230
+
231
+ Args:
232
+ truth: Ground truth tensor
233
+ pred: Predicted tensor
234
+
235
+ Returns:
236
+ L1 error as a scalar tensor
237
+ """
238
+ return torch.mean(torch.abs(truth - pred))
239
+
240
+ def forward(self, pred_img: torch.Tensor, real_img: torch.Tensor) -> torch.Tensor:
241
+ """
242
+ Compute perceptual loss between predicted and real images.
243
+
244
+ Args:
245
+ pred_img: Predicted image tensor of shape [B, 3, H, W] in range [0, 1]
246
+ real_img: Real image tensor of shape [B, 3, H, W] in range [0, 1]
247
+
248
+ Returns:
249
+ Perceptual loss as a scalar tensor
250
+ """
251
+ # Convert to ImageNet normalization (RGB -> BGR and subtract mean)
252
+ imagenet_mean = torch.tensor(IMAGENET_MEAN, dtype=torch.float32, device=pred_img.device)
253
+ imagenet_mean = imagenet_mean.view(1, 3, 1, 1)
254
+
255
+ # Scale to [0, 255] and apply ImageNet normalization
256
+ real_img_normalized = real_img * 255.0 - imagenet_mean
257
+ pred_img_normalized = pred_img * 255.0 - imagenet_mean
258
+
259
+ # Extract features from both images
260
+ real_features = self.net(real_img_normalized, return_style=0)
261
+ pred_features = self.net(pred_img_normalized, return_style=0)
262
+
263
+ # Compute weighted L1 losses at different scales
264
+ losses = []
265
+
266
+ # Raw image loss
267
+ raw_loss = self._compute_l1_error(real_img_normalized, pred_img_normalized)
268
+ losses.append(raw_loss * LAYER_WEIGHTS[0])
269
+
270
+ # Feature losses at different VGG layers
271
+ for i, (real_feat, pred_feat) in enumerate(zip(real_features, pred_features)):
272
+ feature_loss = self._compute_l1_error(real_feat, pred_feat)
273
+ losses.append(feature_loss * LAYER_WEIGHTS[i + 1])
274
+
275
+ # Combine all losses and normalize
276
+ total_loss = sum(losses) / 255.0
277
+ return total_loss
278
+
279
+ class SsimLoss(nn.Module):
280
+ """
281
+ SSIM Loss module that computes 1 - SSIM for image similarity.
282
+
283
+ Args:
284
+ data_range: Range of input data (default: 1.0 for [0,1] range)
285
+ """
286
+
287
+ def __init__(self, data_range: float = 1.0) -> None:
288
+ super().__init__()
289
+ self.data_range = data_range
290
+ self.ssim_module = SSIM(
291
+ win_size=11,
292
+ win_sigma=1.5,
293
+ data_range=self.data_range,
294
+ size_average=True,
295
+ channel=3,
296
+ )
297
+
298
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
299
+ """
300
+ Compute SSIM loss between two image tensors.
301
+
302
+ Args:
303
+ x: Image tensor of shape (N, C, H, W)
304
+ y: Image tensor of shape (N, C, H, W)
305
+
306
+ Returns:
307
+ SSIM loss (1 - SSIM similarity)
308
+ """
309
+ return 1.0 - self.ssim_module(x, y)
gslrm/model/utils_transformer.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
9
+ """
10
+ Transformer utilities for GSLRM.
11
+
12
+ This module contains the core transformer components used by the GSLRM model,
13
+ including self-attention, MLP layers, and transformer blocks.
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from einops import rearrange
20
+
21
+ try:
22
+ import xformers.ops as xops
23
+ except ImportError as e:
24
+ print("Please install xformers to use flashatt v2")
25
+ raise e
26
+
27
+
28
+ def _init_weights(module):
29
+ """
30
+ Initialize weights for transformer modules.
31
+
32
+ Reference: https://github.com/karpathy/nanoGPT/blob/eba36e84649f3c6d840a93092cb779a260544d08/model.py#L162-L168
33
+
34
+ Args:
35
+ module: Neural network module to initialize
36
+ """
37
+ if isinstance(module, nn.Linear):
38
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
39
+ if module.bias is not None:
40
+ torch.nn.init.zeros_(module.bias)
41
+ elif isinstance(module, nn.Embedding):
42
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
43
+
44
+
45
+ class MLP(nn.Module):
46
+ """
47
+ Multi-layer perceptron with GELU activation.
48
+
49
+ Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L49-L65
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ d,
55
+ mlp_ratio=4,
56
+ mlp_bias=False,
57
+ mlp_dropout=0.0,
58
+ mlp_dim=None,
59
+ ):
60
+ """
61
+ Initialize MLP layer.
62
+
63
+ Args:
64
+ d: Input/output dimension
65
+ mlp_ratio: Hidden dimension ratio (hidden_dim = d * mlp_ratio)
66
+ mlp_bias: Whether to use bias in linear layers
67
+ mlp_dropout: Dropout probability
68
+ mlp_dim: Explicit hidden dimension (overrides mlp_ratio if provided)
69
+ """
70
+ super().__init__()
71
+ if mlp_dim is None:
72
+ mlp_dim = d * mlp_ratio
73
+
74
+ self.mlp = nn.Sequential(
75
+ nn.Linear(d, mlp_dim, bias=mlp_bias),
76
+ nn.GELU(),
77
+ nn.Linear(mlp_dim, d, bias=mlp_bias),
78
+ nn.Dropout(mlp_dropout),
79
+ )
80
+
81
+ def forward(self, x):
82
+ """
83
+ Forward pass through MLP.
84
+
85
+ Args:
86
+ x: Input tensor of shape (batch, seq_len, d)
87
+
88
+ Returns:
89
+ Output tensor of shape (batch, seq_len, d)
90
+ """
91
+ return self.mlp(x)
92
+
93
+
94
+ class SelfAttention(nn.Module):
95
+ """
96
+ Multi-head self-attention with flash attention support.
97
+
98
+ Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L68-L92
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ d,
104
+ d_head,
105
+ attn_qkv_bias=False,
106
+ attn_dropout=0.0,
107
+ attn_fc_bias=False,
108
+ attn_fc_dropout=0.0,
109
+ use_flashatt_v2=True,
110
+ ):
111
+ """
112
+ Initialize self-attention layer.
113
+
114
+ Args:
115
+ d: Token dimension
116
+ d_head: Head dimension
117
+ attn_qkv_bias: Whether to use bias in QKV projection
118
+ attn_dropout: Attention dropout probability
119
+ attn_fc_bias: Whether to use bias in output projection
120
+ attn_fc_dropout: Output projection dropout probability
121
+ use_flashatt_v2: Whether to use flash attention v2
122
+ """
123
+ super().__init__()
124
+ assert d % d_head == 0, f"Token dimension {d} should be divisible by head dimension {d_head}"
125
+
126
+ self.d = d
127
+ self.d_head = d_head
128
+ self.attn_dropout = attn_dropout
129
+ self.use_flashatt_v2 = use_flashatt_v2
130
+
131
+ # QKV projection (projects to 3*d for Q, K, V)
132
+ self.to_qkv = nn.Linear(d, 3 * d, bias=attn_qkv_bias)
133
+
134
+ # Output projection
135
+ self.fc = nn.Linear(d, d, bias=attn_fc_bias)
136
+ self.attn_fc_dropout = nn.Dropout(attn_fc_dropout)
137
+
138
+ def forward(self, x, subset_attention_size=None):
139
+ """
140
+ Forward pass through self-attention.
141
+
142
+ Args:
143
+ x: Input tensor of shape (batch, seq_len, d)
144
+ subset_attention_size: Optional size for subset attention
145
+
146
+ Returns:
147
+ Output tensor of shape (batch, seq_len, d)
148
+ """
149
+ # Generate Q, K, V
150
+ q, k, v = self.to_qkv(x).split(self.d, dim=2)
151
+
152
+ if self.use_flashatt_v2:
153
+ # Use xformers flash attention
154
+ q, k, v = map(
155
+ lambda t: rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.d_head),
156
+ (q, k, v),
157
+ )
158
+
159
+ if subset_attention_size is not None and subset_attention_size < q.shape[1]:
160
+ # Handle subset attention for memory efficiency
161
+ x_subset = xops.memory_efficient_attention(
162
+ q[:, :subset_attention_size, :, :].contiguous(),
163
+ k[:, :subset_attention_size, :, :].contiguous(),
164
+ v[:, :subset_attention_size, :, :].contiguous(),
165
+ attn_bias=None,
166
+ op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
167
+ )
168
+ x_rest = xops.memory_efficient_attention(
169
+ q[:, subset_attention_size:, :, :].contiguous(),
170
+ k,
171
+ v,
172
+ attn_bias=None,
173
+ op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
174
+ )
175
+ x = torch.cat([x_subset, x_rest], dim=1)
176
+ else:
177
+ # Standard flash attention
178
+ x = xops.memory_efficient_attention(
179
+ q, k, v,
180
+ attn_bias=None,
181
+ op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
182
+ )
183
+
184
+ x = rearrange(x, "b l nh dh -> b l (nh dh)")
185
+ else:
186
+ # Use PyTorch scaled dot product attention
187
+ q, k, v = (
188
+ rearrange(q, "b l (nh dh) -> b nh l dh", dh=self.d_head),
189
+ rearrange(k, "b l (nh dh) -> b nh l dh", dh=self.d_head),
190
+ rearrange(v, "b l (nh dh) -> b nh l dh", dh=self.d_head),
191
+ )
192
+
193
+ dropout_p = self.attn_dropout if self.training else 0.0
194
+
195
+ if subset_attention_size is not None and subset_attention_size < q.shape[2]:
196
+ # Handle subset attention
197
+ x_subset = F.scaled_dot_product_attention(
198
+ q[:, :, :subset_attention_size, :].contiguous(),
199
+ k[:, :, :subset_attention_size, :].contiguous(),
200
+ v[:, :, :subset_attention_size, :].contiguous(),
201
+ dropout_p=dropout_p,
202
+ )
203
+ x_rest = F.scaled_dot_product_attention(
204
+ q[:, :, subset_attention_size:, :].contiguous(),
205
+ k, v,
206
+ dropout_p=dropout_p,
207
+ )
208
+ x = torch.cat([x_subset, x_rest], dim=2)
209
+ else:
210
+ # Standard attention
211
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
212
+
213
+ x = rearrange(x, "b nh l dh -> b l (nh dh)")
214
+
215
+ # Apply output projection and dropout
216
+ return self.attn_fc_dropout(self.fc(x))
217
+
218
+
219
+ class TransformerBlock(nn.Module):
220
+ """
221
+ Standard transformer block with pre-normalization.
222
+
223
+ Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L95-L113
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ d,
229
+ d_head,
230
+ ln_bias=False,
231
+ attn_qkv_bias=False,
232
+ attn_dropout=0.0,
233
+ attn_fc_bias=False,
234
+ attn_fc_dropout=0.0,
235
+ mlp_ratio=4,
236
+ mlp_bias=False,
237
+ mlp_dropout=0.0,
238
+ ):
239
+ """
240
+ Initialize transformer block.
241
+
242
+ Args:
243
+ d: Token dimension
244
+ d_head: Attention head dimension
245
+ ln_bias: Whether to use bias in layer norm
246
+ attn_qkv_bias: Whether to use bias in attention QKV projection
247
+ attn_dropout: Attention dropout probability
248
+ attn_fc_bias: Whether to use bias in attention output projection
249
+ attn_fc_dropout: Attention output dropout probability
250
+ mlp_ratio: MLP hidden dimension ratio
251
+ mlp_bias: Whether to use bias in MLP layers
252
+ mlp_dropout: MLP dropout probability
253
+ """
254
+ super().__init__()
255
+
256
+ # Layer normalization
257
+ self.norm1 = nn.LayerNorm(d, bias=ln_bias)
258
+ self.norm2 = nn.LayerNorm(d, bias=ln_bias)
259
+
260
+ # Self-attention
261
+ self.attn = SelfAttention(
262
+ d=d,
263
+ d_head=d_head,
264
+ attn_qkv_bias=attn_qkv_bias,
265
+ attn_dropout=attn_dropout,
266
+ attn_fc_bias=attn_fc_bias,
267
+ attn_fc_dropout=attn_fc_dropout,
268
+ )
269
+
270
+ # MLP
271
+ self.mlp = MLP(
272
+ d=d,
273
+ mlp_ratio=mlp_ratio,
274
+ mlp_bias=mlp_bias,
275
+ mlp_dropout=mlp_dropout,
276
+ )
277
+
278
+ def forward(self, x, subset_attention_size=None):
279
+ """
280
+ Forward pass through transformer block.
281
+
282
+ Args:
283
+ x: Input tensor of shape (batch, seq_len, d)
284
+ subset_attention_size: Optional size for subset attention
285
+
286
+ Returns:
287
+ Output tensor of shape (batch, seq_len, d)
288
+ """
289
+ # Pre-norm attention with residual connection
290
+ x = x + self.attn(self.norm1(x), subset_attention_size=subset_attention_size)
291
+
292
+ # Pre-norm MLP with residual connection
293
+ x = x + self.mlp(self.norm2(x))
294
+
295
+ return x
mvdiffusion/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
mvdiffusion/models/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
mvdiffusion/models/transformer_mv2d_image.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
24
+ from diffusers.utils import BaseOutput, deprecate
25
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
26
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
27
+ from diffusers.models.embeddings import PatchEmbed
28
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils.import_utils import is_xformers_available
31
+
32
+ from einops import rearrange, repeat
33
+ import pdb
34
+ import random
35
+
36
+
37
+ if is_xformers_available():
38
+ import xformers
39
+ import xformers.ops
40
+ else:
41
+ xformers = None
42
+
43
+ def my_repeat(tensor, num_repeats):
44
+ """
45
+ Repeat a tensor along a given dimension
46
+ """
47
+ if len(tensor.shape) == 3:
48
+ return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
49
+ elif len(tensor.shape) == 4:
50
+ return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
51
+
52
+
53
+ @dataclass
54
+ class TransformerMV2DModelOutput(BaseOutput):
55
+ """
56
+ The output of [`Transformer2DModel`].
57
+
58
+ Args:
59
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
60
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
61
+ distributions for the unnoised latent pixels.
62
+ """
63
+
64
+ sample: torch.FloatTensor
65
+
66
+
67
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
68
+ """
69
+ A 2D Transformer model for image-like data.
70
+
71
+ Parameters:
72
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
73
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
74
+ in_channels (`int`, *optional*):
75
+ The number of channels in the input and output (specify if the input is **continuous**).
76
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
77
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
78
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
79
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
80
+ This is fixed during training since it is used to learn a number of position embeddings.
81
+ num_vector_embeds (`int`, *optional*):
82
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
83
+ Includes the class for the masked latent pixel.
84
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
85
+ num_embeds_ada_norm ( `int`, *optional*):
86
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
87
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
88
+ added to the hidden states.
89
+
90
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
91
+ attention_bias (`bool`, *optional*):
92
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
93
+ """
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ num_attention_heads: int = 16,
99
+ attention_head_dim: int = 88,
100
+ in_channels: Optional[int] = None,
101
+ out_channels: Optional[int] = None,
102
+ num_layers: int = 1,
103
+ dropout: float = 0.0,
104
+ norm_num_groups: int = 32,
105
+ cross_attention_dim: Optional[int] = None,
106
+ attention_bias: bool = False,
107
+ sample_size: Optional[int] = None,
108
+ num_vector_embeds: Optional[int] = None,
109
+ patch_size: Optional[int] = None,
110
+ activation_fn: str = "geglu",
111
+ num_embeds_ada_norm: Optional[int] = None,
112
+ use_linear_projection: bool = False,
113
+ only_cross_attention: bool = False,
114
+ upcast_attention: bool = False,
115
+ norm_type: str = "layer_norm",
116
+ norm_elementwise_affine: bool = True,
117
+ num_views: int = 1,
118
+ cd_attention_last: bool=False,
119
+ cd_attention_mid: bool=False,
120
+ multiview_attention: bool=True,
121
+ sparse_mv_attention: bool = False,
122
+ ):
123
+ super().__init__()
124
+ self.use_linear_projection = use_linear_projection
125
+ self.num_attention_heads = num_attention_heads
126
+ self.attention_head_dim = attention_head_dim
127
+ inner_dim = num_attention_heads * attention_head_dim
128
+
129
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
130
+ # Define whether input is continuous or discrete depending on configuration
131
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
132
+ self.is_input_vectorized = num_vector_embeds is not None
133
+ self.is_input_patches = in_channels is not None and patch_size is not None
134
+
135
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
136
+ deprecation_message = (
137
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
138
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
139
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
140
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
141
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
142
+ )
143
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
144
+ norm_type = "ada_norm"
145
+
146
+ if self.is_input_continuous and self.is_input_vectorized:
147
+ raise ValueError(
148
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
149
+ " sure that either `in_channels` or `num_vector_embeds` is None."
150
+ )
151
+ elif self.is_input_vectorized and self.is_input_patches:
152
+ raise ValueError(
153
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
154
+ " sure that either `num_vector_embeds` or `num_patches` is None."
155
+ )
156
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
157
+ raise ValueError(
158
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
159
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
160
+ )
161
+
162
+ # 2. Define input layers
163
+ if self.is_input_continuous:
164
+ self.in_channels = in_channels
165
+
166
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
167
+ if use_linear_projection:
168
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
169
+ else:
170
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
171
+ elif self.is_input_vectorized:
172
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
173
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
174
+
175
+ self.height = sample_size
176
+ self.width = sample_size
177
+ self.num_vector_embeds = num_vector_embeds
178
+ self.num_latent_pixels = self.height * self.width
179
+
180
+ self.latent_image_embedding = ImagePositionalEmbeddings(
181
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
182
+ )
183
+ elif self.is_input_patches:
184
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
185
+
186
+ self.height = sample_size
187
+ self.width = sample_size
188
+
189
+ self.patch_size = patch_size
190
+ self.pos_embed = PatchEmbed(
191
+ height=sample_size,
192
+ width=sample_size,
193
+ patch_size=patch_size,
194
+ in_channels=in_channels,
195
+ embed_dim=inner_dim,
196
+ )
197
+
198
+ # 3. Define transformers blocks
199
+ self.transformer_blocks = nn.ModuleList(
200
+ [
201
+ BasicMVTransformerBlock(
202
+ inner_dim,
203
+ num_attention_heads,
204
+ attention_head_dim,
205
+ dropout=dropout,
206
+ cross_attention_dim=cross_attention_dim,
207
+ activation_fn=activation_fn,
208
+ num_embeds_ada_norm=num_embeds_ada_norm,
209
+ attention_bias=attention_bias,
210
+ only_cross_attention=only_cross_attention,
211
+ upcast_attention=upcast_attention,
212
+ norm_type=norm_type,
213
+ norm_elementwise_affine=norm_elementwise_affine,
214
+ num_views=num_views,
215
+ cd_attention_last=cd_attention_last,
216
+ cd_attention_mid=cd_attention_mid,
217
+ multiview_attention=multiview_attention,
218
+ sparse_mv_attention=sparse_mv_attention
219
+ )
220
+ for d in range(num_layers)
221
+ ]
222
+ )
223
+
224
+ # 4. Define output layers
225
+ self.out_channels = in_channels if out_channels is None else out_channels
226
+ if self.is_input_continuous:
227
+ # TODO: should use out_channels for continuous projections
228
+ if use_linear_projection:
229
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
230
+ else:
231
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
232
+ elif self.is_input_vectorized:
233
+ self.norm_out = nn.LayerNorm(inner_dim)
234
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
235
+ elif self.is_input_patches:
236
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
237
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
238
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
239
+
240
+ def forward(
241
+ self,
242
+ hidden_states: torch.Tensor,
243
+ encoder_hidden_states: Optional[torch.Tensor] = None,
244
+ dino_feature: Optional[torch.Tensor] = None,
245
+ timestep: Optional[torch.LongTensor] = None,
246
+ class_labels: Optional[torch.LongTensor] = None,
247
+ cross_attention_kwargs: Dict[str, Any] = None,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
+ encoder_attention_mask: Optional[torch.Tensor] = None,
250
+ return_dict: bool = True,
251
+ ):
252
+ """
253
+ The [`Transformer2DModel`] forward method.
254
+
255
+ Args:
256
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
257
+ Input `hidden_states`.
258
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
259
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
260
+ self-attention.
261
+ timestep ( `torch.LongTensor`, *optional*):
262
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
263
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
264
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
265
+ `AdaLayerZeroNorm`.
266
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
267
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
268
+
269
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
270
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
271
+
272
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
273
+ above. This bias will be added to the cross-attention scores.
274
+ return_dict (`bool`, *optional*, defaults to `True`):
275
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
276
+ tuple.
277
+
278
+ Returns:
279
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
280
+ `tuple` where the first element is the sample tensor.
281
+ """
282
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
283
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
284
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
285
+ # expects mask of shape:
286
+ # [batch, key_tokens]
287
+ # adds singleton query_tokens dimension:
288
+ # [batch, 1, key_tokens]
289
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
290
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
291
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
292
+ if attention_mask is not None and attention_mask.ndim == 2:
293
+ # assume that mask is expressed as:
294
+ # (1 = keep, 0 = discard)
295
+ # convert mask into a bias that can be added to attention scores:
296
+ # (keep = +0, discard = -10000.0)
297
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
298
+ attention_mask = attention_mask.unsqueeze(1)
299
+
300
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
301
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
302
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
303
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
304
+
305
+ # 1. Input
306
+ if self.is_input_continuous:
307
+ batch, _, height, width = hidden_states.shape
308
+ residual = hidden_states
309
+
310
+ hidden_states = self.norm(hidden_states)
311
+ if not self.use_linear_projection:
312
+ hidden_states = self.proj_in(hidden_states)
313
+ inner_dim = hidden_states.shape[1]
314
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
315
+ else:
316
+ inner_dim = hidden_states.shape[1]
317
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
318
+ hidden_states = self.proj_in(hidden_states)
319
+ elif self.is_input_vectorized:
320
+ hidden_states = self.latent_image_embedding(hidden_states)
321
+ elif self.is_input_patches:
322
+ hidden_states = self.pos_embed(hidden_states)
323
+
324
+ # 2. Blocks
325
+ for block in self.transformer_blocks:
326
+ hidden_states = block(
327
+ hidden_states,
328
+ attention_mask=attention_mask,
329
+ encoder_hidden_states=encoder_hidden_states,
330
+ encoder_attention_mask=encoder_attention_mask,
331
+ timestep=timestep,
332
+ cross_attention_kwargs=cross_attention_kwargs,
333
+ class_labels=class_labels,
334
+ )
335
+
336
+ # 3. Output
337
+ if self.is_input_continuous:
338
+ if not self.use_linear_projection:
339
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
340
+ hidden_states = self.proj_out(hidden_states)
341
+ else:
342
+ hidden_states = self.proj_out(hidden_states)
343
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
344
+
345
+ output = hidden_states + residual
346
+ elif self.is_input_vectorized:
347
+ hidden_states = self.norm_out(hidden_states)
348
+ logits = self.out(hidden_states)
349
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
350
+ logits = logits.permute(0, 2, 1)
351
+
352
+ # log(p(x_0))
353
+ output = F.log_softmax(logits.double(), dim=1).float()
354
+ elif self.is_input_patches:
355
+ # TODO: cleanup!
356
+ conditioning = self.transformer_blocks[0].norm1.emb(
357
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
358
+ )
359
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
360
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
361
+ hidden_states = self.proj_out_2(hidden_states)
362
+
363
+ # unpatchify
364
+ height = width = int(hidden_states.shape[1] ** 0.5)
365
+ hidden_states = hidden_states.reshape(
366
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
367
+ )
368
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
369
+ output = hidden_states.reshape(
370
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
371
+ )
372
+
373
+ if not return_dict:
374
+ return (output,)
375
+
376
+ return TransformerMV2DModelOutput(sample=output)
377
+
378
+
379
+ @maybe_allow_in_graph
380
+ class BasicMVTransformerBlock(nn.Module):
381
+ r"""
382
+ A basic Transformer block.
383
+
384
+ Parameters:
385
+ dim (`int`): The number of channels in the input and output.
386
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
387
+ attention_head_dim (`int`): The number of channels in each head.
388
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
389
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
390
+ only_cross_attention (`bool`, *optional*):
391
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
392
+ double_self_attention (`bool`, *optional*):
393
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
394
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
395
+ num_embeds_ada_norm (:
396
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
397
+ attention_bias (:
398
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ dim: int,
404
+ num_attention_heads: int,
405
+ attention_head_dim: int,
406
+ dropout=0.0,
407
+ cross_attention_dim: Optional[int] = None,
408
+ activation_fn: str = "geglu",
409
+ num_embeds_ada_norm: Optional[int] = None,
410
+ attention_bias: bool = False,
411
+ only_cross_attention: bool = False,
412
+ double_self_attention: bool = False,
413
+ upcast_attention: bool = False,
414
+ norm_elementwise_affine: bool = True,
415
+ norm_type: str = "layer_norm",
416
+ final_dropout: bool = False,
417
+ num_views: int = 1,
418
+ cd_attention_last: bool = False,
419
+ cd_attention_mid: bool = False,
420
+ multiview_attention: bool = True,
421
+ sparse_mv_attention: bool = False
422
+ ):
423
+ super().__init__()
424
+ self.only_cross_attention = only_cross_attention
425
+
426
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
427
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
428
+
429
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
430
+ raise ValueError(
431
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
432
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
433
+ )
434
+
435
+ # Define 3 blocks. Each block has its own normalization layer.
436
+ # 1. Self-Attn
437
+ if self.use_ada_layer_norm:
438
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
439
+ elif self.use_ada_layer_norm_zero:
440
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
441
+ else:
442
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
443
+
444
+ self.multiview_attention = multiview_attention
445
+ self.sparse_mv_attention = sparse_mv_attention
446
+
447
+ self.attn1 = CustomAttention(
448
+ query_dim=dim,
449
+ heads=num_attention_heads,
450
+ dim_head=attention_head_dim,
451
+ dropout=dropout,
452
+ bias=attention_bias,
453
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
454
+ upcast_attention=upcast_attention,
455
+ processor=MVAttnProcessor()
456
+ )
457
+
458
+ # 2. Cross-Attn
459
+ if cross_attention_dim is not None or double_self_attention:
460
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
461
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
462
+ # the second cross attention block.
463
+ self.norm2 = (
464
+ AdaLayerNorm(dim, num_embeds_ada_norm)
465
+ if self.use_ada_layer_norm
466
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
467
+ )
468
+ self.attn2 = Attention(
469
+ query_dim=dim,
470
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
471
+ heads=num_attention_heads,
472
+ dim_head=attention_head_dim,
473
+ dropout=dropout,
474
+ bias=attention_bias,
475
+ upcast_attention=upcast_attention,
476
+ ) # is self-attn if encoder_hidden_states is none
477
+ else:
478
+ self.norm2 = None
479
+ self.attn2 = None
480
+
481
+ # 3. Feed-forward
482
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
483
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
484
+
485
+ # let chunk size default to None
486
+ self._chunk_size = None
487
+ self._chunk_dim = 0
488
+
489
+ self.num_views = num_views
490
+
491
+ self.cd_attention_last = cd_attention_last
492
+
493
+ if self.cd_attention_last:
494
+ # Joint task -Attn
495
+ self.attn_joint_last = CustomJointAttention(
496
+ query_dim=dim,
497
+ heads=num_attention_heads,
498
+ dim_head=attention_head_dim,
499
+ dropout=dropout,
500
+ bias=attention_bias,
501
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
502
+ upcast_attention=upcast_attention,
503
+ processor=JointAttnProcessor()
504
+ )
505
+ nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
506
+ self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
507
+
508
+
509
+ self.cd_attention_mid = cd_attention_mid
510
+
511
+ if self.cd_attention_mid:
512
+ print("cross-domain attn in the middle")
513
+ # Joint task -Attn
514
+ self.attn_joint_mid = CustomJointAttention(
515
+ query_dim=dim,
516
+ heads=num_attention_heads,
517
+ dim_head=attention_head_dim,
518
+ dropout=dropout,
519
+ bias=attention_bias,
520
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
521
+ upcast_attention=upcast_attention,
522
+ processor=JointAttnProcessor()
523
+ )
524
+ nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
525
+ self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
526
+
527
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
528
+ # Sets chunk feed-forward
529
+ self._chunk_size = chunk_size
530
+ self._chunk_dim = dim
531
+
532
+ def forward(
533
+ self,
534
+ hidden_states: torch.FloatTensor,
535
+ attention_mask: Optional[torch.FloatTensor] = None,
536
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
537
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
538
+ timestep: Optional[torch.LongTensor] = None,
539
+ cross_attention_kwargs: Dict[str, Any] = None,
540
+ class_labels: Optional[torch.LongTensor] = None,
541
+ ):
542
+ assert attention_mask is None # not supported yet
543
+ # Notice that normalization is always applied before the real computation in the following blocks.
544
+ # 1. Self-Attention
545
+ if self.use_ada_layer_norm:
546
+ norm_hidden_states = self.norm1(hidden_states, timestep)
547
+ elif self.use_ada_layer_norm_zero:
548
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
549
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
550
+ )
551
+ else:
552
+ norm_hidden_states = self.norm1(hidden_states)
553
+
554
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
555
+
556
+ attn_output = self.attn1(
557
+ norm_hidden_states,
558
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
559
+ attention_mask=attention_mask,
560
+ num_views=self.num_views,
561
+ multiview_attention=self.multiview_attention,
562
+ sparse_mv_attention=self.sparse_mv_attention,
563
+ **cross_attention_kwargs,
564
+ )
565
+
566
+
567
+ if self.use_ada_layer_norm_zero:
568
+ attn_output = gate_msa.unsqueeze(1) * attn_output
569
+ hidden_states = attn_output + hidden_states
570
+
571
+ # joint attention twice
572
+ if self.cd_attention_mid:
573
+ norm_hidden_states = (
574
+ self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
575
+ )
576
+ hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
577
+
578
+ # 2. Cross-Attention
579
+ if self.attn2 is not None:
580
+ norm_hidden_states = (
581
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
582
+ )
583
+
584
+ attn_output = self.attn2(
585
+ norm_hidden_states,
586
+ encoder_hidden_states=encoder_hidden_states,
587
+ attention_mask=encoder_attention_mask,
588
+ **cross_attention_kwargs,
589
+ )
590
+ hidden_states = attn_output + hidden_states
591
+
592
+ # 3. Feed-forward
593
+ norm_hidden_states = self.norm3(hidden_states)
594
+
595
+ if self.use_ada_layer_norm_zero:
596
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
597
+
598
+ if self._chunk_size is not None:
599
+ # "feed_forward_chunk_size" can be used to save memory
600
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
601
+ raise ValueError(
602
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
603
+ )
604
+
605
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
606
+ ff_output = torch.cat(
607
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
608
+ dim=self._chunk_dim,
609
+ )
610
+ else:
611
+ ff_output = self.ff(norm_hidden_states)
612
+
613
+ if self.use_ada_layer_norm_zero:
614
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
615
+
616
+ hidden_states = ff_output + hidden_states
617
+
618
+ if self.cd_attention_last:
619
+ norm_hidden_states = (
620
+ self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
621
+ )
622
+ hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
623
+
624
+ return hidden_states
625
+
626
+
627
+ class CustomAttention(Attention):
628
+ def set_use_memory_efficient_attention_xformers(
629
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
630
+ ):
631
+ processor = XFormersMVAttnProcessor()
632
+ self.set_processor(processor)
633
+ # print("using xformers attention processor")
634
+
635
+
636
+ class CustomJointAttention(Attention):
637
+ def set_use_memory_efficient_attention_xformers(
638
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
639
+ ):
640
+ processor = XFormersJointAttnProcessor()
641
+ self.set_processor(processor)
642
+ # print("using xformers attention processor")
643
+
644
+ class MVAttnProcessor:
645
+ r"""
646
+ Default processor for performing attention-related computations.
647
+ """
648
+
649
+ def __call__(
650
+ self,
651
+ attn: Attention,
652
+ hidden_states,
653
+ encoder_hidden_states=None,
654
+ attention_mask=None,
655
+ temb=None,
656
+ num_views=1,
657
+ multiview_attention=True
658
+ ):
659
+ residual = hidden_states
660
+
661
+ if attn.spatial_norm is not None:
662
+ hidden_states = attn.spatial_norm(hidden_states, temb)
663
+
664
+ input_ndim = hidden_states.ndim
665
+
666
+ if input_ndim == 4:
667
+ batch_size, channel, height, width = hidden_states.shape
668
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
669
+
670
+ batch_size, sequence_length, _ = (
671
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
672
+ )
673
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
674
+
675
+ if attn.group_norm is not None:
676
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
677
+
678
+ query = attn.to_q(hidden_states)
679
+
680
+ if encoder_hidden_states is None:
681
+ encoder_hidden_states = hidden_states
682
+ elif attn.norm_cross:
683
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
684
+
685
+ key = attn.to_k(encoder_hidden_states)
686
+ value = attn.to_v(encoder_hidden_states)
687
+
688
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
689
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
690
+ # pdb.set_trace()
691
+ # multi-view self-attention
692
+ if multiview_attention:
693
+ if num_views <= 6:
694
+ # after use xformer; possible to train with 6 views
695
+ key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
696
+ value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
697
+ else:# apply sparse attention
698
+ pass
699
+ # print("use sparse attention")
700
+ # # seems that the sparse random sampling cause problems
701
+ # # don't use random sampling, just fix the indexes
702
+ # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
703
+ # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
704
+ # allkeys = []
705
+ # allvalues = []
706
+ # all_indexes = {
707
+ # 0 : [0, 2, 3, 4],
708
+ # 1: [0, 1, 3, 5],
709
+ # 2: [0, 2, 3, 4],
710
+ # 3: [0, 2, 3, 4],
711
+ # 4: [0, 2, 3, 4],
712
+ # 5: [0, 1, 3, 5]
713
+ # }
714
+ # for jj in range(num_views):
715
+ # # valid_index = [x for x in range(0, num_views) if x!= jj]
716
+ # # indexes = random.sample(valid_index, 3) + [jj] + [0]
717
+ # indexes = all_indexes[jj]
718
+
719
+ # indexes = torch.tensor(indexes).long().to(key.device)
720
+ # allkeys.append(onekey[:, indexes])
721
+ # allvalues.append(onevalue[:, indexes])
722
+ # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
723
+ # values = torch.stack(allvalues, dim=1)
724
+ # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
725
+ # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
726
+
727
+
728
+ query = attn.head_to_batch_dim(query).contiguous()
729
+ key = attn.head_to_batch_dim(key).contiguous()
730
+ value = attn.head_to_batch_dim(value).contiguous()
731
+
732
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
733
+ hidden_states = torch.bmm(attention_probs, value)
734
+ hidden_states = attn.batch_to_head_dim(hidden_states)
735
+
736
+ # linear proj
737
+ hidden_states = attn.to_out[0](hidden_states)
738
+ # dropout
739
+ hidden_states = attn.to_out[1](hidden_states)
740
+
741
+ if input_ndim == 4:
742
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
743
+
744
+ if attn.residual_connection:
745
+ hidden_states = hidden_states + residual
746
+
747
+ hidden_states = hidden_states / attn.rescale_output_factor
748
+
749
+ return hidden_states
750
+
751
+
752
+ class XFormersMVAttnProcessor:
753
+ r"""
754
+ Default processor for performing attention-related computations.
755
+ """
756
+
757
+ def __call__(
758
+ self,
759
+ attn: Attention,
760
+ hidden_states,
761
+ encoder_hidden_states=None,
762
+ attention_mask=None,
763
+ temb=None,
764
+ num_views=1.,
765
+ multiview_attention=True,
766
+ sparse_mv_attention=False,
767
+ ):
768
+ residual = hidden_states
769
+
770
+ if attn.spatial_norm is not None:
771
+ hidden_states = attn.spatial_norm(hidden_states, temb)
772
+
773
+ input_ndim = hidden_states.ndim
774
+
775
+ if input_ndim == 4:
776
+ batch_size, channel, height, width = hidden_states.shape
777
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
778
+
779
+ batch_size, sequence_length, _ = (
780
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
781
+ )
782
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
783
+
784
+ # from yuancheng; here attention_mask is None
785
+ if attention_mask is not None:
786
+ # expand our mask's singleton query_tokens dimension:
787
+ # [batch*heads, 1, key_tokens] ->
788
+ # [batch*heads, query_tokens, key_tokens]
789
+ # so that it can be added as a bias onto the attention scores that xformers computes:
790
+ # [batch*heads, query_tokens, key_tokens]
791
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
792
+ _, query_tokens, _ = hidden_states.shape
793
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
794
+
795
+ if attn.group_norm is not None:
796
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
797
+
798
+ query = attn.to_q(hidden_states)
799
+
800
+ if encoder_hidden_states is None:
801
+ encoder_hidden_states = hidden_states
802
+ elif attn.norm_cross:
803
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
804
+
805
+ key_raw = attn.to_k(encoder_hidden_states)
806
+ value_raw = attn.to_v(encoder_hidden_states)
807
+
808
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
809
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
810
+ # pdb.set_trace()
811
+ # multi-view self-attention
812
+ if multiview_attention:
813
+ if not sparse_mv_attention:
814
+ key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
815
+ value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
816
+ else:
817
+ key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
818
+ value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
819
+ key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
820
+ value = torch.cat([value_front, value_raw], dim=1)
821
+
822
+
823
+ else:
824
+ # print("don't use multiview attention.")
825
+ key = key_raw
826
+ value = value_raw
827
+
828
+ query = attn.head_to_batch_dim(query)
829
+ key = attn.head_to_batch_dim(key)
830
+ value = attn.head_to_batch_dim(value)
831
+
832
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
833
+ hidden_states = attn.batch_to_head_dim(hidden_states)
834
+
835
+ # linear proj
836
+ hidden_states = attn.to_out[0](hidden_states)
837
+ # dropout
838
+ hidden_states = attn.to_out[1](hidden_states)
839
+
840
+ if input_ndim == 4:
841
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
842
+
843
+ if attn.residual_connection:
844
+ hidden_states = hidden_states + residual
845
+
846
+ hidden_states = hidden_states / attn.rescale_output_factor
847
+
848
+ return hidden_states
849
+
850
+
851
+
852
+ class XFormersJointAttnProcessor:
853
+ r"""
854
+ Default processor for performing attention-related computations.
855
+ """
856
+
857
+ def __call__(
858
+ self,
859
+ attn: Attention,
860
+ hidden_states,
861
+ encoder_hidden_states=None,
862
+ attention_mask=None,
863
+ temb=None,
864
+ num_tasks=2
865
+ ):
866
+
867
+ residual = hidden_states
868
+
869
+ if attn.spatial_norm is not None:
870
+ hidden_states = attn.spatial_norm(hidden_states, temb)
871
+
872
+ input_ndim = hidden_states.ndim
873
+
874
+ if input_ndim == 4:
875
+ batch_size, channel, height, width = hidden_states.shape
876
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
877
+
878
+ batch_size, sequence_length, _ = (
879
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
880
+ )
881
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
882
+
883
+ # from yuancheng; here attention_mask is None
884
+ if attention_mask is not None:
885
+ # expand our mask's singleton query_tokens dimension:
886
+ # [batch*heads, 1, key_tokens] ->
887
+ # [batch*heads, query_tokens, key_tokens]
888
+ # so that it can be added as a bias onto the attention scores that xformers computes:
889
+ # [batch*heads, query_tokens, key_tokens]
890
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
891
+ _, query_tokens, _ = hidden_states.shape
892
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
893
+
894
+ if attn.group_norm is not None:
895
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
896
+
897
+ query = attn.to_q(hidden_states)
898
+
899
+ if encoder_hidden_states is None:
900
+ encoder_hidden_states = hidden_states
901
+ elif attn.norm_cross:
902
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
903
+
904
+ key = attn.to_k(encoder_hidden_states)
905
+ value = attn.to_v(encoder_hidden_states)
906
+
907
+ assert num_tasks == 2 # only support two tasks now
908
+
909
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
910
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
911
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
912
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
913
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
914
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
915
+
916
+
917
+ query = attn.head_to_batch_dim(query).contiguous()
918
+ key = attn.head_to_batch_dim(key).contiguous()
919
+ value = attn.head_to_batch_dim(value).contiguous()
920
+
921
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
922
+ hidden_states = attn.batch_to_head_dim(hidden_states)
923
+
924
+ # linear proj
925
+ hidden_states = attn.to_out[0](hidden_states)
926
+ # dropout
927
+ hidden_states = attn.to_out[1](hidden_states)
928
+
929
+ if input_ndim == 4:
930
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
931
+
932
+ if attn.residual_connection:
933
+ hidden_states = hidden_states + residual
934
+
935
+ hidden_states = hidden_states / attn.rescale_output_factor
936
+
937
+ return hidden_states
938
+
939
+
940
+ class JointAttnProcessor:
941
+ r"""
942
+ Default processor for performing attention-related computations.
943
+ """
944
+
945
+ def __call__(
946
+ self,
947
+ attn: Attention,
948
+ hidden_states,
949
+ encoder_hidden_states=None,
950
+ attention_mask=None,
951
+ temb=None,
952
+ num_tasks=2
953
+ ):
954
+
955
+ residual = hidden_states
956
+
957
+ if attn.spatial_norm is not None:
958
+ hidden_states = attn.spatial_norm(hidden_states, temb)
959
+
960
+ input_ndim = hidden_states.ndim
961
+
962
+ if input_ndim == 4:
963
+ batch_size, channel, height, width = hidden_states.shape
964
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
965
+
966
+ batch_size, sequence_length, _ = (
967
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
968
+ )
969
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
970
+
971
+
972
+ if attn.group_norm is not None:
973
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
974
+
975
+ query = attn.to_q(hidden_states)
976
+
977
+ if encoder_hidden_states is None:
978
+ encoder_hidden_states = hidden_states
979
+ elif attn.norm_cross:
980
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
981
+
982
+ key = attn.to_k(encoder_hidden_states)
983
+ value = attn.to_v(encoder_hidden_states)
984
+
985
+ assert num_tasks == 2 # only support two tasks now
986
+
987
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
988
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
989
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
990
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
991
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
992
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
993
+
994
+
995
+ query = attn.head_to_batch_dim(query).contiguous()
996
+ key = attn.head_to_batch_dim(key).contiguous()
997
+ value = attn.head_to_batch_dim(value).contiguous()
998
+
999
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1000
+ hidden_states = torch.bmm(attention_probs, value)
1001
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1002
+
1003
+ # linear proj
1004
+ hidden_states = attn.to_out[0](hidden_states)
1005
+ # dropout
1006
+ hidden_states = attn.to_out[1](hidden_states)
1007
+
1008
+ if input_ndim == 4:
1009
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1010
+
1011
+ if attn.residual_connection:
1012
+ hidden_states = hidden_states + residual
1013
+
1014
+ hidden_states = hidden_states / attn.rescale_output_factor
1015
+
1016
+ return hidden_states
mvdiffusion/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,932 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ # from diffusers.models.normalization import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+
27
+ from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
28
+ from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ def get_down_block(
35
+ down_block_type,
36
+ num_layers,
37
+ in_channels,
38
+ out_channels,
39
+ temb_channels,
40
+ add_downsample,
41
+ resnet_eps,
42
+ resnet_act_fn,
43
+ transformer_layers_per_block=1,
44
+ num_attention_heads=None,
45
+ resnet_groups=None,
46
+ cross_attention_dim=None,
47
+ downsample_padding=None,
48
+ dual_cross_attention=False,
49
+ use_linear_projection=False,
50
+ only_cross_attention=False,
51
+ upcast_attention=False,
52
+ resnet_time_scale_shift="default",
53
+ resnet_skip_time_act=False,
54
+ resnet_out_scale_factor=1.0,
55
+ cross_attention_norm=None,
56
+ attention_head_dim=None,
57
+ downsample_type=None,
58
+ num_views=1,
59
+ cd_attention_last: bool = False,
60
+ cd_attention_mid: bool = False,
61
+ multiview_attention: bool = True,
62
+ sparse_mv_attention: bool = False,
63
+ selfattn_block: str = "custom",
64
+ ):
65
+ # If attn head dim is not defined, we default it to the number of heads
66
+ if attention_head_dim is None:
67
+ logger.warn(
68
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
69
+ )
70
+ attention_head_dim = num_attention_heads
71
+
72
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
73
+ if down_block_type == "DownBlock2D":
74
+ return DownBlock2D(
75
+ num_layers=num_layers,
76
+ in_channels=in_channels,
77
+ out_channels=out_channels,
78
+ temb_channels=temb_channels,
79
+ add_downsample=add_downsample,
80
+ resnet_eps=resnet_eps,
81
+ resnet_act_fn=resnet_act_fn,
82
+ resnet_groups=resnet_groups,
83
+ downsample_padding=downsample_padding,
84
+ resnet_time_scale_shift=resnet_time_scale_shift,
85
+ )
86
+ elif down_block_type == "ResnetDownsampleBlock2D":
87
+ return ResnetDownsampleBlock2D(
88
+ num_layers=num_layers,
89
+ in_channels=in_channels,
90
+ out_channels=out_channels,
91
+ temb_channels=temb_channels,
92
+ add_downsample=add_downsample,
93
+ resnet_eps=resnet_eps,
94
+ resnet_act_fn=resnet_act_fn,
95
+ resnet_groups=resnet_groups,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ skip_time_act=resnet_skip_time_act,
98
+ output_scale_factor=resnet_out_scale_factor,
99
+ )
100
+ elif down_block_type == "AttnDownBlock2D":
101
+ if add_downsample is False:
102
+ downsample_type = None
103
+ else:
104
+ downsample_type = downsample_type or "conv" # default to 'conv'
105
+ return AttnDownBlock2D(
106
+ num_layers=num_layers,
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ temb_channels=temb_channels,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ downsample_padding=downsample_padding,
114
+ attention_head_dim=attention_head_dim,
115
+ resnet_time_scale_shift=resnet_time_scale_shift,
116
+ downsample_type=downsample_type,
117
+ )
118
+ elif down_block_type == "CrossAttnDownBlock2D":
119
+ if cross_attention_dim is None:
120
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
121
+ return CrossAttnDownBlock2D(
122
+ num_layers=num_layers,
123
+ transformer_layers_per_block=transformer_layers_per_block,
124
+ in_channels=in_channels,
125
+ out_channels=out_channels,
126
+ temb_channels=temb_channels,
127
+ add_downsample=add_downsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ downsample_padding=downsample_padding,
132
+ cross_attention_dim=cross_attention_dim,
133
+ num_attention_heads=num_attention_heads,
134
+ dual_cross_attention=dual_cross_attention,
135
+ use_linear_projection=use_linear_projection,
136
+ only_cross_attention=only_cross_attention,
137
+ upcast_attention=upcast_attention,
138
+ resnet_time_scale_shift=resnet_time_scale_shift,
139
+ )
140
+ # custom MV2D attention block
141
+ elif down_block_type == "CrossAttnDownBlockMV2D":
142
+ if cross_attention_dim is None:
143
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
144
+ return CrossAttnDownBlockMV2D(
145
+ num_layers=num_layers,
146
+ transformer_layers_per_block=transformer_layers_per_block,
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ temb_channels=temb_channels,
150
+ add_downsample=add_downsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ downsample_padding=downsample_padding,
155
+ cross_attention_dim=cross_attention_dim,
156
+ num_attention_heads=num_attention_heads,
157
+ dual_cross_attention=dual_cross_attention,
158
+ use_linear_projection=use_linear_projection,
159
+ only_cross_attention=only_cross_attention,
160
+ upcast_attention=upcast_attention,
161
+ resnet_time_scale_shift=resnet_time_scale_shift,
162
+ num_views=num_views,
163
+ cd_attention_last=cd_attention_last,
164
+ cd_attention_mid=cd_attention_mid,
165
+ multiview_attention=multiview_attention,
166
+ sparse_mv_attention=sparse_mv_attention,
167
+ selfattn_block=selfattn_block,
168
+ )
169
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
170
+ if cross_attention_dim is None:
171
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
172
+ return SimpleCrossAttnDownBlock2D(
173
+ num_layers=num_layers,
174
+ in_channels=in_channels,
175
+ out_channels=out_channels,
176
+ temb_channels=temb_channels,
177
+ add_downsample=add_downsample,
178
+ resnet_eps=resnet_eps,
179
+ resnet_act_fn=resnet_act_fn,
180
+ resnet_groups=resnet_groups,
181
+ cross_attention_dim=cross_attention_dim,
182
+ attention_head_dim=attention_head_dim,
183
+ resnet_time_scale_shift=resnet_time_scale_shift,
184
+ skip_time_act=resnet_skip_time_act,
185
+ output_scale_factor=resnet_out_scale_factor,
186
+ only_cross_attention=only_cross_attention,
187
+ cross_attention_norm=cross_attention_norm,
188
+ )
189
+ elif down_block_type == "SkipDownBlock2D":
190
+ return SkipDownBlock2D(
191
+ num_layers=num_layers,
192
+ in_channels=in_channels,
193
+ out_channels=out_channels,
194
+ temb_channels=temb_channels,
195
+ add_downsample=add_downsample,
196
+ resnet_eps=resnet_eps,
197
+ resnet_act_fn=resnet_act_fn,
198
+ downsample_padding=downsample_padding,
199
+ resnet_time_scale_shift=resnet_time_scale_shift,
200
+ )
201
+ elif down_block_type == "AttnSkipDownBlock2D":
202
+ return AttnSkipDownBlock2D(
203
+ num_layers=num_layers,
204
+ in_channels=in_channels,
205
+ out_channels=out_channels,
206
+ temb_channels=temb_channels,
207
+ add_downsample=add_downsample,
208
+ resnet_eps=resnet_eps,
209
+ resnet_act_fn=resnet_act_fn,
210
+ attention_head_dim=attention_head_dim,
211
+ resnet_time_scale_shift=resnet_time_scale_shift,
212
+ )
213
+ elif down_block_type == "DownEncoderBlock2D":
214
+ return DownEncoderBlock2D(
215
+ num_layers=num_layers,
216
+ in_channels=in_channels,
217
+ out_channels=out_channels,
218
+ add_downsample=add_downsample,
219
+ resnet_eps=resnet_eps,
220
+ resnet_act_fn=resnet_act_fn,
221
+ resnet_groups=resnet_groups,
222
+ downsample_padding=downsample_padding,
223
+ resnet_time_scale_shift=resnet_time_scale_shift,
224
+ )
225
+ elif down_block_type == "AttnDownEncoderBlock2D":
226
+ return AttnDownEncoderBlock2D(
227
+ num_layers=num_layers,
228
+ in_channels=in_channels,
229
+ out_channels=out_channels,
230
+ add_downsample=add_downsample,
231
+ resnet_eps=resnet_eps,
232
+ resnet_act_fn=resnet_act_fn,
233
+ resnet_groups=resnet_groups,
234
+ downsample_padding=downsample_padding,
235
+ attention_head_dim=attention_head_dim,
236
+ resnet_time_scale_shift=resnet_time_scale_shift,
237
+ )
238
+ elif down_block_type == "KDownBlock2D":
239
+ return KDownBlock2D(
240
+ num_layers=num_layers,
241
+ in_channels=in_channels,
242
+ out_channels=out_channels,
243
+ temb_channels=temb_channels,
244
+ add_downsample=add_downsample,
245
+ resnet_eps=resnet_eps,
246
+ resnet_act_fn=resnet_act_fn,
247
+ )
248
+ elif down_block_type == "KCrossAttnDownBlock2D":
249
+ return KCrossAttnDownBlock2D(
250
+ num_layers=num_layers,
251
+ in_channels=in_channels,
252
+ out_channels=out_channels,
253
+ temb_channels=temb_channels,
254
+ add_downsample=add_downsample,
255
+ resnet_eps=resnet_eps,
256
+ resnet_act_fn=resnet_act_fn,
257
+ cross_attention_dim=cross_attention_dim,
258
+ attention_head_dim=attention_head_dim,
259
+ add_self_attention=True if not add_downsample else False,
260
+ )
261
+ raise ValueError(f"{down_block_type} does not exist.")
262
+
263
+
264
+ def get_up_block(
265
+ up_block_type,
266
+ num_layers,
267
+ in_channels,
268
+ out_channels,
269
+ prev_output_channel,
270
+ temb_channels,
271
+ add_upsample,
272
+ resnet_eps,
273
+ resnet_act_fn,
274
+ transformer_layers_per_block=1,
275
+ num_attention_heads=None,
276
+ resnet_groups=None,
277
+ cross_attention_dim=None,
278
+ dual_cross_attention=False,
279
+ use_linear_projection=False,
280
+ only_cross_attention=False,
281
+ upcast_attention=False,
282
+ resnet_time_scale_shift="default",
283
+ resnet_skip_time_act=False,
284
+ resnet_out_scale_factor=1.0,
285
+ cross_attention_norm=None,
286
+ attention_head_dim=None,
287
+ upsample_type=None,
288
+ num_views=1,
289
+ cd_attention_last: bool = False,
290
+ cd_attention_mid: bool = False,
291
+ multiview_attention: bool = True,
292
+ sparse_mv_attention: bool = False,
293
+ selfattn_block: str = "custom",
294
+ ):
295
+ # If attn head dim is not defined, we default it to the number of heads
296
+ if attention_head_dim is None:
297
+ logger.warn(
298
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
299
+ )
300
+ attention_head_dim = num_attention_heads
301
+
302
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
303
+ if up_block_type == "UpBlock2D":
304
+ return UpBlock2D(
305
+ num_layers=num_layers,
306
+ in_channels=in_channels,
307
+ out_channels=out_channels,
308
+ prev_output_channel=prev_output_channel,
309
+ temb_channels=temb_channels,
310
+ add_upsample=add_upsample,
311
+ resnet_eps=resnet_eps,
312
+ resnet_act_fn=resnet_act_fn,
313
+ resnet_groups=resnet_groups,
314
+ resnet_time_scale_shift=resnet_time_scale_shift,
315
+ )
316
+ elif up_block_type == "ResnetUpsampleBlock2D":
317
+ return ResnetUpsampleBlock2D(
318
+ num_layers=num_layers,
319
+ in_channels=in_channels,
320
+ out_channels=out_channels,
321
+ prev_output_channel=prev_output_channel,
322
+ temb_channels=temb_channels,
323
+ add_upsample=add_upsample,
324
+ resnet_eps=resnet_eps,
325
+ resnet_act_fn=resnet_act_fn,
326
+ resnet_groups=resnet_groups,
327
+ resnet_time_scale_shift=resnet_time_scale_shift,
328
+ skip_time_act=resnet_skip_time_act,
329
+ output_scale_factor=resnet_out_scale_factor,
330
+ )
331
+ elif up_block_type == "CrossAttnUpBlock2D":
332
+ if cross_attention_dim is None:
333
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
334
+ return CrossAttnUpBlock2D(
335
+ num_layers=num_layers,
336
+ transformer_layers_per_block=transformer_layers_per_block,
337
+ in_channels=in_channels,
338
+ out_channels=out_channels,
339
+ prev_output_channel=prev_output_channel,
340
+ temb_channels=temb_channels,
341
+ add_upsample=add_upsample,
342
+ resnet_eps=resnet_eps,
343
+ resnet_act_fn=resnet_act_fn,
344
+ resnet_groups=resnet_groups,
345
+ cross_attention_dim=cross_attention_dim,
346
+ num_attention_heads=num_attention_heads,
347
+ dual_cross_attention=dual_cross_attention,
348
+ use_linear_projection=use_linear_projection,
349
+ only_cross_attention=only_cross_attention,
350
+ upcast_attention=upcast_attention,
351
+ resnet_time_scale_shift=resnet_time_scale_shift,
352
+ )
353
+ # custom MV2D attention block
354
+ elif up_block_type == "CrossAttnUpBlockMV2D":
355
+ if cross_attention_dim is None:
356
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
357
+ return CrossAttnUpBlockMV2D(
358
+ num_layers=num_layers,
359
+ transformer_layers_per_block=transformer_layers_per_block,
360
+ in_channels=in_channels,
361
+ out_channels=out_channels,
362
+ prev_output_channel=prev_output_channel,
363
+ temb_channels=temb_channels,
364
+ add_upsample=add_upsample,
365
+ resnet_eps=resnet_eps,
366
+ resnet_act_fn=resnet_act_fn,
367
+ resnet_groups=resnet_groups,
368
+ cross_attention_dim=cross_attention_dim,
369
+ num_attention_heads=num_attention_heads,
370
+ dual_cross_attention=dual_cross_attention,
371
+ use_linear_projection=use_linear_projection,
372
+ only_cross_attention=only_cross_attention,
373
+ upcast_attention=upcast_attention,
374
+ resnet_time_scale_shift=resnet_time_scale_shift,
375
+ num_views=num_views,
376
+ cd_attention_last=cd_attention_last,
377
+ cd_attention_mid=cd_attention_mid,
378
+ multiview_attention=multiview_attention,
379
+ sparse_mv_attention=sparse_mv_attention,
380
+ selfattn_block=selfattn_block,
381
+ )
382
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
383
+ if cross_attention_dim is None:
384
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
385
+ return SimpleCrossAttnUpBlock2D(
386
+ num_layers=num_layers,
387
+ in_channels=in_channels,
388
+ out_channels=out_channels,
389
+ prev_output_channel=prev_output_channel,
390
+ temb_channels=temb_channels,
391
+ add_upsample=add_upsample,
392
+ resnet_eps=resnet_eps,
393
+ resnet_act_fn=resnet_act_fn,
394
+ resnet_groups=resnet_groups,
395
+ cross_attention_dim=cross_attention_dim,
396
+ attention_head_dim=attention_head_dim,
397
+ resnet_time_scale_shift=resnet_time_scale_shift,
398
+ skip_time_act=resnet_skip_time_act,
399
+ output_scale_factor=resnet_out_scale_factor,
400
+ only_cross_attention=only_cross_attention,
401
+ cross_attention_norm=cross_attention_norm,
402
+ )
403
+ elif up_block_type == "AttnUpBlock2D":
404
+ if add_upsample is False:
405
+ upsample_type = None
406
+ else:
407
+ upsample_type = upsample_type or "conv" # default to 'conv'
408
+
409
+ return AttnUpBlock2D(
410
+ num_layers=num_layers,
411
+ in_channels=in_channels,
412
+ out_channels=out_channels,
413
+ prev_output_channel=prev_output_channel,
414
+ temb_channels=temb_channels,
415
+ resnet_eps=resnet_eps,
416
+ resnet_act_fn=resnet_act_fn,
417
+ resnet_groups=resnet_groups,
418
+ attention_head_dim=attention_head_dim,
419
+ resnet_time_scale_shift=resnet_time_scale_shift,
420
+ upsample_type=upsample_type,
421
+ )
422
+ elif up_block_type == "SkipUpBlock2D":
423
+ return SkipUpBlock2D(
424
+ num_layers=num_layers,
425
+ in_channels=in_channels,
426
+ out_channels=out_channels,
427
+ prev_output_channel=prev_output_channel,
428
+ temb_channels=temb_channels,
429
+ add_upsample=add_upsample,
430
+ resnet_eps=resnet_eps,
431
+ resnet_act_fn=resnet_act_fn,
432
+ resnet_time_scale_shift=resnet_time_scale_shift,
433
+ )
434
+ elif up_block_type == "AttnSkipUpBlock2D":
435
+ return AttnSkipUpBlock2D(
436
+ num_layers=num_layers,
437
+ in_channels=in_channels,
438
+ out_channels=out_channels,
439
+ prev_output_channel=prev_output_channel,
440
+ temb_channels=temb_channels,
441
+ add_upsample=add_upsample,
442
+ resnet_eps=resnet_eps,
443
+ resnet_act_fn=resnet_act_fn,
444
+ attention_head_dim=attention_head_dim,
445
+ resnet_time_scale_shift=resnet_time_scale_shift,
446
+ )
447
+ elif up_block_type == "UpDecoderBlock2D":
448
+ return UpDecoderBlock2D(
449
+ num_layers=num_layers,
450
+ in_channels=in_channels,
451
+ out_channels=out_channels,
452
+ add_upsample=add_upsample,
453
+ resnet_eps=resnet_eps,
454
+ resnet_act_fn=resnet_act_fn,
455
+ resnet_groups=resnet_groups,
456
+ resnet_time_scale_shift=resnet_time_scale_shift,
457
+ temb_channels=temb_channels,
458
+ )
459
+ elif up_block_type == "AttnUpDecoderBlock2D":
460
+ return AttnUpDecoderBlock2D(
461
+ num_layers=num_layers,
462
+ in_channels=in_channels,
463
+ out_channels=out_channels,
464
+ add_upsample=add_upsample,
465
+ resnet_eps=resnet_eps,
466
+ resnet_act_fn=resnet_act_fn,
467
+ resnet_groups=resnet_groups,
468
+ attention_head_dim=attention_head_dim,
469
+ resnet_time_scale_shift=resnet_time_scale_shift,
470
+ temb_channels=temb_channels,
471
+ )
472
+ elif up_block_type == "KUpBlock2D":
473
+ return KUpBlock2D(
474
+ num_layers=num_layers,
475
+ in_channels=in_channels,
476
+ out_channels=out_channels,
477
+ temb_channels=temb_channels,
478
+ add_upsample=add_upsample,
479
+ resnet_eps=resnet_eps,
480
+ resnet_act_fn=resnet_act_fn,
481
+ )
482
+ elif up_block_type == "KCrossAttnUpBlock2D":
483
+ return KCrossAttnUpBlock2D(
484
+ num_layers=num_layers,
485
+ in_channels=in_channels,
486
+ out_channels=out_channels,
487
+ temb_channels=temb_channels,
488
+ add_upsample=add_upsample,
489
+ resnet_eps=resnet_eps,
490
+ resnet_act_fn=resnet_act_fn,
491
+ cross_attention_dim=cross_attention_dim,
492
+ attention_head_dim=attention_head_dim,
493
+ )
494
+
495
+ raise ValueError(f"{up_block_type} does not exist.")
496
+
497
+
498
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
499
+ def __init__(
500
+ self,
501
+ in_channels: int,
502
+ temb_channels: int,
503
+ dropout: float = 0.0,
504
+ num_layers: int = 1,
505
+ transformer_layers_per_block: int = 1,
506
+ resnet_eps: float = 1e-6,
507
+ resnet_time_scale_shift: str = "default",
508
+ resnet_act_fn: str = "swish",
509
+ resnet_groups: int = 32,
510
+ resnet_pre_norm: bool = True,
511
+ num_attention_heads=1,
512
+ output_scale_factor=1.0,
513
+ cross_attention_dim=1280,
514
+ dual_cross_attention=False,
515
+ use_linear_projection=False,
516
+ upcast_attention=False,
517
+ num_views: int = 1,
518
+ cd_attention_last: bool = False,
519
+ cd_attention_mid: bool = False,
520
+ multiview_attention: bool = True,
521
+ sparse_mv_attention: bool = False,
522
+ selfattn_block: str = "custom",
523
+ ):
524
+ super().__init__()
525
+
526
+ self.has_cross_attention = True
527
+ self.num_attention_heads = num_attention_heads
528
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
529
+ if selfattn_block == "custom":
530
+ from .transformer_mv2d_image import TransformerMV2DModel
531
+ else:
532
+ raise NotImplementedError
533
+
534
+ # there is always at least one resnet
535
+ resnets = [
536
+ ResnetBlock2D(
537
+ in_channels=in_channels,
538
+ out_channels=in_channels,
539
+ temb_channels=temb_channels,
540
+ eps=resnet_eps,
541
+ groups=resnet_groups,
542
+ dropout=dropout,
543
+ time_embedding_norm=resnet_time_scale_shift,
544
+ non_linearity=resnet_act_fn,
545
+ output_scale_factor=output_scale_factor,
546
+ pre_norm=resnet_pre_norm,
547
+ )
548
+ ]
549
+ attentions = []
550
+
551
+ for _ in range(num_layers):
552
+ if not dual_cross_attention:
553
+ attentions.append(
554
+ TransformerMV2DModel(
555
+ num_attention_heads,
556
+ in_channels // num_attention_heads,
557
+ in_channels=in_channels,
558
+ num_layers=transformer_layers_per_block,
559
+ cross_attention_dim=cross_attention_dim,
560
+ norm_num_groups=resnet_groups,
561
+ use_linear_projection=use_linear_projection,
562
+ upcast_attention=upcast_attention,
563
+ num_views=num_views,
564
+ cd_attention_last=cd_attention_last,
565
+ cd_attention_mid=cd_attention_mid,
566
+ multiview_attention=multiview_attention,
567
+ sparse_mv_attention=sparse_mv_attention,
568
+ )
569
+ )
570
+ else:
571
+ raise NotImplementedError
572
+ resnets.append(
573
+ ResnetBlock2D(
574
+ in_channels=in_channels,
575
+ out_channels=in_channels,
576
+ temb_channels=temb_channels,
577
+ eps=resnet_eps,
578
+ groups=resnet_groups,
579
+ dropout=dropout,
580
+ time_embedding_norm=resnet_time_scale_shift,
581
+ non_linearity=resnet_act_fn,
582
+ output_scale_factor=output_scale_factor,
583
+ pre_norm=resnet_pre_norm,
584
+ )
585
+ )
586
+
587
+ self.attentions = nn.ModuleList(attentions)
588
+ self.resnets = nn.ModuleList(resnets)
589
+
590
+ def forward(
591
+ self,
592
+ hidden_states: torch.FloatTensor,
593
+ temb: Optional[torch.FloatTensor] = None,
594
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
595
+ attention_mask: Optional[torch.FloatTensor] = None,
596
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
597
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
598
+ ) -> torch.FloatTensor:
599
+ hidden_states = self.resnets[0](hidden_states, temb)
600
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
601
+ hidden_states = attn(
602
+ hidden_states,
603
+ encoder_hidden_states=encoder_hidden_states,
604
+ cross_attention_kwargs=cross_attention_kwargs,
605
+ attention_mask=attention_mask,
606
+ encoder_attention_mask=encoder_attention_mask,
607
+ return_dict=False,
608
+ )[0]
609
+ hidden_states = resnet(hidden_states, temb)
610
+
611
+ return hidden_states
612
+
613
+
614
+ class CrossAttnUpBlockMV2D(nn.Module):
615
+ def __init__(
616
+ self,
617
+ in_channels: int,
618
+ out_channels: int,
619
+ prev_output_channel: int,
620
+ temb_channels: int,
621
+ dropout: float = 0.0,
622
+ num_layers: int = 1,
623
+ transformer_layers_per_block: int = 1,
624
+ resnet_eps: float = 1e-6,
625
+ resnet_time_scale_shift: str = "default",
626
+ resnet_act_fn: str = "swish",
627
+ resnet_groups: int = 32,
628
+ resnet_pre_norm: bool = True,
629
+ num_attention_heads=1,
630
+ cross_attention_dim=1280,
631
+ output_scale_factor=1.0,
632
+ add_upsample=True,
633
+ dual_cross_attention=False,
634
+ use_linear_projection=False,
635
+ only_cross_attention=False,
636
+ upcast_attention=False,
637
+ num_views: int = 1,
638
+ cd_attention_last: bool = False,
639
+ cd_attention_mid: bool = False,
640
+ multiview_attention: bool = True,
641
+ sparse_mv_attention: bool = False,
642
+ selfattn_block: str = "custom",
643
+ ):
644
+ super().__init__()
645
+ resnets = []
646
+ attentions = []
647
+
648
+ self.has_cross_attention = True
649
+ self.num_attention_heads = num_attention_heads
650
+
651
+ if selfattn_block == "custom":
652
+ from .transformer_mv2d_image import TransformerMV2DModel
653
+ else:
654
+ raise NotImplementedError
655
+
656
+ for i in range(num_layers):
657
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
658
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
659
+
660
+ resnets.append(
661
+ ResnetBlock2D(
662
+ in_channels=resnet_in_channels + res_skip_channels,
663
+ out_channels=out_channels,
664
+ temb_channels=temb_channels,
665
+ eps=resnet_eps,
666
+ groups=resnet_groups,
667
+ dropout=dropout,
668
+ time_embedding_norm=resnet_time_scale_shift,
669
+ non_linearity=resnet_act_fn,
670
+ output_scale_factor=output_scale_factor,
671
+ pre_norm=resnet_pre_norm,
672
+ )
673
+ )
674
+ if not dual_cross_attention:
675
+ attentions.append(
676
+ TransformerMV2DModel(
677
+ num_attention_heads,
678
+ out_channels // num_attention_heads,
679
+ in_channels=out_channels,
680
+ num_layers=transformer_layers_per_block,
681
+ cross_attention_dim=cross_attention_dim,
682
+ norm_num_groups=resnet_groups,
683
+ use_linear_projection=use_linear_projection,
684
+ only_cross_attention=only_cross_attention,
685
+ upcast_attention=upcast_attention,
686
+ num_views=num_views,
687
+ cd_attention_last=cd_attention_last,
688
+ cd_attention_mid=cd_attention_mid,
689
+ multiview_attention=multiview_attention,
690
+ sparse_mv_attention=sparse_mv_attention,
691
+ )
692
+ )
693
+ else:
694
+ raise NotImplementedError
695
+ self.attentions = nn.ModuleList(attentions)
696
+ self.resnets = nn.ModuleList(resnets)
697
+
698
+ if add_upsample:
699
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
700
+ else:
701
+ self.upsamplers = None
702
+
703
+ self.gradient_checkpointing = False
704
+
705
+ def forward(
706
+ self,
707
+ hidden_states: torch.FloatTensor,
708
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
709
+ temb: Optional[torch.FloatTensor] = None,
710
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
711
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
712
+ upsample_size: Optional[int] = None,
713
+ attention_mask: Optional[torch.FloatTensor] = None,
714
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
715
+ ):
716
+ for resnet, attn in zip(self.resnets, self.attentions):
717
+ # pop res hidden states
718
+ res_hidden_states = res_hidden_states_tuple[-1]
719
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
720
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
721
+
722
+ if self.training and self.gradient_checkpointing:
723
+
724
+ def create_custom_forward(module, return_dict=None):
725
+ def custom_forward(*inputs):
726
+ if return_dict is not None and return_dict is not False:
727
+ return module(*inputs, return_dict=return_dict)
728
+ else:
729
+ return module(*inputs)
730
+
731
+ return custom_forward
732
+
733
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
734
+ hidden_states = torch.utils.checkpoint.checkpoint(
735
+ create_custom_forward(resnet),
736
+ hidden_states,
737
+ temb,
738
+ **ckpt_kwargs,
739
+ )
740
+ hidden_states = torch.utils.checkpoint.checkpoint(
741
+ create_custom_forward(attn, return_dict=False),
742
+ hidden_states,
743
+ encoder_hidden_states,
744
+ None, # timestep
745
+ None, # class_labels
746
+ cross_attention_kwargs,
747
+ attention_mask,
748
+ encoder_attention_mask,
749
+ **ckpt_kwargs,
750
+ )[0]
751
+ else:
752
+ hidden_states = resnet(hidden_states, temb)
753
+ hidden_states = attn(
754
+ hidden_states,
755
+ encoder_hidden_states=encoder_hidden_states,
756
+ cross_attention_kwargs=cross_attention_kwargs,
757
+ attention_mask=attention_mask,
758
+ encoder_attention_mask=encoder_attention_mask,
759
+ return_dict=False,
760
+ )[0]
761
+
762
+ if self.upsamplers is not None:
763
+ for upsampler in self.upsamplers:
764
+ hidden_states = upsampler(hidden_states, upsample_size)
765
+
766
+ return hidden_states
767
+
768
+
769
+ class CrossAttnDownBlockMV2D(nn.Module):
770
+ def __init__(
771
+ self,
772
+ in_channels: int,
773
+ out_channels: int,
774
+ temb_channels: int,
775
+ dropout: float = 0.0,
776
+ num_layers: int = 1,
777
+ transformer_layers_per_block: int = 1,
778
+ resnet_eps: float = 1e-6,
779
+ resnet_time_scale_shift: str = "default",
780
+ resnet_act_fn: str = "swish",
781
+ resnet_groups: int = 32,
782
+ resnet_pre_norm: bool = True,
783
+ num_attention_heads=1,
784
+ cross_attention_dim=1280,
785
+ output_scale_factor=1.0,
786
+ downsample_padding=1,
787
+ add_downsample=True,
788
+ dual_cross_attention=False,
789
+ use_linear_projection=False,
790
+ only_cross_attention=False,
791
+ upcast_attention=False,
792
+ num_views: int = 1,
793
+ cd_attention_last: bool = False,
794
+ cd_attention_mid: bool = False,
795
+ multiview_attention: bool = True,
796
+ sparse_mv_attention: bool = False,
797
+ selfattn_block: str = "custom",
798
+ ):
799
+ super().__init__()
800
+ resnets = []
801
+ attentions = []
802
+
803
+ self.has_cross_attention = True
804
+ self.num_attention_heads = num_attention_heads
805
+ if selfattn_block == "custom":
806
+ from .transformer_mv2d_image import TransformerMV2DModel
807
+ else:
808
+ raise NotImplementedError
809
+
810
+ for i in range(num_layers):
811
+ in_channels = in_channels if i == 0 else out_channels
812
+ resnets.append(
813
+ ResnetBlock2D(
814
+ in_channels=in_channels,
815
+ out_channels=out_channels,
816
+ temb_channels=temb_channels,
817
+ eps=resnet_eps,
818
+ groups=resnet_groups,
819
+ dropout=dropout,
820
+ time_embedding_norm=resnet_time_scale_shift,
821
+ non_linearity=resnet_act_fn,
822
+ output_scale_factor=output_scale_factor,
823
+ pre_norm=resnet_pre_norm,
824
+ )
825
+ )
826
+ if not dual_cross_attention:
827
+ attentions.append(
828
+ TransformerMV2DModel(
829
+ num_attention_heads,
830
+ out_channels // num_attention_heads,
831
+ in_channels=out_channels,
832
+ num_layers=transformer_layers_per_block,
833
+ cross_attention_dim=cross_attention_dim,
834
+ norm_num_groups=resnet_groups,
835
+ use_linear_projection=use_linear_projection,
836
+ only_cross_attention=only_cross_attention,
837
+ upcast_attention=upcast_attention,
838
+ num_views=num_views,
839
+ cd_attention_last=cd_attention_last,
840
+ cd_attention_mid=cd_attention_mid,
841
+ multiview_attention=multiview_attention,
842
+ sparse_mv_attention=sparse_mv_attention,
843
+ )
844
+ )
845
+ else:
846
+ raise NotImplementedError
847
+ self.attentions = nn.ModuleList(attentions)
848
+ self.resnets = nn.ModuleList(resnets)
849
+
850
+ if add_downsample:
851
+ self.downsamplers = nn.ModuleList(
852
+ [
853
+ Downsample2D(
854
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
855
+ )
856
+ ]
857
+ )
858
+ else:
859
+ self.downsamplers = None
860
+
861
+ self.gradient_checkpointing = False
862
+
863
+ def forward(
864
+ self,
865
+ hidden_states: torch.FloatTensor,
866
+ temb: Optional[torch.FloatTensor] = None,
867
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
868
+ attention_mask: Optional[torch.FloatTensor] = None,
869
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
870
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
871
+ additional_residuals=None,
872
+ ):
873
+ output_states = ()
874
+
875
+ blocks = list(zip(self.resnets, self.attentions))
876
+
877
+ for i, (resnet, attn) in enumerate(blocks):
878
+ if self.training and self.gradient_checkpointing:
879
+
880
+ def create_custom_forward(module, return_dict=None):
881
+ def custom_forward(*inputs):
882
+ if return_dict is not None and return_dict is not False:
883
+ print("return_dict: ", return_dict)
884
+ return module(*inputs, return_dict=return_dict)
885
+ else:
886
+ return module(*inputs)
887
+
888
+ return custom_forward
889
+
890
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
891
+ hidden_states = torch.utils.checkpoint.checkpoint(
892
+ create_custom_forward(resnet),
893
+ hidden_states,
894
+ temb,
895
+ **ckpt_kwargs,
896
+ )
897
+ hidden_states = torch.utils.checkpoint.checkpoint(
898
+ create_custom_forward(attn, return_dict=False),
899
+ hidden_states,
900
+ encoder_hidden_states,
901
+ None, # timestep
902
+ None, # class_labels
903
+ cross_attention_kwargs,
904
+ attention_mask,
905
+ encoder_attention_mask,
906
+ **ckpt_kwargs,
907
+ )[0]
908
+ else:
909
+ hidden_states = resnet(hidden_states, temb)
910
+ hidden_states = attn(
911
+ hidden_states,
912
+ encoder_hidden_states=encoder_hidden_states,
913
+ cross_attention_kwargs=cross_attention_kwargs,
914
+ attention_mask=attention_mask,
915
+ encoder_attention_mask=encoder_attention_mask,
916
+ return_dict=False,
917
+ )[0]
918
+
919
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
920
+ if i == len(blocks) - 1 and additional_residuals is not None:
921
+ hidden_states = hidden_states + additional_residuals
922
+
923
+ output_states = output_states + (hidden_states,)
924
+
925
+ if self.downsamplers is not None:
926
+ for downsampler in self.downsamplers:
927
+ hidden_states = downsampler(hidden_states)
928
+
929
+ output_states = output_states + (hidden_states,)
930
+
931
+ return hidden_states, output_states
932
+
mvdiffusion/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
+ from diffusers.models.embeddings import (
28
+ GaussianFourierProjection,
29
+ ImageHintTimeEmbedding,
30
+ ImageProjection,
31
+ ImageTimeEmbedding,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
+ from diffusers.models.unets.unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ CrossAttnUpBlock2D,
42
+ DownBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ UNetMidBlock2DSimpleCrossAttn,
45
+ UpBlock2D,
46
+ )
47
+ from diffusers.utils import (
48
+ CONFIG_NAME,
49
+ FLAX_WEIGHTS_NAME,
50
+ SAFETENSORS_WEIGHTS_NAME,
51
+ WEIGHTS_NAME,
52
+ _add_variant,
53
+ _get_model_file,
54
+ deprecate,
55
+ is_torch_version,
56
+ logging,
57
+ )
58
+ from diffusers.utils.import_utils import is_accelerate_available
59
+ from diffusers.utils.hub_utils import HF_HUB_OFFLINE
60
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
61
+ DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE
62
+
63
+ from diffusers import __version__
64
+ from .unet_mv2d_blocks import (
65
+ CrossAttnDownBlockMV2D,
66
+ CrossAttnUpBlockMV2D,
67
+ UNetMidBlockMV2DCrossAttn,
68
+ get_down_block,
69
+ get_up_block,
70
+ )
71
+ from einops import rearrange, repeat
72
+
73
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
74
+
75
+
76
+ @dataclass
77
+ class UNetMV2DConditionOutput(BaseOutput):
78
+ """
79
+ The output of [`UNet2DConditionModel`].
80
+
81
+ Args:
82
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
83
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
84
+ """
85
+
86
+ sample: torch.FloatTensor = None
87
+
88
+
89
+ class ResidualBlock(nn.Module):
90
+ def __init__(self, dim):
91
+ super(ResidualBlock, self).__init__()
92
+ self.linear1 = nn.Linear(dim, dim)
93
+ self.activation = nn.SiLU()
94
+ self.linear2 = nn.Linear(dim, dim)
95
+
96
+ def forward(self, x):
97
+ identity = x
98
+ out = self.linear1(x)
99
+ out = self.activation(out)
100
+ out = self.linear2(out)
101
+ out += identity
102
+ out = self.activation(out)
103
+ return out
104
+
105
+ class ResidualLiner(nn.Module):
106
+ def __init__(self, in_features, out_features, dim, act=None, num_block=1):
107
+ super(ResidualLiner, self).__init__()
108
+ self.linear_in = nn.Sequential(nn.Linear(in_features, dim), nn.SiLU())
109
+
110
+ blocks = nn.ModuleList()
111
+ for _ in range(num_block):
112
+ blocks.append(ResidualBlock(dim))
113
+ self.blocks = blocks
114
+
115
+ self.linear_out = nn.Linear(dim, out_features)
116
+ self.act = act
117
+
118
+ def forward(self, x):
119
+ out = self.linear_in(x)
120
+ for block in self.blocks:
121
+ out = block(out)
122
+ out = self.linear_out(out)
123
+ if self.act is not None:
124
+ out = self.act(out)
125
+ return out
126
+
127
+ class BasicConvBlock(nn.Module):
128
+ def __init__(self, in_channels, out_channels, stride=1):
129
+ super(BasicConvBlock, self).__init__()
130
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
131
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
132
+ self.act = nn.SiLU()
133
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
134
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
135
+ self.downsample = nn.Sequential()
136
+ if stride != 1 or in_channels != out_channels:
137
+ self.downsample = nn.Sequential(
138
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
139
+ nn.GroupNorm(num_groups=8, num_channels=in_channels, affine=True)
140
+ )
141
+
142
+ def forward(self, x):
143
+ identity = x
144
+ out = self.conv1(x)
145
+ out = self.norm1(out)
146
+ out = self.act(out)
147
+ out = self.conv2(out)
148
+ out = self.norm2(out)
149
+ out += self.downsample(identity)
150
+ out = self.act(out)
151
+ return out
152
+
153
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
154
+ r"""
155
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
156
+ shaped output.
157
+
158
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
159
+ for all models (such as downloading or saving).
160
+
161
+ Parameters:
162
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
163
+ Height and width of input/output sample.
164
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
165
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
166
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
167
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
168
+ Whether to flip the sin to cos in the time embedding.
169
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
170
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
171
+ The tuple of downsample blocks to use.
172
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
173
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
174
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
175
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
176
+ The tuple of upsample blocks to use.
177
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
178
+ Whether to include self-attention in the basic transformer blocks, see
179
+ [`~models.attention.BasicTransformerBlock`].
180
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
181
+ The tuple of output channels for each block.
182
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
183
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
184
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
185
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
186
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
187
+ If `None`, normalization and activation layers is skipped in post-processing.
188
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
189
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
190
+ The dimension of the cross attention features.
191
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
192
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
193
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
194
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
195
+ encoder_hid_dim (`int`, *optional*, defaults to None):
196
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
197
+ dimension to `cross_attention_dim`.
198
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
199
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
200
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
201
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
202
+ num_attention_heads (`int`, *optional*):
203
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
204
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
205
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
206
+ class_embed_type (`str`, *optional*, defaults to `None`):
207
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
208
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
209
+ addition_embed_type (`str`, *optional*, defaults to `None`):
210
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
211
+ "text". "text" will use the `TextTimeEmbedding` layer.
212
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
213
+ Dimension for the timestep embeddings.
214
+ num_class_embeds (`int`, *optional*, defaults to `None`):
215
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
216
+ class conditioning with `class_embed_type` equal to `None`.
217
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
218
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
219
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
220
+ An optional override for the dimension of the projected time embedding.
221
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
222
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
223
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
224
+ timestep_post_act (`str`, *optional*, defaults to `None`):
225
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
226
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
227
+ The dimension of `cond_proj` layer in the timestep embedding.
228
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
229
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
230
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
231
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
232
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
233
+ embeddings with the class embeddings.
234
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
235
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
236
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
237
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
238
+ otherwise.
239
+ """
240
+
241
+ _supports_gradient_checkpointing = True
242
+
243
+ @register_to_config
244
+ def __init__(
245
+ self,
246
+ sample_size: Optional[int] = None,
247
+ in_channels: int = 4,
248
+ out_channels: int = 4,
249
+ center_input_sample: bool = False,
250
+ flip_sin_to_cos: bool = True,
251
+ freq_shift: int = 0,
252
+ down_block_types: Tuple[str] = (
253
+ "CrossAttnDownBlockMV2D",
254
+ "CrossAttnDownBlockMV2D",
255
+ "CrossAttnDownBlockMV2D",
256
+ "DownBlock2D",
257
+ ),
258
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
259
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
260
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
261
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
262
+ layers_per_block: Union[int, Tuple[int]] = 2,
263
+ downsample_padding: int = 1,
264
+ mid_block_scale_factor: float = 1,
265
+ act_fn: str = "silu",
266
+ norm_num_groups: Optional[int] = 32,
267
+ norm_eps: float = 1e-5,
268
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
269
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
270
+ encoder_hid_dim: Optional[int] = None,
271
+ encoder_hid_dim_type: Optional[str] = None,
272
+ attention_head_dim: Union[int, Tuple[int]] = 8,
273
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
274
+ dual_cross_attention: bool = False,
275
+ use_linear_projection: bool = False,
276
+ class_embed_type: Optional[str] = None,
277
+ addition_embed_type: Optional[str] = None,
278
+ addition_time_embed_dim: Optional[int] = None,
279
+ num_class_embeds: Optional[int] = None,
280
+ upcast_attention: bool = False,
281
+ resnet_time_scale_shift: str = "default",
282
+ resnet_skip_time_act: bool = False,
283
+ resnet_out_scale_factor: int = 1.0,
284
+ time_embedding_type: str = "positional",
285
+ time_embedding_dim: Optional[int] = None,
286
+ time_embedding_act_fn: Optional[str] = None,
287
+ timestep_post_act: Optional[str] = None,
288
+ time_cond_proj_dim: Optional[int] = None,
289
+ conv_in_kernel: int = 3,
290
+ conv_out_kernel: int = 3,
291
+ projection_class_embeddings_input_dim: Optional[int] = None,
292
+ class_embeddings_concat: bool = False,
293
+ mid_block_only_cross_attention: Optional[bool] = None,
294
+ cross_attention_norm: Optional[str] = None,
295
+ addition_embed_type_num_heads=64,
296
+ num_views: int = 1,
297
+ cd_attention_last: bool = False,
298
+ cd_attention_mid: bool = False,
299
+ multiview_attention: bool = True,
300
+ sparse_mv_attention: bool = False,
301
+ selfattn_block: str = "custom",
302
+ addition_downsample: bool = False,
303
+ addition_channels: Optional[Tuple[int]] = (1280, 1280, 1280),
304
+ ):
305
+ super().__init__()
306
+
307
+ self.sample_size = sample_size
308
+ self.num_views = num_views
309
+
310
+ if num_attention_heads is not None:
311
+ raise ValueError(
312
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
313
+ )
314
+
315
+ # If `num_attention_heads` is not defined (which is the case for most models)
316
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
317
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
318
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
319
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
320
+ # which is why we correct for the naming here.
321
+ num_attention_heads = num_attention_heads or attention_head_dim
322
+
323
+ # Check inputs
324
+ if len(down_block_types) != len(up_block_types):
325
+ raise ValueError(
326
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
327
+ )
328
+
329
+ if len(block_out_channels) != len(down_block_types):
330
+ raise ValueError(
331
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
332
+ )
333
+
334
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
335
+ raise ValueError(
336
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
337
+ )
338
+
339
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
340
+ raise ValueError(
341
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
342
+ )
343
+
344
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
345
+ raise ValueError(
346
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
347
+ )
348
+
349
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
350
+ raise ValueError(
351
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
352
+ )
353
+
354
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
355
+ raise ValueError(
356
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
357
+ )
358
+
359
+ # input
360
+ conv_in_padding = (conv_in_kernel - 1) // 2
361
+ self.conv_in = nn.Conv2d(
362
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
363
+ )
364
+
365
+ # time
366
+ if time_embedding_type == "fourier":
367
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
368
+ if time_embed_dim % 2 != 0:
369
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
370
+ self.time_proj = GaussianFourierProjection(
371
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
372
+ )
373
+ timestep_input_dim = time_embed_dim
374
+ elif time_embedding_type == "positional":
375
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
376
+
377
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
378
+ timestep_input_dim = block_out_channels[0]
379
+ else:
380
+ raise ValueError(
381
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
382
+ )
383
+
384
+ self.time_embedding = TimestepEmbedding(
385
+ timestep_input_dim,
386
+ time_embed_dim,
387
+ act_fn=act_fn,
388
+ post_act_fn=timestep_post_act,
389
+ cond_proj_dim=time_cond_proj_dim,
390
+ )
391
+
392
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
393
+ encoder_hid_dim_type = "text_proj"
394
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
395
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
396
+
397
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
398
+ raise ValueError(
399
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
400
+ )
401
+
402
+ if encoder_hid_dim_type == "text_proj":
403
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
404
+ elif encoder_hid_dim_type == "text_image_proj":
405
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
406
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
407
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
408
+ self.encoder_hid_proj = TextImageProjection(
409
+ text_embed_dim=encoder_hid_dim,
410
+ image_embed_dim=cross_attention_dim,
411
+ cross_attention_dim=cross_attention_dim,
412
+ )
413
+ elif encoder_hid_dim_type == "image_proj":
414
+ # Kandinsky 2.2
415
+ self.encoder_hid_proj = ImageProjection(
416
+ image_embed_dim=encoder_hid_dim,
417
+ cross_attention_dim=cross_attention_dim,
418
+ )
419
+ elif encoder_hid_dim_type is not None:
420
+ raise ValueError(
421
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
422
+ )
423
+ else:
424
+ self.encoder_hid_proj = None
425
+
426
+ # class embedding
427
+ if class_embed_type is None and num_class_embeds is not None:
428
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
429
+ elif class_embed_type == "timestep":
430
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
431
+ elif class_embed_type == "identity":
432
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
433
+ elif class_embed_type == "projection":
434
+ if projection_class_embeddings_input_dim is None:
435
+ raise ValueError(
436
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
437
+ )
438
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
439
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
440
+ # 2. it projects from an arbitrary input dimension.
441
+ #
442
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
443
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
444
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
445
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
446
+ elif class_embed_type == "simple_projection":
447
+ if projection_class_embeddings_input_dim is None:
448
+ raise ValueError(
449
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
450
+ )
451
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
452
+ else:
453
+ self.class_embedding = None
454
+
455
+ if addition_embed_type == "text":
456
+ if encoder_hid_dim is not None:
457
+ text_time_embedding_from_dim = encoder_hid_dim
458
+ else:
459
+ text_time_embedding_from_dim = cross_attention_dim
460
+
461
+ self.add_embedding = TextTimeEmbedding(
462
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
463
+ )
464
+ elif addition_embed_type == "text_image":
465
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
466
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
467
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
468
+ self.add_embedding = TextImageTimeEmbedding(
469
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
470
+ )
471
+ elif addition_embed_type == "text_time":
472
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
473
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
474
+ elif addition_embed_type == "image":
475
+ # Kandinsky 2.2
476
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
477
+ elif addition_embed_type == "image_hint":
478
+ # Kandinsky 2.2 ControlNet
479
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
480
+ elif addition_embed_type is not None:
481
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
482
+
483
+ if time_embedding_act_fn is None:
484
+ self.time_embed_act = None
485
+ else:
486
+ self.time_embed_act = get_activation(time_embedding_act_fn)
487
+
488
+ self.down_blocks = nn.ModuleList([])
489
+ self.up_blocks = nn.ModuleList([])
490
+
491
+ if isinstance(only_cross_attention, bool):
492
+ if mid_block_only_cross_attention is None:
493
+ mid_block_only_cross_attention = only_cross_attention
494
+
495
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
496
+
497
+ if mid_block_only_cross_attention is None:
498
+ mid_block_only_cross_attention = False
499
+
500
+ if isinstance(num_attention_heads, int):
501
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
502
+
503
+ if isinstance(attention_head_dim, int):
504
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
505
+
506
+ if isinstance(cross_attention_dim, int):
507
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
508
+
509
+ if isinstance(layers_per_block, int):
510
+ layers_per_block = [layers_per_block] * len(down_block_types)
511
+
512
+ if isinstance(transformer_layers_per_block, int):
513
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
514
+
515
+ if class_embeddings_concat:
516
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
517
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
518
+ # regular time embeddings
519
+ blocks_time_embed_dim = time_embed_dim * 2
520
+ else:
521
+ blocks_time_embed_dim = time_embed_dim
522
+
523
+ # down
524
+ output_channel = block_out_channels[0]
525
+ for i, down_block_type in enumerate(down_block_types):
526
+ input_channel = output_channel
527
+ output_channel = block_out_channels[i]
528
+ is_final_block = i == len(block_out_channels) - 1
529
+
530
+ down_block = get_down_block(
531
+ down_block_type,
532
+ num_layers=layers_per_block[i],
533
+ transformer_layers_per_block=transformer_layers_per_block[i],
534
+ in_channels=input_channel,
535
+ out_channels=output_channel,
536
+ temb_channels=blocks_time_embed_dim,
537
+ add_downsample=not is_final_block,
538
+ resnet_eps=norm_eps,
539
+ resnet_act_fn=act_fn,
540
+ resnet_groups=norm_num_groups,
541
+ cross_attention_dim=cross_attention_dim[i],
542
+ num_attention_heads=num_attention_heads[i],
543
+ downsample_padding=downsample_padding,
544
+ dual_cross_attention=dual_cross_attention,
545
+ use_linear_projection=use_linear_projection,
546
+ only_cross_attention=only_cross_attention[i],
547
+ upcast_attention=upcast_attention,
548
+ resnet_time_scale_shift=resnet_time_scale_shift,
549
+ resnet_skip_time_act=resnet_skip_time_act,
550
+ resnet_out_scale_factor=resnet_out_scale_factor,
551
+ cross_attention_norm=cross_attention_norm,
552
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
553
+ num_views=num_views,
554
+ cd_attention_last=cd_attention_last,
555
+ cd_attention_mid=cd_attention_mid,
556
+ multiview_attention=multiview_attention,
557
+ sparse_mv_attention=sparse_mv_attention,
558
+ selfattn_block=selfattn_block,
559
+ )
560
+ self.down_blocks.append(down_block)
561
+
562
+ # mid
563
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
564
+ self.mid_block = UNetMidBlock2DCrossAttn(
565
+ transformer_layers_per_block=transformer_layers_per_block[-1],
566
+ in_channels=block_out_channels[-1],
567
+ temb_channels=blocks_time_embed_dim,
568
+ resnet_eps=norm_eps,
569
+ resnet_act_fn=act_fn,
570
+ output_scale_factor=mid_block_scale_factor,
571
+ resnet_time_scale_shift=resnet_time_scale_shift,
572
+ cross_attention_dim=cross_attention_dim[-1],
573
+ num_attention_heads=num_attention_heads[-1],
574
+ resnet_groups=norm_num_groups,
575
+ dual_cross_attention=dual_cross_attention,
576
+ use_linear_projection=use_linear_projection,
577
+ upcast_attention=upcast_attention,
578
+ )
579
+ # custom MV2D attention block
580
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
581
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
582
+ transformer_layers_per_block=transformer_layers_per_block[-1],
583
+ in_channels=block_out_channels[-1],
584
+ temb_channels=blocks_time_embed_dim,
585
+ resnet_eps=norm_eps,
586
+ resnet_act_fn=act_fn,
587
+ output_scale_factor=mid_block_scale_factor,
588
+ resnet_time_scale_shift=resnet_time_scale_shift,
589
+ cross_attention_dim=cross_attention_dim[-1],
590
+ num_attention_heads=num_attention_heads[-1],
591
+ resnet_groups=norm_num_groups,
592
+ dual_cross_attention=dual_cross_attention,
593
+ use_linear_projection=use_linear_projection,
594
+ upcast_attention=upcast_attention,
595
+ num_views=num_views,
596
+ cd_attention_last=cd_attention_last,
597
+ cd_attention_mid=cd_attention_mid,
598
+ multiview_attention=multiview_attention,
599
+ sparse_mv_attention=sparse_mv_attention,
600
+ selfattn_block=selfattn_block,
601
+ )
602
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
603
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
604
+ in_channels=block_out_channels[-1],
605
+ temb_channels=blocks_time_embed_dim,
606
+ resnet_eps=norm_eps,
607
+ resnet_act_fn=act_fn,
608
+ output_scale_factor=mid_block_scale_factor,
609
+ cross_attention_dim=cross_attention_dim[-1],
610
+ attention_head_dim=attention_head_dim[-1],
611
+ resnet_groups=norm_num_groups,
612
+ resnet_time_scale_shift=resnet_time_scale_shift,
613
+ skip_time_act=resnet_skip_time_act,
614
+ only_cross_attention=mid_block_only_cross_attention,
615
+ cross_attention_norm=cross_attention_norm,
616
+ )
617
+ elif mid_block_type is None:
618
+ self.mid_block = None
619
+ else:
620
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
621
+
622
+ self.addition_downsample = addition_downsample
623
+ if self.addition_downsample:
624
+ inc = block_out_channels[-1]
625
+ self.downsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
626
+ self.conv_block = nn.ModuleList()
627
+ self.conv_block.append(BasicConvBlock(inc, addition_channels[0], stride=1))
628
+ for dim_ in addition_channels[1:-1]:
629
+ self.conv_block.append(BasicConvBlock(dim_, dim_, stride=1))
630
+ self.conv_block.append(BasicConvBlock(dim_, inc))
631
+ self.addition_conv_out = nn.Conv2d(inc, inc, kernel_size=1, bias=False)
632
+ nn.init.zeros_(self.addition_conv_out.weight.data)
633
+ self.addition_act_out = nn.SiLU()
634
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
635
+
636
+ # count how many layers upsample the images
637
+ self.num_upsamplers = 0
638
+
639
+ # up
640
+ reversed_block_out_channels = list(reversed(block_out_channels))
641
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
642
+ reversed_layers_per_block = list(reversed(layers_per_block))
643
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
644
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
645
+ only_cross_attention = list(reversed(only_cross_attention))
646
+
647
+ output_channel = reversed_block_out_channels[0]
648
+ for i, up_block_type in enumerate(up_block_types):
649
+ is_final_block = i == len(block_out_channels) - 1
650
+
651
+ prev_output_channel = output_channel
652
+ output_channel = reversed_block_out_channels[i]
653
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
654
+
655
+ # add upsample block for all BUT final layer
656
+ if not is_final_block:
657
+ add_upsample = True
658
+ self.num_upsamplers += 1
659
+ else:
660
+ add_upsample = False
661
+
662
+ up_block = get_up_block(
663
+ up_block_type,
664
+ num_layers=reversed_layers_per_block[i] + 1,
665
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
666
+ in_channels=input_channel,
667
+ out_channels=output_channel,
668
+ prev_output_channel=prev_output_channel,
669
+ temb_channels=blocks_time_embed_dim,
670
+ add_upsample=add_upsample,
671
+ resnet_eps=norm_eps,
672
+ resnet_act_fn=act_fn,
673
+ resnet_groups=norm_num_groups,
674
+ cross_attention_dim=reversed_cross_attention_dim[i],
675
+ num_attention_heads=reversed_num_attention_heads[i],
676
+ dual_cross_attention=dual_cross_attention,
677
+ use_linear_projection=use_linear_projection,
678
+ only_cross_attention=only_cross_attention[i],
679
+ upcast_attention=upcast_attention,
680
+ resnet_time_scale_shift=resnet_time_scale_shift,
681
+ resnet_skip_time_act=resnet_skip_time_act,
682
+ resnet_out_scale_factor=resnet_out_scale_factor,
683
+ cross_attention_norm=cross_attention_norm,
684
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
685
+ num_views=num_views,
686
+ cd_attention_last=cd_attention_last,
687
+ cd_attention_mid=cd_attention_mid,
688
+ multiview_attention=multiview_attention,
689
+ sparse_mv_attention=sparse_mv_attention,
690
+ selfattn_block=selfattn_block,
691
+ )
692
+ self.up_blocks.append(up_block)
693
+ prev_output_channel = output_channel
694
+
695
+ # out
696
+ if norm_num_groups is not None:
697
+ self.conv_norm_out = nn.GroupNorm(
698
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
699
+ )
700
+
701
+ self.conv_act = get_activation(act_fn)
702
+
703
+ else:
704
+ self.conv_norm_out = None
705
+ self.conv_act = None
706
+
707
+ conv_out_padding = (conv_out_kernel - 1) // 2
708
+ self.conv_out = nn.Conv2d(
709
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
710
+ )
711
+
712
+ @property
713
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
714
+ r"""
715
+ Returns:
716
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
717
+ indexed by its weight name.
718
+ """
719
+ # set recursively
720
+ processors = {}
721
+
722
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
723
+ if hasattr(module, "set_processor"):
724
+ processors[f"{name}.processor"] = module.processor
725
+
726
+ for sub_name, child in module.named_children():
727
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
728
+
729
+ return processors
730
+
731
+ for name, module in self.named_children():
732
+ fn_recursive_add_processors(name, module, processors)
733
+
734
+ return processors
735
+
736
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
737
+ r"""
738
+ Sets the attention processor to use to compute attention.
739
+
740
+ Parameters:
741
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
742
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
743
+ for **all** `Attention` layers.
744
+
745
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
746
+ processor. This is strongly recommended when setting trainable attention processors.
747
+
748
+ """
749
+ count = len(self.attn_processors.keys())
750
+
751
+ if isinstance(processor, dict) and len(processor) != count:
752
+ raise ValueError(
753
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
754
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
755
+ )
756
+
757
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
758
+ if hasattr(module, "set_processor"):
759
+ if not isinstance(processor, dict):
760
+ module.set_processor(processor)
761
+ else:
762
+ module.set_processor(processor.pop(f"{name}.processor"))
763
+
764
+ for sub_name, child in module.named_children():
765
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
766
+
767
+ for name, module in self.named_children():
768
+ fn_recursive_attn_processor(name, module, processor)
769
+
770
+ def set_default_attn_processor(self):
771
+ """
772
+ Disables custom attention processors and sets the default attention implementation.
773
+ """
774
+ self.set_attn_processor(AttnProcessor())
775
+
776
+ def set_attention_slice(self, slice_size):
777
+ r"""
778
+ Enable sliced attention computation.
779
+
780
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
781
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
782
+
783
+ Args:
784
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
785
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
786
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
787
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
788
+ must be a multiple of `slice_size`.
789
+ """
790
+ sliceable_head_dims = []
791
+
792
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
793
+ if hasattr(module, "set_attention_slice"):
794
+ sliceable_head_dims.append(module.sliceable_head_dim)
795
+
796
+ for child in module.children():
797
+ fn_recursive_retrieve_sliceable_dims(child)
798
+
799
+ # retrieve number of attention layers
800
+ for module in self.children():
801
+ fn_recursive_retrieve_sliceable_dims(module)
802
+
803
+ num_sliceable_layers = len(sliceable_head_dims)
804
+
805
+ if slice_size == "auto":
806
+ # half the attention head size is usually a good trade-off between
807
+ # speed and memory
808
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
809
+ elif slice_size == "max":
810
+ # make smallest slice possible
811
+ slice_size = num_sliceable_layers * [1]
812
+
813
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
814
+
815
+ if len(slice_size) != len(sliceable_head_dims):
816
+ raise ValueError(
817
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
818
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
819
+ )
820
+
821
+ for i in range(len(slice_size)):
822
+ size = slice_size[i]
823
+ dim = sliceable_head_dims[i]
824
+ if size is not None and size > dim:
825
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
826
+
827
+ # Recursively walk through all the children.
828
+ # Any children which exposes the set_attention_slice method
829
+ # gets the message
830
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
831
+ if hasattr(module, "set_attention_slice"):
832
+ module.set_attention_slice(slice_size.pop())
833
+
834
+ for child in module.children():
835
+ fn_recursive_set_attention_slice(child, slice_size)
836
+
837
+ reversed_slice_size = list(reversed(slice_size))
838
+ for module in self.children():
839
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
840
+
841
+ def _set_gradient_checkpointing(self, module, value=False):
842
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
843
+ module.gradient_checkpointing = value
844
+
845
+ def forward(
846
+ self,
847
+ sample: torch.FloatTensor,
848
+ timestep: Union[torch.Tensor, float, int],
849
+ encoder_hidden_states: torch.Tensor = None,
850
+ class_labels: Optional[torch.Tensor] = None,
851
+ timestep_cond: Optional[torch.Tensor] = None,
852
+ attention_mask: Optional[torch.Tensor] = None,
853
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
854
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
855
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
856
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
857
+ encoder_attention_mask: Optional[torch.Tensor] = None,
858
+ return_dict: bool = True,
859
+ vis_max_min: bool = False,
860
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
861
+ r"""
862
+ The [`UNet2DConditionModel`] forward method.
863
+
864
+ Args:
865
+ sample (`torch.FloatTensor`):
866
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
867
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
868
+ encoder_hidden_states (`torch.FloatTensor`):
869
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
870
+ encoder_attention_mask (`torch.Tensor`):
871
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
872
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
873
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
874
+ return_dict (`bool`, *optional*, defaults to `True`):
875
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
876
+ tuple.
877
+ cross_attention_kwargs (`dict`, *optional*):
878
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
879
+ added_cond_kwargs: (`dict`, *optional*):
880
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
881
+ are passed along to the UNet blocks.
882
+
883
+ Returns:
884
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
885
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
886
+ a `tuple` is returned where the first element is the sample tensor.
887
+ """
888
+ record_max_min = {}
889
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
890
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
891
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
892
+ # on the fly if necessary.
893
+ default_overall_up_factor = 2**self.num_upsamplers
894
+
895
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
896
+ forward_upsample_size = False
897
+ upsample_size = None
898
+
899
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
900
+ logger.info("Forward upsample size to force interpolation output size.")
901
+ forward_upsample_size = True
902
+
903
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
904
+ # expects mask of shape:
905
+ # [batch, key_tokens]
906
+ # adds singleton query_tokens dimension:
907
+ # [batch, 1, key_tokens]
908
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
909
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
910
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
911
+ if attention_mask is not None:
912
+ # assume that mask is expressed as:
913
+ # (1 = keep, 0 = discard)
914
+ # convert mask into a bias that can be added to attention scores:
915
+ # (keep = +0, discard = -10000.0)
916
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
917
+ attention_mask = attention_mask.unsqueeze(1)
918
+
919
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
920
+ if encoder_attention_mask is not None:
921
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
922
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
923
+
924
+ # 0. center input if necessary
925
+ if self.config.center_input_sample:
926
+ sample = 2 * sample - 1.0
927
+ # 1. time
928
+ timesteps = timestep
929
+ if not torch.is_tensor(timesteps):
930
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
931
+ # This would be a good case for the `match` statement (Python 3.10+)
932
+ is_mps = sample.device.type == "mps"
933
+ if isinstance(timestep, float):
934
+ dtype = torch.float32 if is_mps else torch.float64
935
+ else:
936
+ dtype = torch.int32 if is_mps else torch.int64
937
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
938
+ elif len(timesteps.shape) == 0:
939
+ timesteps = timesteps[None].to(sample.device)
940
+
941
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
942
+ timesteps = timesteps.expand(sample.shape[0])
943
+
944
+ t_emb = self.time_proj(timesteps)
945
+
946
+ # `Timesteps` does not contain any weights and will always return f32 tensors
947
+ # but time_embedding might actually be running in fp16. so we need to cast here.
948
+ # there might be better ways to encapsulate this.
949
+ t_emb = t_emb.to(dtype=sample.dtype)
950
+
951
+ emb = self.time_embedding(t_emb, timestep_cond)
952
+ aug_emb = None
953
+ if self.class_embedding is not None:
954
+ if class_labels is None:
955
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
956
+
957
+ if self.config.class_embed_type == "timestep":
958
+ class_labels = self.time_proj(class_labels)
959
+
960
+ # `Timesteps` does not contain any weights and will always return f32 tensors
961
+ # there might be better ways to encapsulate this.
962
+ class_labels = class_labels.to(dtype=sample.dtype)
963
+
964
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
965
+ if self.config.class_embeddings_concat:
966
+ emb = torch.cat([emb, class_emb], dim=-1)
967
+ else:
968
+ emb = emb + class_emb
969
+
970
+ if self.config.addition_embed_type == "text":
971
+ aug_emb = self.add_embedding(encoder_hidden_states)
972
+ elif self.config.addition_embed_type == "text_image":
973
+ # Kandinsky 2.1 - style
974
+ if "image_embeds" not in added_cond_kwargs:
975
+ raise ValueError(
976
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
977
+ )
978
+
979
+ image_embs = added_cond_kwargs.get("image_embeds")
980
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
981
+ aug_emb = self.add_embedding(text_embs, image_embs)
982
+ elif self.config.addition_embed_type == "text_time":
983
+ # SDXL - style
984
+ if "text_embeds" not in added_cond_kwargs:
985
+ raise ValueError(
986
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
987
+ )
988
+ text_embeds = added_cond_kwargs.get("text_embeds")
989
+ if "time_ids" not in added_cond_kwargs:
990
+ raise ValueError(
991
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
992
+ )
993
+ time_ids = added_cond_kwargs.get("time_ids")
994
+ time_embeds = self.add_time_proj(time_ids.flatten())
995
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
996
+
997
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
998
+ add_embeds = add_embeds.to(emb.dtype)
999
+ aug_emb = self.add_embedding(add_embeds)
1000
+ elif self.config.addition_embed_type == "image":
1001
+ # Kandinsky 2.2 - style
1002
+ if "image_embeds" not in added_cond_kwargs:
1003
+ raise ValueError(
1004
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1005
+ )
1006
+ image_embs = added_cond_kwargs.get("image_embeds")
1007
+ aug_emb = self.add_embedding(image_embs)
1008
+ elif self.config.addition_embed_type == "image_hint":
1009
+ # Kandinsky 2.2 - style
1010
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1011
+ raise ValueError(
1012
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1013
+ )
1014
+ image_embs = added_cond_kwargs.get("image_embeds")
1015
+ hint = added_cond_kwargs.get("hint")
1016
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1017
+ sample = torch.cat([sample, hint], dim=1)
1018
+
1019
+ emb = emb + aug_emb if aug_emb is not None else emb
1020
+ emb_pre_act = emb
1021
+ if self.time_embed_act is not None:
1022
+ emb = self.time_embed_act(emb)
1023
+
1024
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1025
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1026
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1027
+ # Kadinsky 2.1 - style
1028
+ if "image_embeds" not in added_cond_kwargs:
1029
+ raise ValueError(
1030
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1031
+ )
1032
+
1033
+ image_embeds = added_cond_kwargs.get("image_embeds")
1034
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1035
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1036
+ # Kandinsky 2.2 - style
1037
+ if "image_embeds" not in added_cond_kwargs:
1038
+ raise ValueError(
1039
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1040
+ )
1041
+ image_embeds = added_cond_kwargs.get("image_embeds")
1042
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1043
+ # 2. pre-process
1044
+ sample = self.conv_in(sample)
1045
+ # 3. down
1046
+
1047
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1048
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1049
+
1050
+ down_block_res_samples = (sample,)
1051
+ for i, downsample_block in enumerate(self.down_blocks):
1052
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1053
+ # For t2i-adapter CrossAttnDownBlock2D
1054
+ additional_residuals = {}
1055
+ if is_adapter and len(down_block_additional_residuals) > 0:
1056
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1057
+
1058
+ sample, res_samples = downsample_block(
1059
+ hidden_states=sample,
1060
+ temb=emb,
1061
+ encoder_hidden_states=encoder_hidden_states,
1062
+ attention_mask=attention_mask,
1063
+ cross_attention_kwargs=cross_attention_kwargs,
1064
+ encoder_attention_mask=encoder_attention_mask,
1065
+ **additional_residuals,
1066
+ )
1067
+ else:
1068
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1069
+
1070
+ if is_adapter and len(down_block_additional_residuals) > 0:
1071
+ sample += down_block_additional_residuals.pop(0)
1072
+
1073
+ down_block_res_samples += res_samples
1074
+
1075
+ if is_controlnet:
1076
+ new_down_block_res_samples = ()
1077
+
1078
+ for down_block_res_sample, down_block_additional_residual in zip(
1079
+ down_block_res_samples, down_block_additional_residuals
1080
+ ):
1081
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1082
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1083
+
1084
+ down_block_res_samples = new_down_block_res_samples
1085
+
1086
+ if self.addition_downsample:
1087
+ global_sample = sample
1088
+ global_sample = self.downsample(global_sample)
1089
+ for layer in self.conv_block:
1090
+ global_sample = layer(global_sample)
1091
+ global_sample = self.addition_act_out(self.addition_conv_out(global_sample))
1092
+ global_sample = self.upsample(global_sample)
1093
+ # 4. mid
1094
+ if self.mid_block is not None:
1095
+ sample = self.mid_block(
1096
+ sample,
1097
+ emb,
1098
+ encoder_hidden_states=encoder_hidden_states,
1099
+ attention_mask=attention_mask,
1100
+ cross_attention_kwargs=cross_attention_kwargs,
1101
+ encoder_attention_mask=encoder_attention_mask,
1102
+ )
1103
+
1104
+ if is_controlnet:
1105
+ sample = sample + mid_block_additional_residual
1106
+
1107
+ if self.addition_downsample:
1108
+ sample = sample + global_sample
1109
+
1110
+ # 5. up
1111
+ for i, upsample_block in enumerate(self.up_blocks):
1112
+ is_final_block = i == len(self.up_blocks) - 1
1113
+
1114
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1115
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1116
+
1117
+ # if we have not reached the final block and need to forward the
1118
+ # upsample size, we do it here
1119
+ if not is_final_block and forward_upsample_size:
1120
+ upsample_size = down_block_res_samples[-1].shape[2:]
1121
+
1122
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1123
+ sample = upsample_block(
1124
+ hidden_states=sample,
1125
+ temb=emb,
1126
+ res_hidden_states_tuple=res_samples,
1127
+ encoder_hidden_states=encoder_hidden_states,
1128
+ cross_attention_kwargs=cross_attention_kwargs,
1129
+ upsample_size=upsample_size,
1130
+ attention_mask=attention_mask,
1131
+ encoder_attention_mask=encoder_attention_mask,
1132
+ )
1133
+ else:
1134
+ sample = upsample_block(
1135
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1136
+ )
1137
+ if torch.isnan(sample).any() or torch.isinf(sample).any():
1138
+ print("NAN in sample, stop training.")
1139
+ exit()
1140
+ # 6. post-process
1141
+ if self.conv_norm_out:
1142
+ sample = self.conv_norm_out(sample)
1143
+ sample = self.conv_act(sample)
1144
+ sample = self.conv_out(sample)
1145
+ if not return_dict:
1146
+ return sample
1147
+ return UNetMV2DConditionOutput(sample=sample)
1148
+
1149
+
1150
+ @classmethod
1151
+ def from_pretrained_2d(
1152
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1153
+ num_views: int, sample_size: int,
1154
+ zero_init_conv_in: bool = True,
1155
+ cd_attention_last: bool = False,
1156
+ cd_attention_mid: bool = False, multiview_attention: bool = True,
1157
+ sparse_mv_attention: bool = False, selfattn_block: str = 'custom',
1158
+ in_channels: int = 8, out_channels: int = 4, unclip: bool = False,
1159
+ init_mvattn_with_selfattn: bool= False, addition_downsample: bool = False,
1160
+ **kwargs
1161
+ ):
1162
+ r"""
1163
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1164
+
1165
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1166
+ train the model, set it back in training mode with `model.train()`.
1167
+
1168
+ Parameters:
1169
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1170
+ Can be either:
1171
+
1172
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1173
+ the Hub.
1174
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1175
+ with [`~ModelMixin.save_pretrained`].
1176
+
1177
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1178
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1179
+ is not used.
1180
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1181
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1182
+ dtype is automatically derived from the model's weights.
1183
+ force_download (`bool`, *optional*, defaults to `False`):
1184
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1185
+ cached versions if they exist.
1186
+
1187
+ proxies (`Dict[str, str]`, *optional*):
1188
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1189
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1190
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1191
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1192
+ local_files_only(`bool`, *optional*, defaults to `False`):
1193
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1194
+ won't be downloaded from the Hub.
1195
+ use_auth_token (`str` or *bool*, *optional*):
1196
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1197
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1198
+ revision (`str`, *optional*, defaults to `"main"`):
1199
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1200
+ allowed by Git.
1201
+ from_flax (`bool`, *optional*, defaults to `False`):
1202
+ Load the model weights from a Flax checkpoint save file.
1203
+ subfolder (`str`, *optional*, defaults to `""`):
1204
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1205
+ mirror (`str`, *optional*):
1206
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1207
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1208
+ information.
1209
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1210
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1211
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1212
+ same device.
1213
+
1214
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1215
+ more information about each option see [designing a device
1216
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1217
+ max_memory (`Dict`, *optional*):
1218
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1219
+ each GPU and the available CPU RAM if unset.
1220
+ offload_folder (`str` or `os.PathLike`, *optional*):
1221
+ The path to offload weights if `device_map` contains the value `"disk"`.
1222
+ offload_state_dict (`bool`, *optional*):
1223
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1224
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1225
+ when there is some disk offload.
1226
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1227
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1228
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1229
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1230
+ argument to `True` will raise an error.
1231
+ variant (`str`, *optional*):
1232
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1233
+ loading `from_flax`.
1234
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1235
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1236
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1237
+ weights. If set to `False`, `safetensors` weights are not loaded.
1238
+
1239
+ <Tip>
1240
+
1241
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1242
+ `huggingface-cli login`. You can also activate the special
1243
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1244
+ firewalled environment.
1245
+
1246
+ </Tip>
1247
+
1248
+ Example:
1249
+
1250
+ ```py
1251
+ from diffusers import UNet2DConditionModel
1252
+
1253
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1254
+ ```
1255
+
1256
+ If you get the error message below, you need to finetune the weights for your downstream task:
1257
+
1258
+ ```bash
1259
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1260
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1261
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1262
+ ```
1263
+ """
1264
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1265
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1266
+ force_download = kwargs.pop("force_download", False)
1267
+ from_flax = kwargs.pop("from_flax", False)
1268
+ resume_download = kwargs.pop("resume_download", False)
1269
+ proxies = kwargs.pop("proxies", None)
1270
+ output_loading_info = kwargs.pop("output_loading_info", False)
1271
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1272
+ use_auth_token = kwargs.pop("use_auth_token", None)
1273
+ revision = kwargs.pop("revision", None)
1274
+ torch_dtype = kwargs.pop("torch_dtype", None)
1275
+ subfolder = kwargs.pop("subfolder", None)
1276
+ device_map = kwargs.pop("device_map", None)
1277
+ max_memory = kwargs.pop("max_memory", None)
1278
+ offload_folder = kwargs.pop("offload_folder", None)
1279
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1280
+ variant = kwargs.pop("variant", None)
1281
+ use_safetensors = kwargs.pop("use_safetensors", None)
1282
+
1283
+ if use_safetensors:
1284
+ raise ValueError(
1285
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1286
+ )
1287
+
1288
+ allow_pickle = False
1289
+ if use_safetensors is None:
1290
+ use_safetensors = True
1291
+ allow_pickle = True
1292
+
1293
+ if device_map is not None and not is_accelerate_available():
1294
+ raise NotImplementedError(
1295
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1296
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1297
+ )
1298
+
1299
+ # Check if we can handle device_map and dispatching the weights
1300
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1301
+ raise NotImplementedError(
1302
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1303
+ " `device_map=None`."
1304
+ )
1305
+
1306
+ # Load config if we don't provide a configuration
1307
+ config_path = pretrained_model_name_or_path
1308
+
1309
+ user_agent = {
1310
+ "diffusers": __version__,
1311
+ "file_type": "model",
1312
+ "framework": "pytorch",
1313
+ }
1314
+
1315
+ # load config
1316
+ config, unused_kwargs, commit_hash = cls.load_config(
1317
+ config_path,
1318
+ cache_dir=cache_dir,
1319
+ return_unused_kwargs=True,
1320
+ return_commit_hash=True,
1321
+ force_download=force_download,
1322
+ resume_download=resume_download,
1323
+ proxies=proxies,
1324
+ local_files_only=local_files_only,
1325
+ use_auth_token=use_auth_token,
1326
+ revision=revision,
1327
+ subfolder=subfolder,
1328
+ device_map=device_map,
1329
+ max_memory=max_memory,
1330
+ offload_folder=offload_folder,
1331
+ offload_state_dict=offload_state_dict,
1332
+ user_agent=user_agent,
1333
+ **kwargs,
1334
+ )
1335
+
1336
+ # modify config
1337
+ config["_class_name"] = cls.__name__
1338
+ config['in_channels'] = in_channels
1339
+ config['out_channels'] = out_channels
1340
+ config['sample_size'] = sample_size # training resolution
1341
+ config['num_views'] = num_views
1342
+ config['cd_attention_last'] = cd_attention_last
1343
+ config['cd_attention_mid'] = cd_attention_mid
1344
+ config['multiview_attention'] = multiview_attention
1345
+ config['sparse_mv_attention'] = sparse_mv_attention
1346
+ config['selfattn_block'] = selfattn_block
1347
+ config["down_block_types"] = [
1348
+ "CrossAttnDownBlockMV2D",
1349
+ "CrossAttnDownBlockMV2D",
1350
+ "CrossAttnDownBlockMV2D",
1351
+ "DownBlock2D"
1352
+ ]
1353
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1354
+ config["up_block_types"] = [
1355
+ "UpBlock2D",
1356
+ "CrossAttnUpBlockMV2D",
1357
+ "CrossAttnUpBlockMV2D",
1358
+ "CrossAttnUpBlockMV2D"
1359
+ ]
1360
+
1361
+ config['addition_downsample'] = addition_downsample
1362
+ # load model
1363
+ model_file = None
1364
+ if from_flax:
1365
+ raise NotImplementedError
1366
+ else:
1367
+ if use_safetensors:
1368
+ try:
1369
+ model_file = _get_model_file(
1370
+ pretrained_model_name_or_path,
1371
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1372
+ cache_dir=cache_dir,
1373
+ force_download=force_download,
1374
+ proxies=proxies,
1375
+ local_files_only=local_files_only,
1376
+ use_auth_token=use_auth_token,
1377
+ revision=revision,
1378
+ subfolder=subfolder,
1379
+ user_agent=user_agent,
1380
+ commit_hash=commit_hash,
1381
+ )
1382
+ except IOError as e:
1383
+ if not allow_pickle:
1384
+ raise e
1385
+ pass
1386
+ if model_file is None:
1387
+ model_file = _get_model_file(
1388
+ pretrained_model_name_or_path,
1389
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1390
+ cache_dir=cache_dir,
1391
+ force_download=force_download,
1392
+ proxies=proxies,
1393
+ local_files_only=local_files_only,
1394
+ use_auth_token=use_auth_token,
1395
+ revision=revision,
1396
+ subfolder=subfolder,
1397
+ user_agent=user_agent,
1398
+ commit_hash=commit_hash,
1399
+ )
1400
+
1401
+ model = cls.from_config(config, **unused_kwargs)
1402
+ import copy
1403
+ state_dict_pretrain = load_state_dict(model_file, variant=variant)
1404
+ state_dict = copy.deepcopy(state_dict_pretrain)
1405
+
1406
+ if init_mvattn_with_selfattn:
1407
+ for key in state_dict_pretrain:
1408
+ if 'attn1' in key:
1409
+ key_mv = key.replace('attn1', 'attn_mv')
1410
+ state_dict[key_mv] = state_dict_pretrain[key]
1411
+ if 'to_out.0.weight' in key:
1412
+ nn.init.zeros_(state_dict[key_mv].data)
1413
+ if 'transformer_blocks' in key and 'norm1' in key: # in case that initialize the norm layer in resnet block
1414
+ key_mv = key.replace('norm1', 'norm_mv')
1415
+ state_dict[key_mv] = state_dict_pretrain[key]
1416
+ # del state_dict_pretrain
1417
+
1418
+ model._convert_deprecated_attention_blocks(state_dict)
1419
+
1420
+ conv_in_weight = state_dict['conv_in.weight']
1421
+ conv_out_weight = state_dict['conv_out.weight']
1422
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1423
+ model,
1424
+ state_dict,
1425
+ model_file,
1426
+ pretrained_model_name_or_path,
1427
+ ignore_mismatched_sizes=True,
1428
+ )
1429
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1430
+ # initialize from the original SD structure
1431
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1432
+
1433
+ # whether to place all zero to new layers?
1434
+ if zero_init_conv_in:
1435
+ model.conv_in.weight.data[:,4:] = 0.
1436
+
1437
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1438
+ # initialize from the original SD structure
1439
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1440
+ if out_channels == 8: # copy for the last 4 channels
1441
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1442
+
1443
+ loading_info = {
1444
+ "missing_keys": missing_keys,
1445
+ "unexpected_keys": unexpected_keys,
1446
+ "mismatched_keys": mismatched_keys,
1447
+ "error_msgs": error_msgs,
1448
+ }
1449
+
1450
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1451
+ raise ValueError(
1452
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1453
+ )
1454
+ elif torch_dtype is not None:
1455
+ model = model.to(torch_dtype)
1456
+
1457
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1458
+
1459
+ # Set model in evaluation mode to deactivate DropOut modules by default
1460
+ model.eval()
1461
+ if output_loading_info:
1462
+ return model, loading_info
1463
+ return model
1464
+
1465
+ @classmethod
1466
+ def _load_pretrained_model_2d(
1467
+ cls,
1468
+ model,
1469
+ state_dict,
1470
+ resolved_archive_file,
1471
+ pretrained_model_name_or_path,
1472
+ ignore_mismatched_sizes=False,
1473
+ ):
1474
+ # Retrieve missing & unexpected_keys
1475
+ model_state_dict = model.state_dict()
1476
+ loaded_keys = list(state_dict.keys())
1477
+
1478
+ expected_keys = list(model_state_dict.keys())
1479
+
1480
+ original_loaded_keys = loaded_keys
1481
+
1482
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1483
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1484
+
1485
+ # Make sure we are able to load base models as well as derived models (with heads)
1486
+ model_to_load = model
1487
+
1488
+ def _find_mismatched_keys(
1489
+ state_dict,
1490
+ model_state_dict,
1491
+ loaded_keys,
1492
+ ignore_mismatched_sizes,
1493
+ ):
1494
+ mismatched_keys = []
1495
+ if ignore_mismatched_sizes:
1496
+ for checkpoint_key in loaded_keys:
1497
+ model_key = checkpoint_key
1498
+
1499
+ if (
1500
+ model_key in model_state_dict
1501
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1502
+ ):
1503
+ mismatched_keys.append(
1504
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1505
+ )
1506
+ del state_dict[checkpoint_key]
1507
+ return mismatched_keys
1508
+
1509
+ if state_dict is not None:
1510
+ # Whole checkpoint
1511
+ mismatched_keys = _find_mismatched_keys(
1512
+ state_dict,
1513
+ model_state_dict,
1514
+ original_loaded_keys,
1515
+ ignore_mismatched_sizes,
1516
+ )
1517
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1518
+
1519
+ if len(error_msgs) > 0:
1520
+ error_msg = "\n\t".join(error_msgs)
1521
+ if "size mismatch" in error_msg:
1522
+ error_msg += (
1523
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1524
+ )
1525
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1526
+
1527
+ if len(unexpected_keys) > 0:
1528
+ logger.warning(
1529
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1530
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1531
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1532
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1533
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1534
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1535
+ " identical (initializing a BertForSequenceClassification model from a"
1536
+ " BertForSequenceClassification model)."
1537
+ )
1538
+ else:
1539
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1540
+ if len(missing_keys) > 0:
1541
+ logger.warning(
1542
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1543
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1544
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1545
+ )
1546
+ elif len(mismatched_keys) == 0:
1547
+ logger.info(
1548
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1549
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1550
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1551
+ " without further training."
1552
+ )
1553
+ if len(mismatched_keys) > 0:
1554
+ mismatched_warning = "\n".join(
1555
+ [
1556
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1557
+ for key, shape1, shape2 in mismatched_keys
1558
+ ]
1559
+ )
1560
+ logger.warning(
1561
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1562
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1563
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1564
+ " able to use it for predictions and inference."
1565
+ )
1566
+
1567
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1568
+
mvdiffusion/pipelines/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # Modified from https://github.com/pengHTYX/Era3D/blob/main/mvdiffusion/pipelines/pipeline_mvdiffusion_unclip.py
8
+ #
9
+ # For inquiries contact: wlyu3@ucmerced.edu
10
+
11
+ import inspect
12
+ import warnings
13
+ from typing import Callable, List, Optional, Union, Dict, Any
14
+ import PIL
15
+ import torch
16
+ from packaging import version
17
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPFeatureExtractor, CLIPTokenizer, CLIPTextModel
18
+ from diffusers.utils.import_utils import is_accelerate_available
19
+ from diffusers.configuration_utils import FrozenDict
20
+ from diffusers.image_processor import VaeImageProcessor
21
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
22
+ from diffusers.models.embeddings import get_timestep_embedding
23
+ from diffusers.schedulers import KarrasDiffusionSchedulers
24
+ from diffusers.utils import deprecate, logging
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27
+ from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
28
+ import os
29
+ import torchvision.transforms.functional as TF
30
+ from einops import rearrange
31
+ logger = logging.get_logger(__name__)
32
+
33
+ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
34
+ """
35
+ Pipeline for text-guided image to image generation using stable unCLIP.
36
+
37
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
38
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
39
+
40
+ Args:
41
+ feature_extractor ([`CLIPFeatureExtractor`]):
42
+ Feature extractor for image pre-processing before being encoded.
43
+ image_encoder ([`CLIPVisionModelWithProjection`]):
44
+ CLIP vision model for encoding images.
45
+ image_normalizer ([`StableUnCLIPImageNormalizer`]):
46
+ Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
47
+ embeddings after the noise has been applied.
48
+ image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
49
+ Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
50
+ by `noise_level` in `StableUnCLIPPipeline.__call__`.
51
+ tokenizer (`CLIPTokenizer`):
52
+ Tokenizer of class
53
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
54
+ text_encoder ([`CLIPTextModel`]):
55
+ Frozen text-encoder.
56
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
57
+ scheduler ([`KarrasDiffusionSchedulers`]):
58
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
59
+ vae ([`AutoencoderKL`]):
60
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
61
+ """
62
+ # image encoding components
63
+ feature_extractor: CLIPFeatureExtractor
64
+ image_encoder: CLIPVisionModelWithProjection
65
+ # image noising components
66
+ image_normalizer: StableUnCLIPImageNormalizer
67
+ image_noising_scheduler: KarrasDiffusionSchedulers
68
+ # regular denoising components
69
+ tokenizer: CLIPTokenizer
70
+ text_encoder: CLIPTextModel
71
+ unet: UNet2DConditionModel
72
+ scheduler: KarrasDiffusionSchedulers
73
+ vae: AutoencoderKL
74
+
75
+ def __init__(
76
+ self,
77
+ # image encoding components
78
+ feature_extractor: CLIPFeatureExtractor,
79
+ image_encoder: CLIPVisionModelWithProjection,
80
+ # image noising components
81
+ image_normalizer: StableUnCLIPImageNormalizer,
82
+ image_noising_scheduler: KarrasDiffusionSchedulers,
83
+ # regular denoising components
84
+ tokenizer: CLIPTokenizer,
85
+ text_encoder: CLIPTextModel,
86
+ unet: UNet2DConditionModel,
87
+ scheduler: KarrasDiffusionSchedulers,
88
+ # vae
89
+ vae: AutoencoderKL,
90
+ num_views: int = 6,
91
+ ):
92
+ super().__init__()
93
+
94
+ self.register_modules(
95
+ feature_extractor=feature_extractor,
96
+ image_encoder=image_encoder,
97
+ image_normalizer=image_normalizer,
98
+ image_noising_scheduler=image_noising_scheduler,
99
+ tokenizer=tokenizer,
100
+ text_encoder=text_encoder,
101
+ unet=unet,
102
+ scheduler=scheduler,
103
+ vae=vae,
104
+ )
105
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
106
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
107
+ self.num_views: int = num_views
108
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
109
+ def enable_vae_slicing(self):
110
+ r"""
111
+ Enable sliced VAE decoding.
112
+
113
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
114
+ steps. This is useful to save some memory and allow larger batch sizes.
115
+ """
116
+ self.vae.enable_slicing()
117
+
118
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
119
+ def disable_vae_slicing(self):
120
+ r"""
121
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
122
+ computing decoding in one step.
123
+ """
124
+ self.vae.disable_slicing()
125
+
126
+ def enable_sequential_cpu_offload(self, gpu_id=0):
127
+ r"""
128
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
129
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
130
+ when their specific submodule has its `forward` method called.
131
+ """
132
+ if is_accelerate_available():
133
+ from accelerate import cpu_offload
134
+ else:
135
+ raise ImportError("Please install accelerate via `pip install accelerate`")
136
+
137
+ device = torch.device(f"cuda:{gpu_id}")
138
+
139
+ # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list
140
+ models = [
141
+ self.image_encoder,
142
+ self.text_encoder,
143
+ self.unet,
144
+ self.vae,
145
+ ]
146
+ for cpu_offloaded_model in models:
147
+ if cpu_offloaded_model is not None:
148
+ cpu_offload(cpu_offloaded_model, device)
149
+
150
+ @property
151
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
152
+ def _execution_device(self):
153
+ r"""
154
+ Returns the device on which the pipeline's models will be executed. After calling
155
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
156
+ hooks.
157
+ """
158
+ if not hasattr(self.unet, "_hf_hook"):
159
+ return self.device
160
+ for module in self.unet.modules():
161
+ if (
162
+ hasattr(module, "_hf_hook")
163
+ and hasattr(module._hf_hook, "execution_device")
164
+ and module._hf_hook.execution_device is not None
165
+ ):
166
+ return torch.device(module._hf_hook.execution_device)
167
+ return self.device
168
+
169
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
170
+ def _encode_prompt(
171
+ self,
172
+ prompt,
173
+ device,
174
+ num_images_per_prompt,
175
+ do_classifier_free_guidance,
176
+ negative_prompt=None,
177
+ prompt_embeds: Optional[torch.FloatTensor] = None,
178
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
179
+ lora_scale: Optional[float] = None,
180
+ ):
181
+ r"""
182
+ Encodes the prompt into text encoder hidden states.
183
+
184
+ Args:
185
+ prompt (`str` or `List[str]`, *optional*):
186
+ prompt to be encoded
187
+ device: (`torch.device`):
188
+ torch device
189
+ num_images_per_prompt (`int`):
190
+ number of images that should be generated per prompt
191
+ do_classifier_free_guidance (`bool`):
192
+ whether to use classifier free guidance or not
193
+ negative_prompt (`str` or `List[str]`, *optional*):
194
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
195
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
196
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
197
+ prompt_embeds (`torch.FloatTensor`, *optional*):
198
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
199
+ provided, text embeddings will be generated from `prompt` input argument.
200
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
201
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
202
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
203
+ argument.
204
+ """
205
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
206
+
207
+ if do_classifier_free_guidance:
208
+ # For classifier free guidance, we need to do two forward passes.
209
+ # Here we concatenate the unconditional and text embeddings into a single batch
210
+ # to avoid doing two forward passes
211
+ # normal_prompt_embeds, color_prompt_embeds = torch.chunk(prompt_embeds, 2, dim=0)
212
+ color_prompt_embeds = prompt_embeds
213
+
214
+ prompt_embeds = torch.cat([color_prompt_embeds, color_prompt_embeds], 0)
215
+
216
+ return prompt_embeds
217
+
218
+ def _encode_image(
219
+ self,
220
+ image_pil,
221
+ device,
222
+ num_images_per_prompt,
223
+ do_classifier_free_guidance,
224
+ noise_level: int=0,
225
+ generator: Optional[torch.Generator] = None
226
+ ):
227
+ dtype = next(self.image_encoder.parameters()).dtype
228
+ # ______________________________clip image embedding______________________________
229
+ image = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
230
+ image = image.to(device=device, dtype=dtype)
231
+ image_embeds = self.image_encoder(image).image_embeds
232
+
233
+ image_embeds = self.noise_image_embeddings(
234
+ image_embeds=image_embeds,
235
+ noise_level=noise_level,
236
+ generator=generator,
237
+ )
238
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
239
+ # image_embeds = image_embeds.unsqueeze(1)
240
+ # note: the condition input is same
241
+ image_embeds = image_embeds.repeat(num_images_per_prompt, 1)
242
+
243
+ if do_classifier_free_guidance:
244
+ # For classifier free guidance, we need to do two forward passes.
245
+ # Here we concatenate the unconditional and text embeddings into a single batch
246
+ # to avoid doing two forward passes
247
+ negative_prompt_embeds = torch.zeros_like(image_embeds)
248
+ image_embeds = torch.cat([negative_prompt_embeds, image_embeds])
249
+
250
+ # _____________________________vae input latents__________________________________________________
251
+ image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
252
+ image_pt = image_pt * 2.0 - 1.0
253
+ ###### Fix [RuntimeError: Input type (float) and bias type (c10::Half) should be the same] ######
254
+ image_pt = image_pt.to(torch.float16)
255
+ ###### Fix [RuntimeError: Input type (float) and bias type (c10::Half) should be the same] ######
256
+ image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
257
+ # Note: repeat differently from official pipelines
258
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
259
+
260
+ if do_classifier_free_guidance:
261
+ image_latents = torch.cat([torch.zeros_like(image_latents), image_latents])
262
+
263
+ return image_embeds, image_latents
264
+
265
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
266
+ def decode_latents(self, latents):
267
+ latents = 1 / self.vae.config.scaling_factor * latents
268
+ image = self.vae.decode(latents).sample
269
+ image = (image / 2 + 0.5).clamp(0, 1)
270
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
271
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
272
+ return image
273
+
274
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
275
+ def prepare_extra_step_kwargs(self, generator, eta):
276
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
277
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
278
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
279
+ # and should be between [0, 1]
280
+
281
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
282
+ extra_step_kwargs = {}
283
+ if accepts_eta:
284
+ extra_step_kwargs["eta"] = eta
285
+
286
+ # check if the scheduler accepts generator
287
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
288
+ if accepts_generator:
289
+ extra_step_kwargs["generator"] = generator
290
+ return extra_step_kwargs
291
+
292
+ def check_inputs(
293
+ self,
294
+ prompt,
295
+ image,
296
+ height,
297
+ width,
298
+ callback_steps,
299
+ noise_level,
300
+ ):
301
+ if height % 8 != 0 or width % 8 != 0:
302
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
303
+
304
+ if (callback_steps is None) or (
305
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
306
+ ):
307
+ raise ValueError(
308
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
309
+ f" {type(callback_steps)}."
310
+ )
311
+
312
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
313
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
314
+
315
+
316
+ if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
317
+ raise ValueError(
318
+ f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive."
319
+ )
320
+
321
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
322
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
323
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
324
+ if isinstance(generator, list) and len(generator) != batch_size:
325
+ raise ValueError(
326
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
327
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
328
+ )
329
+
330
+ if latents is None:
331
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
332
+ else:
333
+ latents = latents.to(device)
334
+
335
+ # scale the initial noise by the standard deviation required by the scheduler
336
+ latents = latents * self.scheduler.init_noise_sigma
337
+ return latents
338
+
339
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings
340
+ def noise_image_embeddings(
341
+ self,
342
+ image_embeds: torch.Tensor,
343
+ noise_level: int,
344
+ noise: Optional[torch.FloatTensor] = None,
345
+ generator: Optional[torch.Generator] = None,
346
+ ):
347
+ """
348
+ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
349
+ `noise_level` increases the variance in the final un-noised images.
350
+
351
+ The noise is applied in two ways
352
+ 1. A noise schedule is applied directly to the embeddings
353
+ 2. A vector of sinusoidal time embeddings are appended to the output.
354
+
355
+ In both cases, the amount of noise is controlled by the same `noise_level`.
356
+
357
+ The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
358
+ """
359
+ if noise is None:
360
+ noise = randn_tensor(
361
+ image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype
362
+ )
363
+
364
+ noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
365
+
366
+ image_embeds = self.image_normalizer.scale(image_embeds)
367
+
368
+ image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
369
+
370
+ image_embeds = self.image_normalizer.unscale(image_embeds)
371
+
372
+ noise_level = get_timestep_embedding(
373
+ timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0
374
+ )
375
+
376
+ # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors,
377
+ # but we might actually be running in fp16. so we need to cast here.
378
+ # there might be better ways to encapsulate this.
379
+ noise_level = noise_level.to(image_embeds.dtype)
380
+
381
+ image_embeds = torch.cat((image_embeds, noise_level), 1)
382
+
383
+ return image_embeds
384
+
385
+ @torch.no_grad()
386
+ # @replace_example_docstring(EXAMPLE_DOC_STRING)
387
+ def __call__(
388
+ self,
389
+ image: Union[torch.FloatTensor, PIL.Image.Image],
390
+ prompt: Union[str, List[str]],
391
+ prompt_embeds: torch.FloatTensor = None,
392
+ height: Optional[int] = None,
393
+ width: Optional[int] = None,
394
+ num_inference_steps: int = 20,
395
+ guidance_scale: float = 10,
396
+ negative_prompt: Optional[Union[str, List[str]]] = None,
397
+ num_images_per_prompt: Optional[int] = 1,
398
+ eta: float = 0.0,
399
+ generator: Optional[torch.Generator] = None,
400
+ latents: Optional[torch.FloatTensor] = None,
401
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
402
+ output_type: Optional[str] = "pil",
403
+ return_dict: bool = True,
404
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
405
+ callback_steps: int = 1,
406
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
407
+ noise_level: int = 0,
408
+ image_embeds: Optional[torch.FloatTensor] = None,
409
+ gt_img_in: Optional[torch.FloatTensor] = None,
410
+ ):
411
+ r"""
412
+ Function invoked when calling the pipeline for generation.
413
+
414
+ Args:
415
+ prompt (`str` or `List[str]`, *optional*):
416
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
417
+ instead.
418
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
419
+ `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
420
+ the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
421
+ latents in the denoising process such as in the standard stable diffusion text guided image variation
422
+ process.
423
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
424
+ The height in pixels of the generated image.
425
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
426
+ The width in pixels of the generated image.
427
+ num_inference_steps (`int`, *optional*, defaults to 20):
428
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
429
+ expense of slower inference.
430
+ guidance_scale (`float`, *optional*, defaults to 10.0):
431
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
432
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
433
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
434
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
435
+ usually at the expense of lower image quality.
436
+ negative_prompt (`str` or `List[str]`, *optional*):
437
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
438
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
439
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
440
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
441
+ The number of images to generate per prompt.
442
+ eta (`float`, *optional*, defaults to 0.0):
443
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
444
+ [`schedulers.DDIMScheduler`], will be ignored for others.
445
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
446
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
447
+ to make generation deterministic.
448
+ latents (`torch.FloatTensor`, *optional*):
449
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
450
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
451
+ tensor will ge generated by sampling using the supplied random `generator`.
452
+ prompt_embeds (`torch.FloatTensor`, *optional*):
453
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
454
+ provided, text embeddings will be generated from `prompt` input argument.
455
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
456
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
457
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
458
+ argument.
459
+ output_type (`str`, *optional*, defaults to `"pil"`):
460
+ The output format of the generate image. Choose between
461
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
462
+ return_dict (`bool`, *optional*, defaults to `True`):
463
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
464
+ plain tuple.
465
+ callback (`Callable`, *optional*):
466
+ A function that will be called every `callback_steps` steps during inference. The function will be
467
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
468
+ callback_steps (`int`, *optional*, defaults to 1):
469
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
470
+ called at every step.
471
+ cross_attention_kwargs (`dict`, *optional*):
472
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
473
+ `self.processor` in
474
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
475
+ noise_level (`int`, *optional*, defaults to `0`):
476
+ The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
477
+ the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details.
478
+ image_embeds (`torch.FloatTensor`, *optional*):
479
+ Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in
480
+ the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as
481
+ `latents`.
482
+
483
+ Examples:
484
+
485
+ Returns:
486
+ [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is
487
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
488
+ """
489
+ # 0. Default height and width to unet
490
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
491
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
492
+
493
+ # 1. Check inputs. Raise error if not correct
494
+ self.check_inputs(
495
+ prompt=prompt,
496
+ image=image,
497
+ height=height,
498
+ width=width,
499
+ callback_steps=callback_steps,
500
+ noise_level=noise_level
501
+ )
502
+
503
+ # 2. Define call parameters
504
+ if isinstance(image, list):
505
+ batch_size = len(image)
506
+ elif isinstance(image, torch.Tensor):
507
+ batch_size = image.shape[0]
508
+ assert batch_size >= self.num_views and batch_size % self.num_views == 0
509
+ elif isinstance(image, PIL.Image.Image):
510
+ image = [image]*self.num_views
511
+ batch_size = self.num_views
512
+
513
+ if isinstance(prompt, str):
514
+ prompt = [prompt] * self.num_views
515
+
516
+ device = self._execution_device
517
+
518
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
519
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
520
+ # corresponds to doing no classifier free guidance.
521
+ do_classifier_free_guidance = guidance_scale != 1.0
522
+
523
+ # 3. Encode input prompt
524
+ text_encoder_lora_scale = (
525
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
526
+ )
527
+ prompt_embeds = self._encode_prompt(
528
+ prompt=prompt,
529
+ device=device,
530
+ num_images_per_prompt=num_images_per_prompt,
531
+ do_classifier_free_guidance=do_classifier_free_guidance,
532
+ negative_prompt=negative_prompt,
533
+ prompt_embeds=prompt_embeds,
534
+ negative_prompt_embeds=negative_prompt_embeds,
535
+ lora_scale=text_encoder_lora_scale,
536
+ )
537
+
538
+
539
+ # 4. Encoder input image
540
+ if isinstance(image, list):
541
+ image_pil = image
542
+ elif isinstance(image, torch.Tensor):
543
+ image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
544
+ noise_level = torch.tensor([noise_level], device=device)
545
+ image_embeds, image_latents = self._encode_image(
546
+ image_pil=image_pil,
547
+ device=device,
548
+ num_images_per_prompt=num_images_per_prompt,
549
+ do_classifier_free_guidance=do_classifier_free_guidance,
550
+ noise_level=noise_level,
551
+ generator=generator,
552
+ )
553
+
554
+ # 5. Prepare timesteps
555
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
556
+ timesteps = self.scheduler.timesteps
557
+
558
+ # 6. Prepare latent variables
559
+ num_channels_latents = self.unet.config.out_channels
560
+ if gt_img_in is not None:
561
+ latents = gt_img_in * self.scheduler.init_noise_sigma
562
+ else:
563
+ latents = self.prepare_latents(
564
+ batch_size=batch_size,
565
+ num_channels_latents=num_channels_latents,
566
+ height=height,
567
+ width=width,
568
+ dtype=prompt_embeds.dtype,
569
+ device=device,
570
+ generator=generator,
571
+ latents=latents,
572
+ )
573
+
574
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
575
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
576
+
577
+ eles, focals = [], []
578
+ # 8. Denoising loop
579
+ for i, t in enumerate(self.progress_bar(timesteps)):
580
+ if do_classifier_free_guidance:
581
+ latent_model_input = torch.cat([latents, latents], 0)
582
+ else:
583
+ latent_model_input = latents
584
+ latent_model_input = torch.cat([
585
+ latent_model_input, image_latents
586
+ ], dim=1)
587
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
588
+
589
+ # predict the noise residual
590
+ unet_out = self.unet(
591
+ latent_model_input,
592
+ t,
593
+ encoder_hidden_states=prompt_embeds,
594
+ class_labels=image_embeds,
595
+ cross_attention_kwargs=cross_attention_kwargs,
596
+ return_dict=False)
597
+
598
+ noise_pred = unet_out
599
+
600
+ # perform guidance
601
+ if do_classifier_free_guidance:
602
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
603
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
604
+
605
+ # compute the previous noisy sample x_t -> x_t-1
606
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
607
+
608
+ if callback is not None and i % callback_steps == 0:
609
+ callback(i, t, latents)
610
+
611
+ # 9. Post-processing
612
+ if not output_type == "latent":
613
+ if num_channels_latents == 8:
614
+ latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0)
615
+ with torch.no_grad():
616
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
617
+ else:
618
+ image = latents
619
+
620
+ image = self.image_processor.postprocess(image, output_type=output_type)
621
+
622
+ # Offload last model to CPU
623
+ # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
624
+ # self.final_offload_hook.offload()
625
+ if not return_dict:
626
+ return (image, )
627
+ return ImagePipelineOutput(images=image)
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ numpy>=1.24.0
5
+ pillow>=9.5.0
6
+
7
+ # Deep learning frameworks
8
+ diffusers>=0.27.0
9
+ transformers>=4.30.0
10
+ accelerate>=0.20.0
11
+
12
+ # Image processing
13
+ opencv-python>=4.8.0
14
+ rembg>=2.0.50
15
+ facenet-pytorch>=2.5.3
16
+
17
+ # 3D and rendering
18
+ diff-gaussian-rasterization
19
+ einops>=0.7.0
20
+ plyfile>=0.9
21
+
22
+ # Utilities
23
+ easydict>=1.10
24
+ pyyaml>=6.0
25
+ lpips>=0.1.4
26
+ huggingface-hub>=0.19.0
27
+
28
+ # Video processing
29
+ videoio>=0.2.0
30
+
31
+ # Gradio for UI
32
+ gradio>=5.0.0
33
+
34
+ # Optional performance
35
+ xformers>=0.0.20
36
+
utils_folder/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
utils_folder/face_utils.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025, FaceLift Research Group
2
+ # https://github.com/weijielyu/FaceLift
3
+ #
4
+ # This software is free for non-commercial, research and evaluation use
5
+ # under the terms of the LICENSE.md file.
6
+ #
7
+ # For inquiries contact: wlyu3@ucmerced.edu
8
+
9
+ """
10
+ Face detection and cropping utilities for 3D face reconstruction.
11
+
12
+ This module provides functions for face detection, cropping, and preprocessing
13
+ to align faces with training data specifications.
14
+ """
15
+
16
+ from typing import Tuple, Optional, Dict, Any
17
+ import numpy as np
18
+ import torch
19
+ from PIL import Image
20
+ from facenet_pytorch import MTCNN
21
+ from rembg import remove
22
+
23
+ # Training set face parameters (derived from training data statistics)
24
+ TRAINING_SET_FACE_SIZE = 194.2749650813705
25
+ TRAINING_SET_FACE_CENTER = [251.83270369057132, 280.0133630862363]
26
+
27
+ # Public constants for external use
28
+ FACE_SIZE = TRAINING_SET_FACE_SIZE
29
+ FACE_CENTER = TRAINING_SET_FACE_CENTER
30
+ DEFAULT_BACKGROUND_COLOR = (255, 255, 255)
31
+ DEFAULT_IMG_SIZE = 512
32
+
33
+ # Device setup
34
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+
36
+ # Default face detector instance
37
+ FACE_DETECTOR = MTCNN(
38
+ image_size=512,
39
+ margin=0,
40
+ min_face_size=20,
41
+ thresholds=[0.6, 0.7, 0.7],
42
+ factor=0.709,
43
+ post_process=True,
44
+ device=DEVICE
45
+ )
46
+
47
+ def select_face(detected_bounding_boxes: Optional[np.ndarray], confidence_scores: Optional[np.ndarray]) -> Optional[np.ndarray]:
48
+ """
49
+ Select the largest face from detected faces with confidence above threshold.
50
+
51
+ Args:
52
+ detected_bounding_boxes: Detected bounding boxes in xyxy format
53
+ confidence_scores: Detection confidence probabilities
54
+
55
+ Returns:
56
+ Selected bounding box or None if no suitable face found
57
+ """
58
+ if detected_bounding_boxes is None or confidence_scores is None:
59
+ return None
60
+
61
+ # Filter faces with confidence > 0.8
62
+ high_confidence_faces = [
63
+ detected_bounding_boxes[i] for i in range(len(detected_bounding_boxes))
64
+ if confidence_scores[i] > 0.8
65
+ ]
66
+
67
+ if not high_confidence_faces:
68
+ return None
69
+
70
+ # Return the largest face (by area)
71
+ return max(high_confidence_faces, key=lambda bbox: (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]))
72
+
73
+ def crop_face(
74
+ input_image_array: np.ndarray,
75
+ face_detector: MTCNN = FACE_DETECTOR,
76
+ target_face_size: float = FACE_SIZE,
77
+ target_face_center: list = FACE_CENTER,
78
+ output_image_size: int = 512,
79
+ background_color: Tuple[int, int, int] = (255, 255, 255)
80
+ ) -> Tuple[Image.Image, Dict[str, Any]]:
81
+ """
82
+ Crop and align face in image to match training data specifications.
83
+
84
+ Args:
85
+ input_image_array: Input image as numpy array (H, W, C)
86
+ face_detector: MTCNN face detector instance
87
+ target_face_size: Target face size from training data
88
+ target_face_center: Target face center from training data
89
+ output_image_size: Output image size
90
+ background_color: Background color for padding
91
+
92
+ Returns:
93
+ Tuple of (cropped_image, crop_parameters)
94
+
95
+ Raises:
96
+ ValueError: If no face is detected in the image
97
+ """
98
+ image_height, image_width, _ = input_image_array.shape
99
+
100
+ # Handle RGBA images by compositing with background color
101
+ if input_image_array.shape[2] == 4:
102
+ rgba_pil_image = Image.fromarray(input_image_array)
103
+ background_image = Image.new("RGB", rgba_pil_image.size, background_color)
104
+ rgb_composite_image = Image.alpha_composite(background_image.convert("RGBA"), rgba_pil_image).convert("RGB")
105
+ processed_image_array = np.array(rgb_composite_image)
106
+ else:
107
+ processed_image_array = input_image_array[:, :, :3] # Ensure RGB format
108
+
109
+ # Detect and select face
110
+ detected_bounding_boxes, confidence_scores = face_detector.detect(processed_image_array)
111
+ selected_face_bbox = select_face(detected_bounding_boxes, confidence_scores)
112
+ if selected_face_bbox is None:
113
+ raise ValueError("No face detected in the image")
114
+
115
+ # Calculate detected face properties
116
+ detected_face_size = 0.5 * (selected_face_bbox[2] - selected_face_bbox[0] + selected_face_bbox[3] - selected_face_bbox[1])
117
+ detected_face_center = (
118
+ 0.5 * (selected_face_bbox[0] + selected_face_bbox[2]),
119
+ 0.5 * (selected_face_bbox[1] + selected_face_bbox[3])
120
+ )
121
+
122
+ # Scale image to match training face size
123
+ scale_ratio = target_face_size / detected_face_size
124
+ scaled_width, scaled_height = int(image_width * scale_ratio), int(image_height * scale_ratio)
125
+ scaled_pil_image = Image.fromarray(processed_image_array).resize((scaled_width, scaled_height))
126
+ scaled_face_center = (
127
+ int(detected_face_center[0] * scale_ratio),
128
+ int(detected_face_center[1] * scale_ratio)
129
+ )
130
+
131
+ # Create output image with background
132
+ output_image = Image.new("RGB", (output_image_size, output_image_size), color=background_color)
133
+
134
+ # Calculate alignment offsets
135
+ horizontal_offset = target_face_center[0] - scaled_face_center[0]
136
+ vertical_offset = target_face_center[1] - scaled_face_center[1]
137
+
138
+ # Calculate crop boundaries
139
+ crop_left_boundary = int(max(0, -horizontal_offset))
140
+ crop_top_boundary = int(max(0, -vertical_offset))
141
+ crop_right_boundary = int(min(scaled_width, output_image_size - horizontal_offset))
142
+ crop_bottom_boundary = int(min(scaled_height, output_image_size - vertical_offset))
143
+
144
+ # Crop and paste
145
+ cropped_face_image = scaled_pil_image.crop((crop_left_boundary, crop_top_boundary, crop_right_boundary, crop_bottom_boundary))
146
+ paste_coordinates = (int(max(0, horizontal_offset)), int(max(0, vertical_offset)))
147
+ output_image.paste(cropped_face_image, paste_coordinates)
148
+
149
+ crop_parameters = {
150
+ 'resize_ratio': scale_ratio,
151
+ 'x_offset_left': horizontal_offset,
152
+ 'y_offset_top': vertical_offset,
153
+ }
154
+
155
+ return output_image, crop_parameters
156
+
157
+ def prepare_foreground_with_rembg(input_image_array: np.ndarray) -> np.ndarray:
158
+ """
159
+ Prepare foreground image using rembg for background removal.
160
+
161
+ Args:
162
+ input_image_array: Input image as numpy array (H, W, C)
163
+
164
+ Returns:
165
+ RGBA image as numpy array with background removed
166
+ """
167
+ pil_image = Image.fromarray(input_image_array)
168
+ background_removed_image = remove(pil_image)
169
+ processed_image_array = np.array(background_removed_image)
170
+
171
+ # Ensure RGBA format
172
+ if processed_image_array.shape[2] == 4:
173
+ return processed_image_array
174
+ elif processed_image_array.shape[2] == 3:
175
+ height, width = processed_image_array.shape[:2]
176
+ alpha_channel = np.full((height, width), 255, dtype=np.uint8)
177
+ rgba_image = np.zeros((height, width, 4), dtype=np.uint8)
178
+ rgba_image[:, :, :3] = processed_image_array
179
+ rgba_image[:, :, 3] = alpha_channel
180
+ return rgba_image
181
+
182
+ return processed_image_array
183
+
184
+ def preprocess_image(
185
+ original_image_array: np.ndarray,
186
+ target_image_size: int = DEFAULT_IMG_SIZE,
187
+ background_color: Tuple[int, int, int] = DEFAULT_BACKGROUND_COLOR
188
+ ) -> Image.Image:
189
+ """
190
+ Preprocess image with background removal and face cropping.
191
+
192
+ Args:
193
+ original_image_array: Input image as numpy array
194
+ target_image_size: Target image size
195
+ background_color: Background color for compositing
196
+
197
+ Returns:
198
+ Processed PIL Image
199
+ """
200
+ processed_image_array = prepare_foreground_with_rembg(original_image_array)
201
+
202
+ # Convert RGBA to RGB with specified background
203
+ if processed_image_array.shape[2] == 4:
204
+ rgba_pil_image = Image.fromarray(processed_image_array)
205
+ background_image = Image.new("RGB", rgba_pil_image.size, background_color)
206
+ rgb_composite_image = Image.alpha_composite(background_image.convert("RGBA"), rgba_pil_image).convert("RGB")
207
+ processed_image_array = np.array(rgb_composite_image)
208
+
209
+ cropped_image, crop_parameters = crop_face(
210
+ processed_image_array,
211
+ FACE_DETECTOR,
212
+ FACE_SIZE,
213
+ FACE_CENTER,
214
+ target_image_size,
215
+ background_color
216
+ )
217
+ return cropped_image
218
+
219
+ def preprocess_image_without_cropping(
220
+ original_image_array: np.ndarray,
221
+ target_image_size: int = DEFAULT_IMG_SIZE,
222
+ background_color: Tuple[int, int, int] = DEFAULT_BACKGROUND_COLOR
223
+ ) -> Image.Image:
224
+ """
225
+ Preprocess image with background removal, without face cropping.
226
+
227
+ Args:
228
+ original_image_array: Input image as numpy array
229
+ target_image_size: Target image size
230
+ background_color: Background color for compositing
231
+
232
+ Returns:
233
+ Processed PIL Image
234
+ """
235
+ processed_image_array = prepare_foreground_with_rembg(original_image_array)
236
+
237
+ resized_image = Image.fromarray(processed_image_array).resize((target_image_size, target_image_size))
238
+ background_image = Image.new("RGBA", (target_image_size, target_image_size), background_color)
239
+ composite_image = Image.alpha_composite(background_image, resized_image).convert("RGB")
240
+ return composite_image
utils_folder/opencv_cameras.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "id": "sample_000",
3
+ "frames": [
4
+ {
5
+ "w": 512,
6
+ "h": 512,
7
+ "fx": 548.9937744140625,
8
+ "fy": 548.9937744140625,
9
+ "cx": 256.0,
10
+ "cy": 256.0,
11
+ "w2c": [
12
+ [
13
+ 7.549789865864094e-08,
14
+ -0.9999999999999908,
15
+ 5.960464618088506e-08,
16
+ 2.0384433030900552e-07
17
+ ],
18
+ [
19
+ -7.549789676401702e-08,
20
+ -5.960465047532283e-08,
21
+ -0.9999999999999908,
22
+ -2.0384432486286618e-07
23
+ ],
24
+ [
25
+ 0.9999999999999885,
26
+ 7.549789676401702e-08,
27
+ -7.549789865864095e-08,
28
+ 2.7000000476836847
29
+ ],
30
+ [
31
+ 0.0,
32
+ 0.0,
33
+ 0.0,
34
+ 1.0
35
+ ]
36
+ ],
37
+ "blender_camera_name": "lrm_cam.000",
38
+ "blender_camera_location": [
39
+ -2.7,
40
+ 3.3065463576978537e-16,
41
+ 0.0
42
+ ]
43
+ },
44
+ {
45
+ "w": 512,
46
+ "h": 512,
47
+ "fx": 548.9937744140625,
48
+ "fy": 548.9937744140625,
49
+ "cx": 256.0,
50
+ "cy": 256.0,
51
+ "w2c": [
52
+ [
53
+ 0.7071067932881387,
54
+ -0.7071067336835127,
55
+ -1.4907362623459278e-07,
56
+ 1.137964524414734e-07
57
+ ],
58
+ [
59
+ -1.5879604949837492e-07,
60
+ 5.202588738731513e-08,
61
+ -0.999999999999972,
62
+ -2.0384434114916973e-07
63
+ ],
64
+ [
65
+ 0.7071067336835085,
66
+ 0.7071067932881663,
67
+ -7.549790026365632e-08,
68
+ 2.6999998777741223
69
+ ],
70
+ [
71
+ 0.0,
72
+ 0.0,
73
+ 0.0,
74
+ 1.0
75
+ ]
76
+ ],
77
+ "blender_camera_name": "lrm_cam.001",
78
+ "blender_camera_location": [
79
+ -1.9091883092036788,
80
+ -1.9091883092036783,
81
+ 0.0
82
+ ]
83
+ },
84
+ {
85
+ "w": 512,
86
+ "h": 512,
87
+ "fx": 548.9937744140625,
88
+ "fy": 548.9937744140625,
89
+ "cx": 256.0,
90
+ "cy": 256.0,
91
+ "w2c": [
92
+ [
93
+ 1.0,
94
+ -1.8369704112266414e-16,
95
+ -1.8716961953007326e-23,
96
+ -7.758078222293693e-23
97
+ ],
98
+ [
99
+ -1.0687309218630199e-23,
100
+ 4.371138828673784e-08,
101
+ -0.9999999999999981,
102
+ 1.1802075045851359e-07
103
+ ],
104
+ [
105
+ 1.8369704112266542e-16,
106
+ 0.999999999999998,
107
+ 4.371138828673784e-08,
108
+ 2.7000000476837105
109
+ ],
110
+ [
111
+ 0.0,
112
+ 0.0,
113
+ 0.0,
114
+ 1.0
115
+ ]
116
+ ],
117
+ "blender_camera_name": "lrm_cam.002",
118
+ "blender_camera_location": [
119
+ -4.959819536546781e-16,
120
+ -2.7,
121
+ 0.0
122
+ ]
123
+ },
124
+ {
125
+ "w": 512,
126
+ "h": 512,
127
+ "fx": 548.9937744140625,
128
+ "fy": 548.9937744140625,
129
+ "cx": 256.0,
130
+ "cy": 256.0,
131
+ "w2c": [
132
+ [
133
+ 0.7071067932881387,
134
+ 0.7071067336835127,
135
+ 1.4907362623459278e-07,
136
+ -1.137964524414734e-07
137
+ ],
138
+ [
139
+ 1.5879604949837492e-07,
140
+ 5.202588738731513e-08,
141
+ -0.999999999999972,
142
+ -2.0384434114916973e-07
143
+ ],
144
+ [
145
+ -0.7071067336835085,
146
+ 0.7071067932881663,
147
+ -7.549790026365632e-08,
148
+ 2.6999998777741223
149
+ ],
150
+ [
151
+ 0.0,
152
+ 0.0,
153
+ 0.0,
154
+ 1.0
155
+ ]
156
+ ],
157
+ "blender_camera_name": "lrm_cam.003",
158
+ "blender_camera_location": [
159
+ 1.9091883092036779,
160
+ -1.9091883092036788,
161
+ 0.0
162
+ ]
163
+ },
164
+ {
165
+ "w": 512,
166
+ "h": 512,
167
+ "fx": 548.9937744140625,
168
+ "fy": 548.9937744140625,
169
+ "cx": 256.0,
170
+ "cy": 256.0,
171
+ "w2c": [
172
+ [
173
+ 7.549790126404288e-08,
174
+ 0.9999999999999943,
175
+ 2.1254750082671866e-22,
176
+ -2.0384433701293624e-07
177
+ ],
178
+ [
179
+ 7.549790126404245e-08,
180
+ -5.699933095275168e-15,
181
+ -0.9999999999999943,
182
+ -2.0384433701293508e-07
183
+ ],
184
+ [
185
+ -0.9999999999999885,
186
+ 7.549790126404245e-08,
187
+ -7.549790126404288e-08,
188
+ 2.7000000476836847
189
+ ],
190
+ [
191
+ 0.0,
192
+ 0.0,
193
+ 0.0,
194
+ 1.0
195
+ ]
196
+ ],
197
+ "blender_camera_name": "lrm_cam.004",
198
+ "blender_camera_location": [
199
+ 2.7,
200
+ 0.0,
201
+ 0.0
202
+ ]
203
+ },
204
+ {
205
+ "w": 512,
206
+ "h": 512,
207
+ "fx": 548.9937744140625,
208
+ "fy": 548.9937744140625,
209
+ "cx": 256.0,
210
+ "cy": 256.0,
211
+ "w2c": [
212
+ [
213
+ -0.9999999999999847,
214
+ -8.74227798575314e-08,
215
+ 8.742276618399183e-08,
216
+ 2.360415099493051e-07
217
+ ],
218
+ [
219
+ -8.742276564667607e-08,
220
+ -4.371139592947904e-08,
221
+ -0.9999999999999906,
222
+ 1.1802077109391521e-07
223
+ ],
224
+ [
225
+ 8.742278039484591e-08,
226
+ -0.9999999999999905,
227
+ 4.3711391302137594e-08,
228
+ 2.70000004768369
229
+ ],
230
+ [
231
+ 0.0,
232
+ 0.0,
233
+ 0.0,
234
+ 1.0
235
+ ]
236
+ ],
237
+ "blender_camera_name": "lrm_cam.005",
238
+ "blender_camera_location": [
239
+ 1.6532731788489269e-16,
240
+ 2.7,
241
+ 0.0
242
+ ]
243
+ }
244
+ ]
245
+ }