Spaces:
Runtime error
Runtime error
| """Vector quantizer. | |
| Copyright (2024) Bytedance Ltd. and/or its affiliates | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| Reference: | |
| https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py | |
| https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py | |
| """ | |
| from typing import Mapping, Text, Tuple | |
| import torch | |
| from einops import rearrange | |
| from torch.cuda.amp import autocast | |
| class VectorQuantizer(torch.nn.Module): | |
| def __init__(self, | |
| codebook_size: int = 1024, | |
| token_size: int = 256, | |
| commitment_cost: float = 0.25, | |
| use_l2_norm: bool = False, | |
| ): | |
| super().__init__() | |
| self.commitment_cost = commitment_cost | |
| self.embedding = torch.nn.Embedding(codebook_size, token_size) | |
| self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) | |
| self.use_l2_norm = use_l2_norm | |
| # Ensure quantization is performed using f32 | |
| def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: | |
| z = z.float() | |
| z = rearrange(z, 'b c h w -> b h w c').contiguous() | |
| z_flattened = rearrange(z, 'b h w c -> (b h w) c') | |
| if self.use_l2_norm: | |
| z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) | |
| embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) | |
| else: | |
| embedding = self.embedding.weight | |
| d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ | |
| torch.sum(embedding**2, dim=1) - 2 * \ | |
| torch.einsum('bd,dn->bn', z_flattened, embedding.T) | |
| min_encoding_indices = torch.argmin(d, dim=1) # num_ele | |
| z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) | |
| if self.use_l2_norm: | |
| z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) | |
| z = torch.nn.functional.normalize(z, dim=-1) | |
| # compute loss for embedding | |
| commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2) | |
| codebook_loss = torch.mean((z_quantized - z.detach()) **2) | |
| loss = commitment_loss + codebook_loss | |
| # preserve gradients | |
| z_quantized = z + (z_quantized - z).detach() | |
| # reshape back to match original input shape | |
| z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() | |
| result_dict = dict( | |
| quantizer_loss=loss, | |
| commitment_loss=commitment_loss, | |
| codebook_loss=codebook_loss, | |
| min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) | |
| ) | |
| return z_quantized, result_dict | |
| def get_codebook_entry(self, indices): | |
| if len(indices.shape) == 1: | |
| z_quantized = self.embedding(indices) | |
| elif len(indices.shape) == 2: | |
| z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight) | |
| else: | |
| raise NotImplementedError | |
| return z_quantized |