Spaces:
Paused
Paused
| import os | |
| from glob import glob | |
| import joblib | |
| import numpy as np | |
| import torch | |
| from sklearn.cluster import MiniBatchKMeans | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from train import instantiate_from_config | |
| class FeatClusterStage(object): | |
| def __init__(self, num_clusters=None, cached_kmeans_path=None, feats_dataset_config=None, num_workers=None): | |
| if cached_kmeans_path is not None and os.path.exists(cached_kmeans_path): | |
| print(f'Precalculated Clusterer already exists, loading from {cached_kmeans_path}') | |
| self.clusterer = joblib.load(cached_kmeans_path) | |
| elif feats_dataset_config is not None: | |
| self.clusterer = self.load_or_precalculate_kmeans(num_clusters, feats_dataset_config, num_workers) | |
| else: | |
| raise Exception('Neither `feats_dataset_config` nor `cached_kmeans_path` are defined') | |
| def eval(self): | |
| return self | |
| def encode(self, c): | |
| # c_quant: cluster centers, c_ind: cluster index | |
| B, D, T = c.shape | |
| # (B*T, D) <- (B, T, D) <- (B, D, T) | |
| c_flat = c.permute(0, 2, 1).view(B*T, D).cpu().numpy() | |
| c_ind = self.clusterer.predict(c_flat) | |
| c_quant = self.clusterer.cluster_centers_[c_ind] | |
| c_ind = torch.from_numpy(c_ind).to(c.device) | |
| c_quant = torch.from_numpy(c_quant).to(c.device) | |
| c_ind = c_ind.long().unsqueeze(-1) | |
| c_quant = c_quant.view(B, T, D).permute(0, 2, 1) | |
| info = None, None, c_ind | |
| # (B, D, T), (), ((), (768, 1024), (768, 1)) | |
| return c_quant, None, info | |
| def decode(self, c): | |
| return c | |
| def get_input(self, batch, k): | |
| x = batch[k] | |
| x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format) | |
| return x.float() | |
| def load_or_precalculate_kmeans(self, num_clusters, dataset_cfg, num_workers): | |
| print(f'Calculating clustering K={num_clusters}') | |
| batch_size = 64 | |
| dataset_name = dataset_cfg.target.split('.')[-1] | |
| cached_path = os.path.join('./specvqgan/modules/misc/', f'kmeans_K{num_clusters}_{dataset_name}.sklearn') | |
| feat_depth = dataset_cfg.params.condition_dataset_cfg.feat_depth | |
| feat_crop_len = dataset_cfg.params.condition_dataset_cfg.feat_crop_len | |
| feat_loading_dset = instantiate_from_config(dataset_cfg) | |
| feat_loading_dset = DataLoader(feat_loading_dset, batch_size, num_workers=num_workers, shuffle=True) | |
| clusterer = MiniBatchKMeans(num_clusters, batch_size=batch_size*feat_crop_len, random_state=0) | |
| for item in tqdm(feat_loading_dset): | |
| batch = item['feature'].reshape(-1, feat_depth).float().numpy() | |
| clusterer.partial_fit(batch) | |
| joblib.dump(clusterer, cached_path) | |
| print(f'Saved the calculated Clusterer @ {cached_path}') | |
| return clusterer | |
| if __name__ == '__main__': | |
| from omegaconf import OmegaConf | |
| config = OmegaConf.load('./configs/vggsound_featcluster_transformer.yaml') | |
| config.model.params.first_stage_config.params.ckpt_path = './logs/2021-05-19T22-16-54_vggsound_specs_vqgan/checkpoints/epoch_39.ckpt' | |
| model = instantiate_from_config(config.model.params.cond_stage_config) | |
| print(model) | |