Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Literal, Optional | |
| import json | |
| import open_clip | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from open_clip import create_model_from_pretrained | |
| from torchvision.transforms import Normalize | |
| from think_sound.models.factory import create_model_from_config | |
| from think_sound.models.utils import load_ckpt_state_dict | |
| from think_sound.training.utils import copy_state_dict | |
| from transformers import AutoModel | |
| from transformers import AutoProcessor | |
| from transformers import T5EncoderModel, AutoTokenizer | |
| import logging | |
| from data_utils.ext.synchformer import Synchformer | |
| log = logging.getLogger() | |
| def patch_clip(clip_model): | |
| # a hack to make it output last hidden states | |
| # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 | |
| def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None): | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| text_outputs = self.text_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| last_hidden_state = text_outputs[0] | |
| pooled_output = text_outputs[1] | |
| text_features = self.text_projection(pooled_output) | |
| return text_features, last_hidden_state | |
| clip_model.get_text_features = new_get_text_features.__get__(clip_model) | |
| return clip_model | |
| class FeaturesUtils(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| vae_ckpt: Optional[str] = None, | |
| vae_config: Optional[str] = None, | |
| synchformer_ckpt: Optional[str] = None, | |
| enable_conditions: bool = True, | |
| need_vae_encoder: bool = True, | |
| ): | |
| super().__init__() | |
| if enable_conditions: | |
| self.clip_model = AutoModel.from_pretrained("facebook/metaclip-h14-fullcc2.5b") | |
| self.clip_model = patch_clip(self.clip_model) | |
| self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl") | |
| self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl") | |
| self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") | |
| # self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
| # std=[0.26862954, 0.26130258, 0.27577711]) | |
| self.synchformer = Synchformer() | |
| self.synchformer.load_state_dict( | |
| torch.load(synchformer_ckpt, weights_only=True, map_location='cpu')) | |
| # self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' | |
| else: | |
| self.clip_model = None | |
| self.synchformer = None | |
| self.tokenizer = None | |
| if vae_ckpt is not None: | |
| with open(vae_config) as f: | |
| vae_config = json.load(f) | |
| self.vae = create_model_from_config(vae_config) | |
| print(f"Loading model checkpoint from {vae_ckpt}") | |
| # Load checkpoint | |
| copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.' | |
| else: | |
| self.tod = None | |
| def compile(self): | |
| if self.clip_model is not None: | |
| self.clip_model.encode_image = torch.compile(self.clip_model.encode_image) | |
| self.clip_model.encode_text = torch.compile(self.clip_model.encode_text) | |
| if self.synchformer is not None: | |
| self.synchformer = torch.compile(self.synchformer) | |
| def train(self, mode: bool) -> None: | |
| return super().train(False) | |
| def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: | |
| assert self.clip_model is not None, 'CLIP is not loaded' | |
| # x: (B, T, C, H, W) H/W: 384 | |
| b, t, c, h, w = x.shape | |
| assert c == 3 and h == 224 and w == 224 | |
| # x = self.clip_preprocess(x) | |
| x = rearrange(x, 'b t c h w -> (b t) c h w') | |
| outputs = [] | |
| if batch_size < 0: | |
| batch_size = b * t | |
| for i in range(0, b * t, batch_size): | |
| outputs.append(self.clip_model.get_image_features(x[i:i + batch_size])) | |
| x = torch.cat(outputs, dim=0) | |
| # x = self.clip_model.encode_image(x, normalize=True) | |
| x = rearrange(x, '(b t) d -> b t d', b=b) | |
| return x | |
| def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: | |
| assert self.synchformer is not None, 'Synchformer is not loaded' | |
| # x: (B, T, C, H, W) H/W: 384 | |
| b, t, c, h, w = x.shape | |
| # import ipdb | |
| # ipdb.set_trace() | |
| assert c == 3 and h == 224 and w == 224 | |
| # partition the video | |
| segment_size = 16 | |
| step_size = 8 | |
| num_segments = (t - segment_size) // step_size + 1 | |
| segments = [] | |
| for i in range(num_segments): | |
| segments.append(x[:, i * step_size:i * step_size + segment_size]) | |
| x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) | |
| outputs = [] | |
| if batch_size < 0: | |
| batch_size = b | |
| x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w') | |
| for i in range(0, b * num_segments, batch_size): | |
| outputs.append(self.synchformer(x[i:i + batch_size])) | |
| x = torch.cat(outputs, dim=0) | |
| x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b) | |
| return x | |
| def encode_text(self, text: list[str]) -> torch.Tensor: | |
| assert self.clip_model is not None, 'CLIP is not loaded' | |
| # assert self.tokenizer is not None, 'Tokenizer is not loaded' | |
| # x: (B, L) | |
| tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device) | |
| return self.clip_model.get_text_features(**tokens) | |
| def encode_t5_text(self, text: list[str]) -> torch.Tensor: | |
| assert self.t5_model is not None, 'T5 model is not loaded' | |
| assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded' | |
| # x: (B, L) | |
| inputs = self.t5_tokenizer(text, | |
| truncation=True, | |
| max_length=77, | |
| padding="max_length", | |
| return_tensors="pt").to(self.device) | |
| return self.t5_model(**inputs).last_hidden_state | |
| def encode_audio(self, x) -> torch.Tensor: | |
| x = self.vae.encode(x) | |
| return x | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |