Spaces:
Runtime error
Runtime error
Upload 14 files
Browse files- .gitattributes +1 -0
- Dockerfile +54 -0
- __pycache__/app.cpython-310.pyc +0 -0
- core/__init__.py +0 -0
- core/__pycache__/__init__.cpython-310.pyc +0 -0
- core/__pycache__/gs.cpython-310.pyc +0 -0
- core/__pycache__/options.cpython-310.pyc +0 -0
- core/attention.py +156 -0
- core/gs.py +190 -0
- core/models.py +174 -0
- core/options.py +120 -0
- core/provider_objaverse.py +172 -0
- core/unet.py +319 -0
- core/utils.py +109 -0
- data_test/catstatue.ply +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data_test/catstatue.ply filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
# Configure environment
|
| 4 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 5 |
+
|
| 6 |
+
# Install the required packages
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
software-properties-common
|
| 9 |
+
|
| 10 |
+
# Add the deadsnakes PPA
|
| 11 |
+
RUN add-apt-repository ppa:deadsnakes/ppa
|
| 12 |
+
|
| 13 |
+
# Install Python 3.10
|
| 14 |
+
RUN apt-get update && apt-get install -y \
|
| 15 |
+
python3.10 \
|
| 16 |
+
python3.10-dev \
|
| 17 |
+
python3.10-distutils \
|
| 18 |
+
python3.10-venv \
|
| 19 |
+
python3-pip
|
| 20 |
+
|
| 21 |
+
# Install other dependencies
|
| 22 |
+
RUN apt-get install -y \
|
| 23 |
+
git \
|
| 24 |
+
gcc \
|
| 25 |
+
g++ \
|
| 26 |
+
libgl1 \
|
| 27 |
+
libglib2.0.0 \
|
| 28 |
+
ffmpeg \
|
| 29 |
+
cmake \
|
| 30 |
+
libgtk2.0.0
|
| 31 |
+
|
| 32 |
+
# Working directory
|
| 33 |
+
RUN useradd -m -u 1000 user
|
| 34 |
+
USER user
|
| 35 |
+
ENV HOME=/home/user \
|
| 36 |
+
PATH=/home/user/.local/bin:$PATH
|
| 37 |
+
WORKDIR $HOME/app
|
| 38 |
+
|
| 39 |
+
# Install the required Python packages
|
| 40 |
+
RUN pip install wheel
|
| 41 |
+
RUN pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 torchtext==0.16.0 torchdata==0.7.0 --extra-index-url https://download.pytorch.org/whl/cu121 -U
|
| 42 |
+
RUN sed -i 's/return caster.operator typename make_caster<T>::template cast_op_type<T>();/return caster;/' /home/user/.local/lib/python3.10/site-packages/torch/include/pybind11/cast.h
|
| 43 |
+
RUN pip install tyro kiui PyMCubes nerfacc trimesh pymeshlab ninja plyfile xatlas pygltflib gradio opencv-python scikit-learn
|
| 44 |
+
RUN pip install https://github.com/dylanebert/wheels/releases/download/1.0.0/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
|
| 45 |
+
RUN pip install https://github.com/dylanebert/wheels/releases/download/1.0.0/nvdiffrast-0.3.1-py3-none-any.whl
|
| 46 |
+
RUN pip install git+https://github.com/ashawkey/kiuikit.git
|
| 47 |
+
|
| 48 |
+
# Copy all files to the working directory
|
| 49 |
+
COPY --chown=user . $HOME/app
|
| 50 |
+
|
| 51 |
+
EXPOSE 7860
|
| 52 |
+
|
| 53 |
+
# Run the gradio app
|
| 54 |
+
CMD ["python3.10", "app.py"]
|
__pycache__/app.cpython-310.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
core/__init__.py
ADDED
|
File without changes
|
core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (135 Bytes). View file
|
|
|
core/__pycache__/gs.cpython-310.pyc
ADDED
|
Binary file (5.45 kB). View file
|
|
|
core/__pycache__/options.cpython-310.pyc
ADDED
|
Binary file (2.49 kB). View file
|
|
|
core/attention.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import warnings
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 17 |
+
try:
|
| 18 |
+
if XFORMERS_ENABLED:
|
| 19 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 20 |
+
|
| 21 |
+
XFORMERS_AVAILABLE = True
|
| 22 |
+
warnings.warn("xFormers is available (Attention)")
|
| 23 |
+
else:
|
| 24 |
+
warnings.warn("xFormers is disabled (Attention)")
|
| 25 |
+
raise ImportError
|
| 26 |
+
except ImportError:
|
| 27 |
+
XFORMERS_AVAILABLE = False
|
| 28 |
+
warnings.warn("xFormers is not available (Attention)")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Attention(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
dim: int,
|
| 35 |
+
num_heads: int = 8,
|
| 36 |
+
qkv_bias: bool = False,
|
| 37 |
+
proj_bias: bool = True,
|
| 38 |
+
attn_drop: float = 0.0,
|
| 39 |
+
proj_drop: float = 0.0,
|
| 40 |
+
) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.num_heads = num_heads
|
| 43 |
+
head_dim = dim // num_heads
|
| 44 |
+
self.scale = head_dim**-0.5
|
| 45 |
+
|
| 46 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 47 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 48 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 49 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 50 |
+
|
| 51 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 52 |
+
B, N, C = x.shape
|
| 53 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 54 |
+
|
| 55 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 56 |
+
attn = q @ k.transpose(-2, -1)
|
| 57 |
+
|
| 58 |
+
attn = attn.softmax(dim=-1)
|
| 59 |
+
attn = self.attn_drop(attn)
|
| 60 |
+
|
| 61 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 62 |
+
x = self.proj(x)
|
| 63 |
+
x = self.proj_drop(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MemEffAttention(Attention):
|
| 68 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 69 |
+
if not XFORMERS_AVAILABLE:
|
| 70 |
+
if attn_bias is not None:
|
| 71 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 72 |
+
return super().forward(x)
|
| 73 |
+
|
| 74 |
+
B, N, C = x.shape
|
| 75 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 76 |
+
|
| 77 |
+
q, k, v = unbind(qkv, 2)
|
| 78 |
+
|
| 79 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 80 |
+
x = x.reshape([B, N, C])
|
| 81 |
+
|
| 82 |
+
x = self.proj(x)
|
| 83 |
+
x = self.proj_drop(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class CrossAttention(nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
dim: int,
|
| 91 |
+
dim_q: int,
|
| 92 |
+
dim_k: int,
|
| 93 |
+
dim_v: int,
|
| 94 |
+
num_heads: int = 8,
|
| 95 |
+
qkv_bias: bool = False,
|
| 96 |
+
proj_bias: bool = True,
|
| 97 |
+
attn_drop: float = 0.0,
|
| 98 |
+
proj_drop: float = 0.0,
|
| 99 |
+
) -> None:
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.dim = dim
|
| 102 |
+
self.num_heads = num_heads
|
| 103 |
+
head_dim = dim // num_heads
|
| 104 |
+
self.scale = head_dim**-0.5
|
| 105 |
+
|
| 106 |
+
self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
|
| 107 |
+
self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
|
| 108 |
+
self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
|
| 109 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 110 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 111 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 112 |
+
|
| 113 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 114 |
+
# q: [B, N, Cq]
|
| 115 |
+
# k: [B, M, Ck]
|
| 116 |
+
# v: [B, M, Cv]
|
| 117 |
+
# return: [B, N, C]
|
| 118 |
+
|
| 119 |
+
B, N, _ = q.shape
|
| 120 |
+
M = k.shape[1]
|
| 121 |
+
|
| 122 |
+
q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh]
|
| 123 |
+
k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
|
| 124 |
+
v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
|
| 125 |
+
|
| 126 |
+
attn = q @ k.transpose(-2, -1) # [B, nh, N, M]
|
| 127 |
+
|
| 128 |
+
attn = attn.softmax(dim=-1) # [B, nh, N, M]
|
| 129 |
+
attn = self.attn_drop(attn)
|
| 130 |
+
|
| 131 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C]
|
| 132 |
+
x = self.proj(x)
|
| 133 |
+
x = self.proj_drop(x)
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class MemEffCrossAttention(CrossAttention):
|
| 138 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
|
| 139 |
+
if not XFORMERS_AVAILABLE:
|
| 140 |
+
if attn_bias is not None:
|
| 141 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 142 |
+
return super().forward(x)
|
| 143 |
+
|
| 144 |
+
B, N, _ = q.shape
|
| 145 |
+
M = k.shape[1]
|
| 146 |
+
|
| 147 |
+
q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh]
|
| 148 |
+
k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
|
| 149 |
+
v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
|
| 150 |
+
|
| 151 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 152 |
+
x = x.reshape(B, N, -1)
|
| 153 |
+
|
| 154 |
+
x = self.proj(x)
|
| 155 |
+
x = self.proj_drop(x)
|
| 156 |
+
return x
|
core/gs.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from diff_gaussian_rasterization import (
|
| 8 |
+
GaussianRasterizationSettings,
|
| 9 |
+
GaussianRasterizer,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from core.options import Options
|
| 13 |
+
|
| 14 |
+
import kiui
|
| 15 |
+
|
| 16 |
+
class GaussianRenderer:
|
| 17 |
+
def __init__(self, opt: Options):
|
| 18 |
+
|
| 19 |
+
self.opt = opt
|
| 20 |
+
self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
|
| 21 |
+
|
| 22 |
+
# intrinsics
|
| 23 |
+
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
|
| 24 |
+
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
|
| 25 |
+
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
| 26 |
+
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
| 27 |
+
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
| 28 |
+
self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
| 29 |
+
self.proj_matrix[2, 3] = 1
|
| 30 |
+
|
| 31 |
+
def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1):
|
| 32 |
+
# gaussians: [B, N, 14]
|
| 33 |
+
# cam_view, cam_view_proj: [B, V, 4, 4]
|
| 34 |
+
# cam_pos: [B, V, 3]
|
| 35 |
+
|
| 36 |
+
device = gaussians.device
|
| 37 |
+
B, V = cam_view.shape[:2]
|
| 38 |
+
|
| 39 |
+
# loop of loop...
|
| 40 |
+
images = []
|
| 41 |
+
alphas = []
|
| 42 |
+
for b in range(B):
|
| 43 |
+
|
| 44 |
+
# pos, opacity, scale, rotation, shs
|
| 45 |
+
means3D = gaussians[b, :, 0:3].contiguous().float()
|
| 46 |
+
opacity = gaussians[b, :, 3:4].contiguous().float()
|
| 47 |
+
scales = gaussians[b, :, 4:7].contiguous().float()
|
| 48 |
+
rotations = gaussians[b, :, 7:11].contiguous().float()
|
| 49 |
+
rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]
|
| 50 |
+
|
| 51 |
+
for v in range(V):
|
| 52 |
+
|
| 53 |
+
# render novel views
|
| 54 |
+
view_matrix = cam_view[b, v].float()
|
| 55 |
+
view_proj_matrix = cam_view_proj[b, v].float()
|
| 56 |
+
campos = cam_pos[b, v].float()
|
| 57 |
+
|
| 58 |
+
raster_settings = GaussianRasterizationSettings(
|
| 59 |
+
image_height=self.opt.output_size,
|
| 60 |
+
image_width=self.opt.output_size,
|
| 61 |
+
tanfovx=self.tan_half_fov,
|
| 62 |
+
tanfovy=self.tan_half_fov,
|
| 63 |
+
bg=self.bg_color if bg_color is None else bg_color,
|
| 64 |
+
scale_modifier=scale_modifier,
|
| 65 |
+
viewmatrix=view_matrix,
|
| 66 |
+
projmatrix=view_proj_matrix,
|
| 67 |
+
sh_degree=0,
|
| 68 |
+
campos=campos,
|
| 69 |
+
prefiltered=False,
|
| 70 |
+
debug=False,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
| 74 |
+
|
| 75 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
| 76 |
+
rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
|
| 77 |
+
means3D=means3D,
|
| 78 |
+
means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device),
|
| 79 |
+
shs=None,
|
| 80 |
+
colors_precomp=rgbs,
|
| 81 |
+
opacities=opacity,
|
| 82 |
+
scales=scales,
|
| 83 |
+
rotations=rotations,
|
| 84 |
+
cov3D_precomp=None,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
rendered_image = rendered_image.clamp(0, 1)
|
| 88 |
+
|
| 89 |
+
images.append(rendered_image)
|
| 90 |
+
alphas.append(rendered_alpha)
|
| 91 |
+
|
| 92 |
+
images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size)
|
| 93 |
+
alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
"image": images, # [B, V, 3, H, W]
|
| 97 |
+
"alpha": alphas, # [B, V, 1, H, W]
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def save_ply(self, gaussians, path, compatible=True):
|
| 102 |
+
# gaussians: [B, N, 14]
|
| 103 |
+
# compatible: save pre-activated gaussians as in the original paper
|
| 104 |
+
|
| 105 |
+
assert gaussians.shape[0] == 1, 'only support batch size 1'
|
| 106 |
+
|
| 107 |
+
from plyfile import PlyData, PlyElement
|
| 108 |
+
|
| 109 |
+
means3D = gaussians[0, :, 0:3].contiguous().float()
|
| 110 |
+
opacity = gaussians[0, :, 3:4].contiguous().float()
|
| 111 |
+
scales = gaussians[0, :, 4:7].contiguous().float()
|
| 112 |
+
rotations = gaussians[0, :, 7:11].contiguous().float()
|
| 113 |
+
shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
|
| 114 |
+
|
| 115 |
+
# prune by opacity
|
| 116 |
+
mask = opacity.squeeze(-1) >= 0.005
|
| 117 |
+
means3D = means3D[mask]
|
| 118 |
+
opacity = opacity[mask]
|
| 119 |
+
scales = scales[mask]
|
| 120 |
+
rotations = rotations[mask]
|
| 121 |
+
shs = shs[mask]
|
| 122 |
+
|
| 123 |
+
# invert activation to make it compatible with the original ply format
|
| 124 |
+
if compatible:
|
| 125 |
+
opacity = kiui.op.inverse_sigmoid(opacity)
|
| 126 |
+
scales = torch.log(scales + 1e-8)
|
| 127 |
+
shs = (shs - 0.5) / 0.28209479177387814
|
| 128 |
+
|
| 129 |
+
xyzs = means3D.detach().cpu().numpy()
|
| 130 |
+
f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
| 131 |
+
opacities = opacity.detach().cpu().numpy()
|
| 132 |
+
scales = scales.detach().cpu().numpy()
|
| 133 |
+
rotations = rotations.detach().cpu().numpy()
|
| 134 |
+
|
| 135 |
+
l = ['x', 'y', 'z']
|
| 136 |
+
# All channels except the 3 DC
|
| 137 |
+
for i in range(f_dc.shape[1]):
|
| 138 |
+
l.append('f_dc_{}'.format(i))
|
| 139 |
+
l.append('opacity')
|
| 140 |
+
for i in range(scales.shape[1]):
|
| 141 |
+
l.append('scale_{}'.format(i))
|
| 142 |
+
for i in range(rotations.shape[1]):
|
| 143 |
+
l.append('rot_{}'.format(i))
|
| 144 |
+
|
| 145 |
+
dtype_full = [(attribute, 'f4') for attribute in l]
|
| 146 |
+
|
| 147 |
+
elements = np.empty(xyzs.shape[0], dtype=dtype_full)
|
| 148 |
+
attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
|
| 149 |
+
elements[:] = list(map(tuple, attributes))
|
| 150 |
+
el = PlyElement.describe(elements, 'vertex')
|
| 151 |
+
|
| 152 |
+
PlyData([el]).write(path)
|
| 153 |
+
|
| 154 |
+
def load_ply(self, path, compatible=True):
|
| 155 |
+
|
| 156 |
+
from plyfile import PlyData, PlyElement
|
| 157 |
+
|
| 158 |
+
plydata = PlyData.read(path)
|
| 159 |
+
|
| 160 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
| 161 |
+
np.asarray(plydata.elements[0]["y"]),
|
| 162 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
| 163 |
+
print("Number of points at loading : ", xyz.shape[0])
|
| 164 |
+
|
| 165 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
| 166 |
+
|
| 167 |
+
shs = np.zeros((xyz.shape[0], 3))
|
| 168 |
+
shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
| 169 |
+
shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
|
| 170 |
+
shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
|
| 171 |
+
|
| 172 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
| 173 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
| 174 |
+
for idx, attr_name in enumerate(scale_names):
|
| 175 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
| 176 |
+
|
| 177 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
|
| 178 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
| 179 |
+
for idx, attr_name in enumerate(rot_names):
|
| 180 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
| 181 |
+
|
| 182 |
+
gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
|
| 183 |
+
gaussians = torch.from_numpy(gaussians).float() # cpu
|
| 184 |
+
|
| 185 |
+
if compatible:
|
| 186 |
+
gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
|
| 187 |
+
gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
|
| 188 |
+
gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
|
| 189 |
+
|
| 190 |
+
return gaussians
|
core/models.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import kiui
|
| 7 |
+
from kiui.lpips import LPIPS
|
| 8 |
+
|
| 9 |
+
from core.unet import UNet
|
| 10 |
+
from core.options import Options
|
| 11 |
+
from core.gs import GaussianRenderer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LGM(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
opt: Options,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.opt = opt
|
| 22 |
+
|
| 23 |
+
# unet
|
| 24 |
+
self.unet = UNet(
|
| 25 |
+
9, 14,
|
| 26 |
+
down_channels=self.opt.down_channels,
|
| 27 |
+
down_attention=self.opt.down_attention,
|
| 28 |
+
mid_attention=self.opt.mid_attention,
|
| 29 |
+
up_channels=self.opt.up_channels,
|
| 30 |
+
up_attention=self.opt.up_attention,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# last conv
|
| 34 |
+
self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again
|
| 35 |
+
|
| 36 |
+
# Gaussian Renderer
|
| 37 |
+
self.gs = GaussianRenderer(opt)
|
| 38 |
+
|
| 39 |
+
# activations...
|
| 40 |
+
self.pos_act = lambda x: x.clamp(-1, 1)
|
| 41 |
+
self.scale_act = lambda x: 0.1 * F.softplus(x)
|
| 42 |
+
self.opacity_act = lambda x: torch.sigmoid(x)
|
| 43 |
+
self.rot_act = F.normalize
|
| 44 |
+
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again
|
| 45 |
+
|
| 46 |
+
# LPIPS loss
|
| 47 |
+
if self.opt.lambda_lpips > 0:
|
| 48 |
+
self.lpips_loss = LPIPS(net='vgg')
|
| 49 |
+
self.lpips_loss.requires_grad_(False)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def state_dict(self, **kwargs):
|
| 53 |
+
# remove lpips_loss
|
| 54 |
+
state_dict = super().state_dict(**kwargs)
|
| 55 |
+
for k in list(state_dict.keys()):
|
| 56 |
+
if 'lpips_loss' in k:
|
| 57 |
+
del state_dict[k]
|
| 58 |
+
return state_dict
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def prepare_default_rays(self, device, elevation=0):
|
| 62 |
+
|
| 63 |
+
from kiui.cam import orbit_camera
|
| 64 |
+
from core.utils import get_rays
|
| 65 |
+
|
| 66 |
+
cam_poses = np.stack([
|
| 67 |
+
orbit_camera(elevation, 0, radius=self.opt.cam_radius),
|
| 68 |
+
orbit_camera(elevation, 90, radius=self.opt.cam_radius),
|
| 69 |
+
orbit_camera(elevation, 180, radius=self.opt.cam_radius),
|
| 70 |
+
orbit_camera(elevation, 270, radius=self.opt.cam_radius),
|
| 71 |
+
], axis=0) # [4, 4, 4]
|
| 72 |
+
cam_poses = torch.from_numpy(cam_poses)
|
| 73 |
+
|
| 74 |
+
rays_embeddings = []
|
| 75 |
+
for i in range(cam_poses.shape[0]):
|
| 76 |
+
rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
|
| 77 |
+
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
|
| 78 |
+
rays_embeddings.append(rays_plucker)
|
| 79 |
+
|
| 80 |
+
## visualize rays for plotting figure
|
| 81 |
+
# kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
|
| 82 |
+
|
| 83 |
+
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
|
| 84 |
+
|
| 85 |
+
return rays_embeddings
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def forward_gaussians(self, images):
|
| 89 |
+
# images: [B, 4, 9, H, W]
|
| 90 |
+
# return: Gaussians: [B, dim_t]
|
| 91 |
+
|
| 92 |
+
B, V, C, H, W = images.shape
|
| 93 |
+
images = images.view(B*V, C, H, W)
|
| 94 |
+
|
| 95 |
+
x = self.unet(images) # [B*4, 14, h, w]
|
| 96 |
+
x = self.conv(x) # [B*4, 14, h, w]
|
| 97 |
+
|
| 98 |
+
x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size)
|
| 99 |
+
|
| 100 |
+
## visualize multi-view gaussian features for plotting figure
|
| 101 |
+
# tmp_alpha = self.opacity_act(x[0, :, 3:4])
|
| 102 |
+
# tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha)
|
| 103 |
+
# tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5
|
| 104 |
+
# kiui.vis.plot_image(tmp_img_rgb, save=True)
|
| 105 |
+
# kiui.vis.plot_image(tmp_img_pos, save=True)
|
| 106 |
+
|
| 107 |
+
x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
|
| 108 |
+
|
| 109 |
+
pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
|
| 110 |
+
opacity = self.opacity_act(x[..., 3:4])
|
| 111 |
+
scale = self.scale_act(x[..., 4:7])
|
| 112 |
+
rotation = self.rot_act(x[..., 7:11])
|
| 113 |
+
rgbs = self.rgb_act(x[..., 11:])
|
| 114 |
+
|
| 115 |
+
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
|
| 116 |
+
|
| 117 |
+
return gaussians
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def forward(self, data, step_ratio=1):
|
| 121 |
+
# data: output of the dataloader
|
| 122 |
+
# return: loss
|
| 123 |
+
|
| 124 |
+
results = {}
|
| 125 |
+
loss = 0
|
| 126 |
+
|
| 127 |
+
images = data['input'] # [B, 4, 9, h, W], input features
|
| 128 |
+
|
| 129 |
+
# use the first view to predict gaussians
|
| 130 |
+
gaussians = self.forward_gaussians(images) # [B, N, 14]
|
| 131 |
+
|
| 132 |
+
results['gaussians'] = gaussians
|
| 133 |
+
|
| 134 |
+
# random bg for training
|
| 135 |
+
if self.training:
|
| 136 |
+
bg_color = torch.rand(3, dtype=torch.float32, device=gaussians.device)
|
| 137 |
+
else:
|
| 138 |
+
bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)
|
| 139 |
+
|
| 140 |
+
# use the other views for rendering and supervision
|
| 141 |
+
results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
|
| 142 |
+
pred_images = results['image'] # [B, V, C, output_size, output_size]
|
| 143 |
+
pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
|
| 144 |
+
|
| 145 |
+
results['images_pred'] = pred_images
|
| 146 |
+
results['alphas_pred'] = pred_alphas
|
| 147 |
+
|
| 148 |
+
gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
|
| 149 |
+
gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
|
| 150 |
+
|
| 151 |
+
gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
|
| 152 |
+
|
| 153 |
+
loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks)
|
| 154 |
+
loss = loss + loss_mse
|
| 155 |
+
|
| 156 |
+
if self.opt.lambda_lpips > 0:
|
| 157 |
+
loss_lpips = self.lpips_loss(
|
| 158 |
+
# gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
|
| 159 |
+
# pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
|
| 160 |
+
# downsampled to at most 256 to reduce memory cost
|
| 161 |
+
F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
| 162 |
+
F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
| 163 |
+
).mean()
|
| 164 |
+
results['loss_lpips'] = loss_lpips
|
| 165 |
+
loss = loss + self.opt.lambda_lpips * loss_lpips
|
| 166 |
+
|
| 167 |
+
results['loss'] = loss
|
| 168 |
+
|
| 169 |
+
# metric
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
|
| 172 |
+
results['psnr'] = psnr
|
| 173 |
+
|
| 174 |
+
return results
|
core/options.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tyro
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Tuple, Literal, Dict, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class Options:
|
| 8 |
+
### model
|
| 9 |
+
# Unet image input size
|
| 10 |
+
input_size: int = 256
|
| 11 |
+
# Unet definition
|
| 12 |
+
down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
|
| 13 |
+
down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
|
| 14 |
+
mid_attention: bool = True
|
| 15 |
+
up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
|
| 16 |
+
up_attention: Tuple[bool, ...] = (True, True, True, False)
|
| 17 |
+
# Unet output size, dependent on the input_size and U-Net structure!
|
| 18 |
+
splat_size: int = 64
|
| 19 |
+
# gaussian render size
|
| 20 |
+
output_size: int = 256
|
| 21 |
+
|
| 22 |
+
### dataset
|
| 23 |
+
# data mode (only support s3 now)
|
| 24 |
+
data_mode: Literal['s3'] = 's3'
|
| 25 |
+
# fovy of the dataset
|
| 26 |
+
fovy: float = 49.1
|
| 27 |
+
# camera near plane
|
| 28 |
+
znear: float = 0.5
|
| 29 |
+
# camera far plane
|
| 30 |
+
zfar: float = 2.5
|
| 31 |
+
# number of all views (input + output)
|
| 32 |
+
num_views: int = 12
|
| 33 |
+
# number of views
|
| 34 |
+
num_input_views: int = 4
|
| 35 |
+
# camera radius
|
| 36 |
+
cam_radius: float = 1.5 # to better use [-1, 1]^3 space
|
| 37 |
+
# num workers
|
| 38 |
+
num_workers: int = 8
|
| 39 |
+
|
| 40 |
+
### training
|
| 41 |
+
# workspace
|
| 42 |
+
workspace: str = './workspace'
|
| 43 |
+
# resume
|
| 44 |
+
resume: Optional[str] = None
|
| 45 |
+
# batch size (per-GPU)
|
| 46 |
+
batch_size: int = 8
|
| 47 |
+
# gradient accumulation
|
| 48 |
+
gradient_accumulation_steps: int = 1
|
| 49 |
+
# training epochs
|
| 50 |
+
num_epochs: int = 30
|
| 51 |
+
# lpips loss weight
|
| 52 |
+
lambda_lpips: float = 1.0
|
| 53 |
+
# gradient clip
|
| 54 |
+
gradient_clip: float = 1.0
|
| 55 |
+
# mixed precision
|
| 56 |
+
mixed_precision: str = 'bf16'
|
| 57 |
+
# learning rate
|
| 58 |
+
lr: float = 4e-4
|
| 59 |
+
# augmentation prob for grid distortion
|
| 60 |
+
prob_grid_distortion: float = 0.5
|
| 61 |
+
# augmentation prob for camera jitter
|
| 62 |
+
prob_cam_jitter: float = 0.5
|
| 63 |
+
|
| 64 |
+
### testing
|
| 65 |
+
# test image path
|
| 66 |
+
test_path: Optional[str] = None
|
| 67 |
+
|
| 68 |
+
### misc
|
| 69 |
+
# nvdiffrast backend setting
|
| 70 |
+
force_cuda_rast: bool = False
|
| 71 |
+
# render fancy video with gaussian scaling effect
|
| 72 |
+
fancy_video: bool = False
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# all the default settings
|
| 76 |
+
config_defaults: Dict[str, Options] = {}
|
| 77 |
+
config_doc: Dict[str, str] = {}
|
| 78 |
+
|
| 79 |
+
config_doc['lrm'] = 'the default settings for LGM'
|
| 80 |
+
config_defaults['lrm'] = Options()
|
| 81 |
+
|
| 82 |
+
config_doc['small'] = 'small model with lower resolution Gaussians'
|
| 83 |
+
config_defaults['small'] = Options(
|
| 84 |
+
input_size=256,
|
| 85 |
+
splat_size=64,
|
| 86 |
+
output_size=256,
|
| 87 |
+
batch_size=8,
|
| 88 |
+
gradient_accumulation_steps=1,
|
| 89 |
+
mixed_precision='bf16',
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
config_doc['big'] = 'big model with higher resolution Gaussians'
|
| 93 |
+
config_defaults['big'] = Options(
|
| 94 |
+
input_size=256,
|
| 95 |
+
up_channels=(1024, 1024, 512, 256, 128), # one more decoder
|
| 96 |
+
up_attention=(True, True, True, False, False),
|
| 97 |
+
splat_size=128,
|
| 98 |
+
output_size=512, # render & supervise Gaussians at a higher resolution.
|
| 99 |
+
batch_size=8,
|
| 100 |
+
num_views=8,
|
| 101 |
+
gradient_accumulation_steps=1,
|
| 102 |
+
mixed_precision='bf16',
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
config_doc['tiny'] = 'tiny model for ablation'
|
| 106 |
+
config_defaults['tiny'] = Options(
|
| 107 |
+
input_size=256,
|
| 108 |
+
down_channels=(32, 64, 128, 256, 512),
|
| 109 |
+
down_attention=(False, False, False, False, True),
|
| 110 |
+
up_channels=(512, 256, 128),
|
| 111 |
+
up_attention=(True, False, False, False),
|
| 112 |
+
splat_size=64,
|
| 113 |
+
output_size=256,
|
| 114 |
+
batch_size=16,
|
| 115 |
+
num_views=8,
|
| 116 |
+
gradient_accumulation_steps=1,
|
| 117 |
+
mixed_precision='bf16',
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)
|
core/provider_objaverse.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms.functional as TF
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
|
| 12 |
+
import kiui
|
| 13 |
+
from core.options import Options
|
| 14 |
+
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
|
| 15 |
+
|
| 16 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 17 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ObjaverseDataset(Dataset):
|
| 21 |
+
|
| 22 |
+
def _warn(self):
|
| 23 |
+
raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')
|
| 24 |
+
|
| 25 |
+
def __init__(self, opt: Options, training=True):
|
| 26 |
+
|
| 27 |
+
self.opt = opt
|
| 28 |
+
self.training = training
|
| 29 |
+
|
| 30 |
+
# TODO: remove this barrier
|
| 31 |
+
self._warn()
|
| 32 |
+
|
| 33 |
+
# TODO: load the list of objects for training
|
| 34 |
+
self.items = []
|
| 35 |
+
with open('TODO: file containing the list', 'r') as f:
|
| 36 |
+
for line in f.readlines():
|
| 37 |
+
self.items.append(line.strip())
|
| 38 |
+
|
| 39 |
+
# naive split
|
| 40 |
+
if self.training:
|
| 41 |
+
self.items = self.items[:-self.opt.batch_size]
|
| 42 |
+
else:
|
| 43 |
+
self.items = self.items[-self.opt.batch_size:]
|
| 44 |
+
|
| 45 |
+
# default camera intrinsics
|
| 46 |
+
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
|
| 47 |
+
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
|
| 48 |
+
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
| 49 |
+
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
| 50 |
+
self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
|
| 51 |
+
self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
|
| 52 |
+
self.proj_matrix[2, 3] = 1
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.items)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, idx):
|
| 59 |
+
|
| 60 |
+
uid = self.items[idx]
|
| 61 |
+
results = {}
|
| 62 |
+
|
| 63 |
+
# load num_views images
|
| 64 |
+
images = []
|
| 65 |
+
masks = []
|
| 66 |
+
cam_poses = []
|
| 67 |
+
|
| 68 |
+
vid_cnt = 0
|
| 69 |
+
|
| 70 |
+
# TODO: choose views, based on your rendering settings
|
| 71 |
+
if self.training:
|
| 72 |
+
# input views are in (36, 72), other views are randomly selected
|
| 73 |
+
vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist()
|
| 74 |
+
else:
|
| 75 |
+
# fixed views
|
| 76 |
+
vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist()
|
| 77 |
+
|
| 78 |
+
for vid in vids:
|
| 79 |
+
|
| 80 |
+
image_path = os.path.join(uid, 'rgb', f'{vid:03d}.png')
|
| 81 |
+
camera_path = os.path.join(uid, 'pose', f'{vid:03d}.txt')
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# TODO: load data (modify self.client here)
|
| 85 |
+
image = np.frombuffer(self.client.get(image_path), np.uint8)
|
| 86 |
+
image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1]
|
| 87 |
+
c2w = [float(t) for t in self.client.get(camera_path).decode().strip().split(' ')]
|
| 88 |
+
c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
# print(f'[WARN] dataset {uid} {vid}: {e}')
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# TODO: you may have a different camera system
|
| 94 |
+
# blender world + opencv cam --> opengl world & cam
|
| 95 |
+
c2w[1] *= -1
|
| 96 |
+
c2w[[1, 2]] = c2w[[2, 1]]
|
| 97 |
+
c2w[:3, 1:3] *= -1 # invert up and forward direction
|
| 98 |
+
|
| 99 |
+
# scale up radius to fully use the [-1, 1]^3 space!
|
| 100 |
+
c2w[:3, 3] *= self.opt.cam_radius / 1.5 # 1.5 is the default scale
|
| 101 |
+
|
| 102 |
+
image = image.permute(2, 0, 1) # [4, 512, 512]
|
| 103 |
+
mask = image[3:4] # [1, 512, 512]
|
| 104 |
+
image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg
|
| 105 |
+
image = image[[2,1,0]].contiguous() # bgr to rgb
|
| 106 |
+
|
| 107 |
+
images.append(image)
|
| 108 |
+
masks.append(mask.squeeze(0))
|
| 109 |
+
cam_poses.append(c2w)
|
| 110 |
+
|
| 111 |
+
vid_cnt += 1
|
| 112 |
+
if vid_cnt == self.opt.num_views:
|
| 113 |
+
break
|
| 114 |
+
|
| 115 |
+
if vid_cnt < self.opt.num_views:
|
| 116 |
+
print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')
|
| 117 |
+
n = self.opt.num_views - vid_cnt
|
| 118 |
+
images = images + [images[-1]] * n
|
| 119 |
+
masks = masks + [masks[-1]] * n
|
| 120 |
+
cam_poses = cam_poses + [cam_poses[-1]] * n
|
| 121 |
+
|
| 122 |
+
images = torch.stack(images, dim=0) # [V, C, H, W]
|
| 123 |
+
masks = torch.stack(masks, dim=0) # [V, H, W]
|
| 124 |
+
cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]
|
| 125 |
+
|
| 126 |
+
# normalized camera feats as in paper (transform the first pose to a fixed position)
|
| 127 |
+
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])
|
| 128 |
+
cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4]
|
| 129 |
+
|
| 130 |
+
images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]
|
| 131 |
+
cam_poses_input = cam_poses[:self.opt.num_input_views].clone()
|
| 132 |
+
|
| 133 |
+
# data augmentation
|
| 134 |
+
if self.training:
|
| 135 |
+
# apply random grid distortion to simulate 3D inconsistency
|
| 136 |
+
if random.random() < self.opt.prob_grid_distortion:
|
| 137 |
+
images_input[1:] = grid_distortion(images_input[1:])
|
| 138 |
+
# apply camera jittering (only to input!)
|
| 139 |
+
if random.random() < self.opt.prob_cam_jitter:
|
| 140 |
+
cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:])
|
| 141 |
+
|
| 142 |
+
images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
| 143 |
+
|
| 144 |
+
# resize render ground-truth images, range still in [0, 1]
|
| 145 |
+
results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]
|
| 146 |
+
results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]
|
| 147 |
+
|
| 148 |
+
# build rays for input views
|
| 149 |
+
rays_embeddings = []
|
| 150 |
+
for i in range(self.opt.num_input_views):
|
| 151 |
+
rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
|
| 152 |
+
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
|
| 153 |
+
rays_embeddings.append(rays_plucker)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
|
| 157 |
+
final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]
|
| 158 |
+
results['input'] = final_input
|
| 159 |
+
|
| 160 |
+
# opengl to colmap camera for gaussian renderer
|
| 161 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
| 162 |
+
|
| 163 |
+
# cameras needed by gaussian rasterizer
|
| 164 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
| 165 |
+
cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
|
| 166 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
| 167 |
+
|
| 168 |
+
results['cam_view'] = cam_view
|
| 169 |
+
results['cam_view_proj'] = cam_view_proj
|
| 170 |
+
results['cam_pos'] = cam_pos
|
| 171 |
+
|
| 172 |
+
return results
|
core/unet.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import Tuple, Literal
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
from core.attention import MemEffAttention
|
| 10 |
+
|
| 11 |
+
class MVAttention(nn.Module):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
dim: int,
|
| 15 |
+
num_heads: int = 8,
|
| 16 |
+
qkv_bias: bool = False,
|
| 17 |
+
proj_bias: bool = True,
|
| 18 |
+
attn_drop: float = 0.0,
|
| 19 |
+
proj_drop: float = 0.0,
|
| 20 |
+
groups: int = 32,
|
| 21 |
+
eps: float = 1e-5,
|
| 22 |
+
residual: bool = True,
|
| 23 |
+
skip_scale: float = 1,
|
| 24 |
+
num_frames: int = 4, # WARN: hardcoded!
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.residual = residual
|
| 29 |
+
self.skip_scale = skip_scale
|
| 30 |
+
self.num_frames = num_frames
|
| 31 |
+
|
| 32 |
+
self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
|
| 33 |
+
self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
# x: [B*V, C, H, W]
|
| 37 |
+
BV, C, H, W = x.shape
|
| 38 |
+
B = BV // self.num_frames # assert BV % self.num_frames == 0
|
| 39 |
+
|
| 40 |
+
res = x
|
| 41 |
+
x = self.norm(x)
|
| 42 |
+
|
| 43 |
+
x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C)
|
| 44 |
+
x = self.attn(x)
|
| 45 |
+
x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W)
|
| 46 |
+
|
| 47 |
+
if self.residual:
|
| 48 |
+
x = (x + res) * self.skip_scale
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
class ResnetBlock(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
in_channels: int,
|
| 55 |
+
out_channels: int,
|
| 56 |
+
resample: Literal['default', 'up', 'down'] = 'default',
|
| 57 |
+
groups: int = 32,
|
| 58 |
+
eps: float = 1e-5,
|
| 59 |
+
skip_scale: float = 1, # multiplied to output
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.in_channels = in_channels
|
| 64 |
+
self.out_channels = out_channels
|
| 65 |
+
self.skip_scale = skip_scale
|
| 66 |
+
|
| 67 |
+
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
| 68 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 69 |
+
|
| 70 |
+
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
| 71 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 72 |
+
|
| 73 |
+
self.act = F.silu
|
| 74 |
+
|
| 75 |
+
self.resample = None
|
| 76 |
+
if resample == 'up':
|
| 77 |
+
self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
| 78 |
+
elif resample == 'down':
|
| 79 |
+
self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 80 |
+
|
| 81 |
+
self.shortcut = nn.Identity()
|
| 82 |
+
if self.in_channels != self.out_channels:
|
| 83 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
res = x
|
| 88 |
+
|
| 89 |
+
x = self.norm1(x)
|
| 90 |
+
x = self.act(x)
|
| 91 |
+
|
| 92 |
+
if self.resample:
|
| 93 |
+
res = self.resample(res)
|
| 94 |
+
x = self.resample(x)
|
| 95 |
+
|
| 96 |
+
x = self.conv1(x)
|
| 97 |
+
x = self.norm2(x)
|
| 98 |
+
x = self.act(x)
|
| 99 |
+
x = self.conv2(x)
|
| 100 |
+
|
| 101 |
+
x = (x + self.shortcut(res)) * self.skip_scale
|
| 102 |
+
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
class DownBlock(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
in_channels: int,
|
| 109 |
+
out_channels: int,
|
| 110 |
+
num_layers: int = 1,
|
| 111 |
+
downsample: bool = True,
|
| 112 |
+
attention: bool = True,
|
| 113 |
+
attention_heads: int = 16,
|
| 114 |
+
skip_scale: float = 1,
|
| 115 |
+
):
|
| 116 |
+
super().__init__()
|
| 117 |
+
|
| 118 |
+
nets = []
|
| 119 |
+
attns = []
|
| 120 |
+
for i in range(num_layers):
|
| 121 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 122 |
+
nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
|
| 123 |
+
if attention:
|
| 124 |
+
attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))
|
| 125 |
+
else:
|
| 126 |
+
attns.append(None)
|
| 127 |
+
self.nets = nn.ModuleList(nets)
|
| 128 |
+
self.attns = nn.ModuleList(attns)
|
| 129 |
+
|
| 130 |
+
self.downsample = None
|
| 131 |
+
if downsample:
|
| 132 |
+
self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
xs = []
|
| 136 |
+
|
| 137 |
+
for attn, net in zip(self.attns, self.nets):
|
| 138 |
+
x = net(x)
|
| 139 |
+
if attn:
|
| 140 |
+
x = attn(x)
|
| 141 |
+
xs.append(x)
|
| 142 |
+
|
| 143 |
+
if self.downsample:
|
| 144 |
+
x = self.downsample(x)
|
| 145 |
+
xs.append(x)
|
| 146 |
+
|
| 147 |
+
return x, xs
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class MidBlock(nn.Module):
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
in_channels: int,
|
| 154 |
+
num_layers: int = 1,
|
| 155 |
+
attention: bool = True,
|
| 156 |
+
attention_heads: int = 16,
|
| 157 |
+
skip_scale: float = 1,
|
| 158 |
+
):
|
| 159 |
+
super().__init__()
|
| 160 |
+
|
| 161 |
+
nets = []
|
| 162 |
+
attns = []
|
| 163 |
+
# first layer
|
| 164 |
+
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
|
| 165 |
+
# more layers
|
| 166 |
+
for i in range(num_layers):
|
| 167 |
+
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
|
| 168 |
+
if attention:
|
| 169 |
+
attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale))
|
| 170 |
+
else:
|
| 171 |
+
attns.append(None)
|
| 172 |
+
self.nets = nn.ModuleList(nets)
|
| 173 |
+
self.attns = nn.ModuleList(attns)
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
x = self.nets[0](x)
|
| 177 |
+
for attn, net in zip(self.attns, self.nets[1:]):
|
| 178 |
+
if attn:
|
| 179 |
+
x = attn(x)
|
| 180 |
+
x = net(x)
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class UpBlock(nn.Module):
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
in_channels: int,
|
| 188 |
+
prev_out_channels: int,
|
| 189 |
+
out_channels: int,
|
| 190 |
+
num_layers: int = 1,
|
| 191 |
+
upsample: bool = True,
|
| 192 |
+
attention: bool = True,
|
| 193 |
+
attention_heads: int = 16,
|
| 194 |
+
skip_scale: float = 1,
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
nets = []
|
| 199 |
+
attns = []
|
| 200 |
+
for i in range(num_layers):
|
| 201 |
+
cin = in_channels if i == 0 else out_channels
|
| 202 |
+
cskip = prev_out_channels if (i == num_layers - 1) else out_channels
|
| 203 |
+
|
| 204 |
+
nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
|
| 205 |
+
if attention:
|
| 206 |
+
attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))
|
| 207 |
+
else:
|
| 208 |
+
attns.append(None)
|
| 209 |
+
self.nets = nn.ModuleList(nets)
|
| 210 |
+
self.attns = nn.ModuleList(attns)
|
| 211 |
+
|
| 212 |
+
self.upsample = None
|
| 213 |
+
if upsample:
|
| 214 |
+
self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 215 |
+
|
| 216 |
+
def forward(self, x, xs):
|
| 217 |
+
|
| 218 |
+
for attn, net in zip(self.attns, self.nets):
|
| 219 |
+
res_x = xs[-1]
|
| 220 |
+
xs = xs[:-1]
|
| 221 |
+
x = torch.cat([x, res_x], dim=1)
|
| 222 |
+
x = net(x)
|
| 223 |
+
if attn:
|
| 224 |
+
x = attn(x)
|
| 225 |
+
|
| 226 |
+
if self.upsample:
|
| 227 |
+
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
|
| 228 |
+
x = self.upsample(x)
|
| 229 |
+
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# it could be asymmetric!
|
| 234 |
+
class UNet(nn.Module):
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
in_channels: int = 3,
|
| 238 |
+
out_channels: int = 3,
|
| 239 |
+
down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
|
| 240 |
+
down_attention: Tuple[bool, ...] = (False, False, False, True, True),
|
| 241 |
+
mid_attention: bool = True,
|
| 242 |
+
up_channels: Tuple[int, ...] = (1024, 512, 256),
|
| 243 |
+
up_attention: Tuple[bool, ...] = (True, True, False),
|
| 244 |
+
layers_per_block: int = 2,
|
| 245 |
+
skip_scale: float = np.sqrt(0.5),
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
|
| 249 |
+
# first
|
| 250 |
+
self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)
|
| 251 |
+
|
| 252 |
+
# down
|
| 253 |
+
down_blocks = []
|
| 254 |
+
cout = down_channels[0]
|
| 255 |
+
for i in range(len(down_channels)):
|
| 256 |
+
cin = cout
|
| 257 |
+
cout = down_channels[i]
|
| 258 |
+
|
| 259 |
+
down_blocks.append(DownBlock(
|
| 260 |
+
cin, cout,
|
| 261 |
+
num_layers=layers_per_block,
|
| 262 |
+
downsample=(i != len(down_channels) - 1), # not final layer
|
| 263 |
+
attention=down_attention[i],
|
| 264 |
+
skip_scale=skip_scale,
|
| 265 |
+
))
|
| 266 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
| 267 |
+
|
| 268 |
+
# mid
|
| 269 |
+
self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale)
|
| 270 |
+
|
| 271 |
+
# up
|
| 272 |
+
up_blocks = []
|
| 273 |
+
cout = up_channels[0]
|
| 274 |
+
for i in range(len(up_channels)):
|
| 275 |
+
cin = cout
|
| 276 |
+
cout = up_channels[i]
|
| 277 |
+
cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric
|
| 278 |
+
|
| 279 |
+
up_blocks.append(UpBlock(
|
| 280 |
+
cin, cskip, cout,
|
| 281 |
+
num_layers=layers_per_block + 1, # one more layer for up
|
| 282 |
+
upsample=(i != len(up_channels) - 1), # not final layer
|
| 283 |
+
attention=up_attention[i],
|
| 284 |
+
skip_scale=skip_scale,
|
| 285 |
+
))
|
| 286 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
| 287 |
+
|
| 288 |
+
# last
|
| 289 |
+
self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5)
|
| 290 |
+
self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def forward(self, x):
|
| 294 |
+
# x: [B, Cin, H, W]
|
| 295 |
+
|
| 296 |
+
# first
|
| 297 |
+
x = self.conv_in(x)
|
| 298 |
+
|
| 299 |
+
# down
|
| 300 |
+
xss = [x]
|
| 301 |
+
for block in self.down_blocks:
|
| 302 |
+
x, xs = block(x)
|
| 303 |
+
xss.extend(xs)
|
| 304 |
+
|
| 305 |
+
# mid
|
| 306 |
+
x = self.mid_block(x)
|
| 307 |
+
|
| 308 |
+
# up
|
| 309 |
+
for block in self.up_blocks:
|
| 310 |
+
xs = xss[-len(block.nets):]
|
| 311 |
+
xss = xss[:-len(block.nets)]
|
| 312 |
+
x = block(x, xs)
|
| 313 |
+
|
| 314 |
+
# last
|
| 315 |
+
x = self.norm_out(x)
|
| 316 |
+
x = F.silu(x)
|
| 317 |
+
x = self.conv_out(x) # [B, Cout, H', W']
|
| 318 |
+
|
| 319 |
+
return x
|
core/utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
import roma
|
| 8 |
+
from kiui.op import safe_normalize
|
| 9 |
+
|
| 10 |
+
def get_rays(pose, h, w, fovy, opengl=True):
|
| 11 |
+
|
| 12 |
+
x, y = torch.meshgrid(
|
| 13 |
+
torch.arange(w, device=pose.device),
|
| 14 |
+
torch.arange(h, device=pose.device),
|
| 15 |
+
indexing="xy",
|
| 16 |
+
)
|
| 17 |
+
x = x.flatten()
|
| 18 |
+
y = y.flatten()
|
| 19 |
+
|
| 20 |
+
cx = w * 0.5
|
| 21 |
+
cy = h * 0.5
|
| 22 |
+
|
| 23 |
+
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
|
| 24 |
+
|
| 25 |
+
camera_dirs = F.pad(
|
| 26 |
+
torch.stack(
|
| 27 |
+
[
|
| 28 |
+
(x - cx + 0.5) / focal,
|
| 29 |
+
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
|
| 30 |
+
],
|
| 31 |
+
dim=-1,
|
| 32 |
+
),
|
| 33 |
+
(0, 1),
|
| 34 |
+
value=(-1.0 if opengl else 1.0),
|
| 35 |
+
) # [hw, 3]
|
| 36 |
+
|
| 37 |
+
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
|
| 38 |
+
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
|
| 39 |
+
|
| 40 |
+
rays_o = rays_o.view(h, w, 3)
|
| 41 |
+
rays_d = safe_normalize(rays_d).view(h, w, 3)
|
| 42 |
+
|
| 43 |
+
return rays_o, rays_d
|
| 44 |
+
|
| 45 |
+
def orbit_camera_jitter(poses, strength=0.1):
|
| 46 |
+
# poses: [B, 4, 4], assume orbit camera in opengl format
|
| 47 |
+
# random orbital rotate
|
| 48 |
+
|
| 49 |
+
B = poses.shape[0]
|
| 50 |
+
rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)
|
| 51 |
+
rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)
|
| 52 |
+
|
| 53 |
+
rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)
|
| 54 |
+
R = rot @ poses[:, :3, :3]
|
| 55 |
+
T = rot @ poses[:, :3, 3:]
|
| 56 |
+
|
| 57 |
+
new_poses = poses.clone()
|
| 58 |
+
new_poses[:, :3, :3] = R
|
| 59 |
+
new_poses[:, :3, 3:] = T
|
| 60 |
+
|
| 61 |
+
return new_poses
|
| 62 |
+
|
| 63 |
+
def grid_distortion(images, strength=0.5):
|
| 64 |
+
# images: [B, C, H, W]
|
| 65 |
+
# num_steps: int, grid resolution for distortion
|
| 66 |
+
# strength: float in [0, 1], strength of distortion
|
| 67 |
+
|
| 68 |
+
B, C, H, W = images.shape
|
| 69 |
+
|
| 70 |
+
num_steps = np.random.randint(8, 17)
|
| 71 |
+
grid_steps = torch.linspace(-1, 1, num_steps)
|
| 72 |
+
|
| 73 |
+
# have to loop batch...
|
| 74 |
+
grids = []
|
| 75 |
+
for b in range(B):
|
| 76 |
+
# construct displacement
|
| 77 |
+
x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
|
| 78 |
+
x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
|
| 79 |
+
x_steps = (x_steps * W).long() # [num_steps]
|
| 80 |
+
x_steps[0] = 0
|
| 81 |
+
x_steps[-1] = W
|
| 82 |
+
xs = []
|
| 83 |
+
for i in range(num_steps - 1):
|
| 84 |
+
xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))
|
| 85 |
+
xs = torch.cat(xs, dim=0) # [W]
|
| 86 |
+
|
| 87 |
+
y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
|
| 88 |
+
y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
|
| 89 |
+
y_steps = (y_steps * H).long() # [num_steps]
|
| 90 |
+
y_steps[0] = 0
|
| 91 |
+
y_steps[-1] = H
|
| 92 |
+
ys = []
|
| 93 |
+
for i in range(num_steps - 1):
|
| 94 |
+
ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))
|
| 95 |
+
ys = torch.cat(ys, dim=0) # [H]
|
| 96 |
+
|
| 97 |
+
# construct grid
|
| 98 |
+
grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W]
|
| 99 |
+
grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2]
|
| 100 |
+
|
| 101 |
+
grids.append(grid)
|
| 102 |
+
|
| 103 |
+
grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2]
|
| 104 |
+
|
| 105 |
+
# grid sample
|
| 106 |
+
images = F.grid_sample(images, grids, align_corners=False)
|
| 107 |
+
|
| 108 |
+
return images
|
| 109 |
+
|
data_test/catstatue.ply
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:57dc6f5902301d7577c53a73ce4c9d1bbff2fca86bf93d015b6cdfa1d3de9b18
|
| 3 |
+
size 2390497
|