Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from shap_e.models.nn.checkpoint import checkpoint | |
| from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType | |
| from .util import timestep_embedding | |
| def init_linear(l, stddev): | |
| nn.init.normal_(l.weight, std=stddev) | |
| if l.bias is not None: | |
| nn.init.constant_(l.bias, 0.0) | |
| class MultiheadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| n_ctx: int, | |
| width: int, | |
| heads: int, | |
| init_scale: float, | |
| ): | |
| super().__init__() | |
| self.n_ctx = n_ctx | |
| self.width = width | |
| self.heads = heads | |
| self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype) | |
| self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) | |
| self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) | |
| init_linear(self.c_qkv, init_scale) | |
| init_linear(self.c_proj, init_scale) | |
| def forward(self, x): | |
| x = self.c_qkv(x) | |
| x = checkpoint(self.attention, (x,), (), True) | |
| x = self.c_proj(x) | |
| return x | |
| class MLP(nn.Module): | |
| def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float): | |
| super().__init__() | |
| self.width = width | |
| self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) | |
| self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) | |
| self.gelu = nn.GELU() | |
| init_linear(self.c_fc, init_scale) | |
| init_linear(self.c_proj, init_scale) | |
| def forward(self, x): | |
| return self.c_proj(self.gelu(self.c_fc(x))) | |
| class QKVMultiheadAttention(nn.Module): | |
| def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): | |
| super().__init__() | |
| self.device = device | |
| self.dtype = dtype | |
| self.heads = heads | |
| self.n_ctx = n_ctx | |
| def forward(self, qkv): | |
| bs, n_ctx, width = qkv.shape | |
| attn_ch = width // self.heads // 3 | |
| scale = 1 / math.sqrt(math.sqrt(attn_ch)) | |
| qkv = qkv.view(bs, n_ctx, self.heads, -1) | |
| q, k, v = torch.split(qkv, attn_ch, dim=-1) | |
| weight = torch.einsum( | |
| "bthc,bshc->bhts", q * scale, k * scale | |
| ) # More stable with f16 than dividing afterwards | |
| wdtype = weight.dtype | |
| weight = torch.softmax(weight.float(), dim=-1).type(wdtype) | |
| return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) | |
| class ResidualAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| n_ctx: int, | |
| width: int, | |
| heads: int, | |
| init_scale: float = 1.0, | |
| ): | |
| super().__init__() | |
| self.attn = MultiheadAttention( | |
| device=device, | |
| dtype=dtype, | |
| n_ctx=n_ctx, | |
| width=width, | |
| heads=heads, | |
| init_scale=init_scale, | |
| ) | |
| self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) | |
| self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) | |
| self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) | |
| def forward(self, x: torch.Tensor): | |
| x = x + self.attn(self.ln_1(x)) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| n_ctx: int, | |
| width: int, | |
| layers: int, | |
| heads: int, | |
| init_scale: float = 0.25, | |
| ): | |
| super().__init__() | |
| self.n_ctx = n_ctx | |
| self.width = width | |
| self.layers = layers | |
| init_scale = init_scale * math.sqrt(1.0 / width) | |
| self.resblocks = nn.ModuleList( | |
| [ | |
| ResidualAttentionBlock( | |
| device=device, | |
| dtype=dtype, | |
| n_ctx=n_ctx, | |
| width=width, | |
| heads=heads, | |
| init_scale=init_scale, | |
| ) | |
| for _ in range(layers) | |
| ] | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| for block in self.resblocks: | |
| x = block(x) | |
| return x | |
| class PointDiffusionTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| input_channels: int = 3, | |
| output_channels: int = 3, | |
| n_ctx: int = 1024, | |
| width: int = 512, | |
| layers: int = 12, | |
| heads: int = 8, | |
| init_scale: float = 0.25, | |
| time_token_cond: bool = False, | |
| use_pos_emb: bool = False, | |
| pos_emb_init_scale: float = 1.0, | |
| pos_emb_n_ctx: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.input_channels = input_channels | |
| self.output_channels = output_channels | |
| self.n_ctx = n_ctx | |
| self.time_token_cond = time_token_cond | |
| self.use_pos_emb = use_pos_emb | |
| self.time_embed = MLP( | |
| device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) | |
| ) | |
| self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) | |
| self.backbone = Transformer( | |
| device=device, | |
| dtype=dtype, | |
| n_ctx=n_ctx + int(time_token_cond), | |
| width=width, | |
| layers=layers, | |
| heads=heads, | |
| init_scale=init_scale, | |
| ) | |
| self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) | |
| self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) | |
| self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) | |
| # with torch.no_grad(): | |
| # self.output_proj.weight.zero_() | |
| # self.output_proj.bias.zero_() | |
| if self.use_pos_emb: | |
| self.register_parameter( | |
| "pos_emb", | |
| nn.Parameter( | |
| pos_emb_init_scale | |
| * torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype) | |
| ), | |
| ) | |
| def forward(self, x: torch.Tensor, t: torch.Tensor): | |
| """ | |
| :param x: an [N x C x T] tensor. | |
| :param t: an [N] tensor. | |
| :return: an [N x C' x T] tensor. | |
| """ | |
| assert x.shape[-1] == self.n_ctx | |
| t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) | |
| return self._forward_with_cond(x, [(t_embed, self.time_token_cond)]) | |
| def _forward_with_cond( | |
| self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]] | |
| ) -> torch.Tensor: | |
| h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC | |
| for emb, as_token in cond_as_token: | |
| if not as_token: | |
| h = h + emb[:, None] | |
| if self.use_pos_emb: | |
| h = h + self.pos_emb | |
| extra_tokens = [ | |
| (emb[:, None] if len(emb.shape) == 2 else emb) | |
| for emb, as_token in cond_as_token | |
| if as_token | |
| ] | |
| if len(extra_tokens): | |
| h = torch.cat(extra_tokens + [h], dim=1) | |
| h = self.ln_pre(h) | |
| h = self.backbone(h) | |
| h = self.ln_post(h) | |
| if len(extra_tokens): | |
| h = h[:, sum(h.shape[1] for h in extra_tokens):] | |
| h = self.output_proj(h) | |
| return h.permute(0, 2, 1) # NCL -> NLC | |
| class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| n_ctx: int = 1024, | |
| token_cond: bool = False, | |
| cond_drop_prob: float = 0.0, | |
| frozen_clip: bool = True, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs | |
| ) | |
| # print("!!!!!", "deivce:", device, "dtype:", dtype, "n_ctx:", n_ctx, "token_cond:", token_cond, "cond_drop_prob:", cond_drop_prob, "frozen_clip:", frozen_clip, "kwargs:", kwargs) | |
| self.n_ctx = n_ctx | |
| self.token_cond = token_cond | |
| self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) | |
| self.clip_embed = nn.Linear( | |
| self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype | |
| ) | |
| self.cond_drop_prob = cond_drop_prob | |
| def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: | |
| with torch.no_grad(): | |
| return dict(embeddings=self.clip(batch_size, **model_kwargs)) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| images: Optional[Iterable[Optional[ImageType]]] = None, | |
| texts: Optional[Iterable[Optional[str]]] = None, | |
| embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, | |
| ): | |
| """ | |
| :param x: an [N x C x T] tensor. | |
| :param t: an [N] tensor. | |
| :param images: a batch of images to condition on. | |
| :param texts: a batch of texts to condition on. | |
| :param embeddings: a batch of CLIP embeddings to condition on. | |
| :return: an [N x C' x T] tensor. | |
| """ | |
| # print("x.shape", x.shape, "t.shape", t.shape, "images", images, "texts", texts, "embeddings", embeddings) | |
| assert x.shape[-1] == self.n_ctx # self.n_ctx = 1024 | |
| t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) | |
| clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings) | |
| assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0] | |
| if self.training: | |
| mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob | |
| clip_out = clip_out * mask[:, None].to(clip_out) | |
| # Rescale the features to have unit variance | |
| clip_out = math.sqrt(clip_out.shape[1]) * clip_out | |
| clip_embed = self.clip_embed(clip_out) | |
| cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)] | |
| return self._forward_with_cond(x, cond) | |
| class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| n_ctx: int = 1024, | |
| cond_drop_prob: float = 0.0, | |
| frozen_clip: bool = True, | |
| **kwargs, | |
| ): | |
| clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) | |
| super().__init__( | |
| device=device, | |
| dtype=dtype, | |
| n_ctx=n_ctx + clip.grid_size**2, | |
| pos_emb_n_ctx=n_ctx, | |
| **kwargs, | |
| ) | |
| self.n_ctx = n_ctx | |
| self.clip = clip | |
| self.clip_embed = nn.Sequential( | |
| nn.LayerNorm( | |
| normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype | |
| ), | |
| nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), | |
| ) | |
| self.cond_drop_prob = cond_drop_prob | |
| def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: | |
| _ = batch_size | |
| with torch.no_grad(): | |
| return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"])) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| images: Optional[Iterable[ImageType]] = None, | |
| embeddings: Optional[Iterable[torch.Tensor]] = None, | |
| ): | |
| """ | |
| :param x: an [N x C x T] tensor. | |
| :param t: an [N] tensor. | |
| :param images: a batch of images to condition on. | |
| :param embeddings: a batch of CLIP latent grids to condition on. | |
| :return: an [N x C' x T] tensor. | |
| """ | |
| assert images is not None or embeddings is not None, "must specify images or embeddings" | |
| assert images is None or embeddings is None, "cannot specify both images and embeddings" | |
| assert x.shape[-1] == self.n_ctx | |
| t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) | |
| if images is not None: | |
| clip_out = self.clip.embed_images_grid(images) | |
| else: | |
| clip_out = embeddings | |
| if self.training: | |
| mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob | |
| clip_out = clip_out * mask[:, None, None].to(clip_out) | |
| clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC | |
| clip_embed = self.clip_embed(clip_out) | |
| cond = [(t_embed, self.time_token_cond), (clip_embed, True)] | |
| return self._forward_with_cond(x, cond) | |
| class UpsamplePointDiffusionTransformer(PointDiffusionTransformer): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| cond_input_channels: Optional[int] = None, | |
| cond_ctx: int = 1024, | |
| n_ctx: int = 4096 - 1024, | |
| channel_scales: Optional[Sequence[float]] = None, | |
| channel_biases: Optional[Sequence[float]] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs) | |
| self.n_ctx = n_ctx | |
| self.cond_input_channels = cond_input_channels or self.input_channels | |
| self.cond_point_proj = nn.Linear( | |
| self.cond_input_channels, self.backbone.width, device=device, dtype=dtype | |
| ) | |
| self.register_buffer( | |
| "channel_scales", | |
| torch.tensor(channel_scales, dtype=dtype, device=device) | |
| if channel_scales is not None | |
| else None, | |
| ) | |
| self.register_buffer( | |
| "channel_biases", | |
| torch.tensor(channel_biases, dtype=dtype, device=device) | |
| if channel_biases is not None | |
| else None, | |
| ) | |
| def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor): | |
| """ | |
| :param x: an [N x C1 x T] tensor. | |
| :param t: an [N] tensor. | |
| :param low_res: an [N x C2 x T'] tensor of conditioning points. | |
| :return: an [N x C3 x T] tensor. | |
| """ | |
| assert x.shape[-1] == self.n_ctx | |
| t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) | |
| low_res_embed = self._embed_low_res(low_res) | |
| cond = [(t_embed, self.time_token_cond), (low_res_embed, True)] | |
| return self._forward_with_cond(x, cond) | |
| def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.channel_scales is not None: | |
| x = x * self.channel_scales[None, :, None] | |
| if self.channel_biases is not None: | |
| x = x + self.channel_biases[None, :, None] | |
| return self.cond_point_proj(x.permute(0, 2, 1)) | |
| class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer): | |
| def __init__( | |
| self, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| n_ctx: int = 4096 - 1024, | |
| cond_drop_prob: float = 0.0, | |
| frozen_clip: bool = True, | |
| **kwargs, | |
| ): | |
| clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) | |
| super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs) | |
| self.n_ctx = n_ctx | |
| self.clip = clip | |
| self.clip_embed = nn.Sequential( | |
| nn.LayerNorm( | |
| normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype | |
| ), | |
| nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), | |
| ) | |
| self.cond_drop_prob = cond_drop_prob | |
| def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: | |
| _ = batch_size | |
| with torch.no_grad(): | |
| return dict( | |
| embeddings=self.clip.embed_images_grid(model_kwargs["images"]), | |
| low_res=model_kwargs["low_res"], | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| *, | |
| low_res: torch.Tensor, | |
| images: Optional[Iterable[ImageType]] = None, | |
| embeddings: Optional[Iterable[torch.Tensor]] = None, | |
| ): | |
| """ | |
| :param x: an [N x C1 x T] tensor. | |
| :param t: an [N] tensor. | |
| :param low_res: an [N x C2 x T'] tensor of conditioning points. | |
| :param images: a batch of images to condition on. | |
| :param embeddings: a batch of CLIP latent grids to condition on. | |
| :return: an [N x C3 x T] tensor. | |
| """ | |
| assert x.shape[-1] == self.n_ctx | |
| t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) | |
| low_res_embed = self._embed_low_res(low_res) | |
| if images is not None: | |
| clip_out = self.clip.embed_images_grid(images) | |
| else: | |
| clip_out = embeddings | |
| if self.training: | |
| mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob | |
| clip_out = clip_out * mask[:, None, None].to(clip_out) | |
| clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC | |
| clip_embed = self.clip_embed(clip_out) | |
| cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)] | |
| return self._forward_with_cond(x, cond) | |