Spaces:
Build error
Build error
| import random | |
| from math import ceil | |
| from functools import partial | |
| from itertools import zip_longest | |
| from random import randrange | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| # from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize | |
| from models.vq.quantizer import QuantizeEMAReset, QuantizeEMA | |
| from einops import rearrange, repeat, pack, unpack | |
| # helper functions | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def round_up_multiple(num, mult): | |
| return ceil(num / mult) * mult | |
| # main class | |
| class ResidualVQ(nn.Module): | |
| """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ | |
| def __init__( | |
| self, | |
| num_quantizers, | |
| shared_codebook=False, | |
| quantize_dropout_prob=0.5, | |
| quantize_dropout_cutoff_index=0, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.num_quantizers = num_quantizers | |
| # self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)]) | |
| if shared_codebook: | |
| layer = QuantizeEMAReset(**kwargs) | |
| self.layers = nn.ModuleList([layer for _ in range(num_quantizers)]) | |
| else: | |
| self.layers = nn.ModuleList([QuantizeEMAReset(**kwargs) for _ in range(num_quantizers)]) | |
| # self.layers = nn.ModuleList([QuantizeEMA(**kwargs) for _ in range(num_quantizers)]) | |
| # self.quantize_dropout = quantize_dropout and num_quantizers > 1 | |
| assert quantize_dropout_cutoff_index >= 0 and quantize_dropout_prob >= 0 | |
| self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index | |
| self.quantize_dropout_prob = quantize_dropout_prob | |
| def codebooks(self): | |
| codebooks = [layer.codebook for layer in self.layers] | |
| codebooks = torch.stack(codebooks, dim = 0) | |
| return codebooks # 'q c d' | |
| def get_codes_from_indices(self, indices): #indices shape 'b n q' # dequantize | |
| batch, quantize_dim = indices.shape[0], indices.shape[-1] | |
| # because of quantize dropout, one can pass in indices that are coarse | |
| # and the network should be able to reconstruct | |
| if quantize_dim < self.num_quantizers: | |
| indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1) | |
| # get ready for gathering | |
| codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch) | |
| gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1]) | |
| # take care of quantizer dropout | |
| mask = gather_indices == -1. | |
| gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later | |
| # print(gather_indices.max(), gather_indices.min()) | |
| all_codes = codebooks.gather(2, gather_indices) # gather all codes | |
| # mask out any codes that were dropout-ed | |
| all_codes = all_codes.masked_fill(mask, 0.) | |
| return all_codes # 'q b n d' | |
| def get_codebook_entry(self, indices): #indices shape 'b n q' | |
| all_codes = self.get_codes_from_indices(indices) #'q b n d' | |
| latent = torch.sum(all_codes, dim=0) #'b n d' | |
| latent = latent.permute(0, 2, 1) | |
| return latent | |
| def forward(self, x, return_all_codes = False, sample_codebook_temp = None, force_dropout_index=-1): | |
| # debug check | |
| # print(self.codebooks[:,0,0].detach().cpu().numpy()) | |
| num_quant, quant_dropout_prob, device = self.num_quantizers, self.quantize_dropout_prob, x.device | |
| quantized_out = 0. | |
| residual = x | |
| all_losses = [] | |
| all_indices = [] | |
| all_perplexity = [] | |
| should_quantize_dropout = self.training and random.random() < self.quantize_dropout_prob | |
| start_drop_quantize_index = num_quant | |
| # To ensure the first-k layers learn things as much as possible, we randomly dropout the last q - k layers | |
| if should_quantize_dropout: | |
| start_drop_quantize_index = randrange(self.quantize_dropout_cutoff_index, num_quant) # keep quant layers <= quantize_dropout_cutoff_index, TODO vary in batch | |
| null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n' | |
| null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long) | |
| # null_loss = 0. | |
| if force_dropout_index >= 0: | |
| should_quantize_dropout = True | |
| start_drop_quantize_index = force_dropout_index | |
| null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n' | |
| null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long) | |
| # print(force_dropout_index) | |
| # go through the layers | |
| for quantizer_index, layer in enumerate(self.layers): | |
| if should_quantize_dropout and quantizer_index > start_drop_quantize_index: | |
| all_indices.append(null_indices) | |
| # all_losses.append(null_loss) | |
| continue | |
| # layer_indices = None | |
| # if return_loss: | |
| # layer_indices = indices[..., quantizer_index] #gt indices | |
| # quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) #single quantizer TODO | |
| quantized, *rest = layer(residual, return_idx=True, temperature=sample_codebook_temp) #single quantizer | |
| # print(quantized.shape, residual.shape) | |
| residual -= quantized.detach() | |
| quantized_out += quantized | |
| embed_indices, loss, perplexity = rest | |
| all_indices.append(embed_indices) | |
| all_losses.append(loss) | |
| all_perplexity.append(perplexity) | |
| # stack all losses and indices | |
| all_indices = torch.stack(all_indices, dim=-1) | |
| all_losses = sum(all_losses)/len(all_losses) | |
| all_perplexity = sum(all_perplexity)/len(all_perplexity) | |
| ret = (quantized_out, all_indices, all_losses, all_perplexity) | |
| if return_all_codes: | |
| # whether to return all codes from all codebooks across layers | |
| all_codes = self.get_codes_from_indices(all_indices) | |
| # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) | |
| ret = (*ret, all_codes) | |
| return ret | |
| def quantize(self, x, return_latent=False): | |
| all_indices = [] | |
| quantized_out = 0. | |
| residual = x | |
| all_codes = [] | |
| for quantizer_index, layer in enumerate(self.layers): | |
| quantized, *rest = layer(residual, return_idx=True) #single quantizer | |
| residual = residual - quantized.detach() | |
| quantized_out = quantized_out + quantized | |
| embed_indices, loss, perplexity = rest | |
| all_indices.append(embed_indices) | |
| # print(quantizer_index, embed_indices[0]) | |
| # print(quantizer_index, quantized[0]) | |
| # break | |
| all_codes.append(quantized) | |
| code_idx = torch.stack(all_indices, dim=-1) | |
| all_codes = torch.stack(all_codes, dim=0) | |
| if return_latent: | |
| return code_idx, all_codes | |
| return code_idx |