Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| ################################################## PACKAGES ############################################################ | |
| ################################################# PACKAGES ############################################################# | |
| # PyTorch for deep learning operations | |
| import torch | |
| import torch.nn as nn | |
| # PyTorch data loading and utilities | |
| import torch.multiprocessing | |
| # Additional PyTorch modules and libraries | |
| import numpy as np | |
| # Hugging Face Transformers library for BERT models | |
| from transformers import BertModel, BertTokenizer, AutoImageProcessor, VideoMAEModel | |
| # Visualization and progress tracking | |
| from datasets import load_dataset | |
| import av # pip install av | |
| # Additional utility for iterating over combinations | |
| import pandas as pd | |
| from configs import CFG | |
| from text_image import OneEncoder as TextImageEncoder | |
| def read_video_pyav(container, indices): | |
| """ | |
| Decode the video with PyAV decoder. | |
| Args: | |
| container (`av.container.input.InputContainer`): PyAV container. | |
| indices (`List[int]`): List of frame indices to decode. | |
| Returns: | |
| result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). | |
| """ | |
| frames = [] | |
| container.seek(0) | |
| start_index = indices[0] | |
| end_index = indices[-1] | |
| for i, frame in enumerate(container.decode(video=0)): | |
| if i > end_index: | |
| break | |
| if i >= start_index and i in indices: | |
| frames.append(frame) | |
| return np.stack([x.to_ndarray(format="rgb24") for x in frames]) | |
| def sample_frame_indices(clip_len, frame_sample_rate, seg_len): | |
| """ | |
| Sample a given number of frame indices from the video. | |
| Args: | |
| clip_len (`int`): Total number of frames to sample. | |
| frame_sample_rate (`int`): Sample every n-th frame. | |
| seg_len (`int`): Maximum allowed index of sample's last frame. | |
| Returns: | |
| indices (`List[int]`): List of sampled frame indices | |
| """ | |
| converted_len = int(clip_len * frame_sample_rate) | |
| end_idx = np.random.randint(converted_len, seg_len) | |
| start_idx = end_idx - converted_len | |
| indices = np.linspace(start_idx, end_idx, num=clip_len) | |
| indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) | |
| return indices | |
| class AlignmentLayer(nn.Module): | |
| def __init__(self, input_dim=768, projection_dim=CFG.projection_dim, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
| super(AlignmentLayer, self).__init__(*args, **kwargs) | |
| # Attributes | |
| self.input_dim = input_dim | |
| self.projection_dim = projection_dim | |
| self.dropout_rate = dropout_rate | |
| # Layers | |
| self.linear_layer1 = nn.Linear(self.input_dim, self.projection_dim) | |
| self.gelu = nn.GELU() | |
| self.linear_layer2 = nn.Linear(self.projection_dim, self.projection_dim) | |
| self.dropout = nn.Dropout(self.dropout_rate) | |
| self.normalization_layer = nn.LayerNorm(self.projection_dim) | |
| def forward(self, inputs): | |
| x = inputs | |
| x = self.linear_layer1(x) | |
| x = self.gelu(x) | |
| x = self.linear_layer2(x) | |
| x = self.dropout(x) | |
| x = self.normalization_layer(x) | |
| return x | |
| def __call__(self, inputs): | |
| return self.forward(inputs) | |
| class VideoEncoder(nn.Module): | |
| def __init__(self, model_name=CFG.video_name, projection_dim=CFG.projection_dim, | |
| trainable=False, dropout_rate=CFG.dropout_rate, *args, **kwargs): | |
| super(VideoEncoder, self).__init__(*args, **kwargs) | |
| # Attributes | |
| self.model_name = model_name | |
| self.projection_dim = projection_dim | |
| self.dropout_rate = dropout_rate | |
| self.trainable = trainable | |
| # Models | |
| self.pretrained_encoder = VideoMAEModel.from_pretrained(self.model_name) | |
| self.alignment_layer = AlignmentLayer( | |
| input_dim=self.pretrained_encoder.config.hidden_size, | |
| projection_dim=self.projection_dim, | |
| dropout_rate=self.dropout_rate) | |
| # Freeze VideoMAE | |
| for parameter in self.pretrained_encoder.parameters(): | |
| parameter.requires_grad = self.trainable | |
| def forward(self, inputs): | |
| x = self.pretrained_encoder(inputs).last_hidden_state | |
| x = self.alignment_layer(x) | |
| return x | |
| def __call__(self, inputs): | |
| return self.forward(inputs) | |
| class ModalityTokenEncoder(nn.Module): | |
| def __init__(self, projection_dim=CFG.projection_dim, token_size=CFG.token_size, device='cpu', *args, **kwargs): | |
| super(ModalityTokenEncoder, self).__init__(*args, **kwargs) | |
| # Attributes | |
| self.projection_dim = projection_dim | |
| self.device = device | |
| self.token_size = token_size | |
| # Models | |
| video_variance = torch.rand(1) * 0.5 + 0.1 | |
| self.video_token = nn.Parameter(torch.normal(mean=0, std=video_variance.item(), | |
| size=(self.token_size, self.projection_dim)).to(self.device)) | |
| def forward(self): | |
| return self.video_token | |
| def __call__(self): | |
| return self.forward() | |
| class OneEncoder(nn.Module): | |
| def __init__(self, device='cpu', modality_token_encoder=ModalityTokenEncoder(), checkpoint="bilalfaye/OneEncoder-text-image", | |
| video_processor=AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base"), | |
| video_encoder=VideoEncoder(), *args, **kwargs): | |
| super(OneEncoder, self).__init__(*args, **kwargs) | |
| self.device = device | |
| self.checkpoint = checkpoint | |
| self.modality_token_encoder = modality_token_encoder | |
| self.modality_token_encoder.device = self.device | |
| self.text_image_encoder = TextImageEncoder(device=self.device) | |
| self.text_image_encoder.from_pretrained(self.checkpoint) | |
| self.video_processor = video_processor | |
| self.video_encoder = video_encoder | |
| self.temperature = nn.Parameter(torch.tensor(0.07).to(self.device)) | |
| # Freeze | |
| for parameter in self.text_image_encoder.parameters(): | |
| parameter.requires_grad = False | |
| def load_video(cls, video_path): | |
| container = av.open(video_path) | |
| return container | |
| def read_video_pyav(cls, container, indices): | |
| """ | |
| Decode the video with PyAV decoder. | |
| Args: | |
| container (`av.container.input.InputContainer`): PyAV container. | |
| indices (`List[int]`): List of frame indices to decode. | |
| Returns: | |
| result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). | |
| """ | |
| frames = [] | |
| container.seek(0) | |
| start_index = indices[0] | |
| end_index = indices[-1] | |
| for i, frame in enumerate(container.decode(video=0)): | |
| if i > end_index: | |
| break | |
| if i >= start_index and i in indices: | |
| frames.append(frame) | |
| return np.stack([x.to_ndarray(format="rgb24") for x in frames]) | |
| def sample_frame_indices(cls, clip_len, frame_sample_rate, seg_len): | |
| """ | |
| Sample a given number of frame indices from the video. | |
| Args: | |
| clip_len (`int`): Total number of frames to sample. | |
| frame_sample_rate (`int`): Sample every n-th frame. | |
| seg_len (`int`): Maximum allowed index of sample's last frame. | |
| Returns: | |
| indices (`List[int]`): List of sampled frame indices | |
| """ | |
| converted_len = int(clip_len * frame_sample_rate) | |
| end_idx = np.random.randint(converted_len, seg_len) | |
| start_idx = end_idx - converted_len | |
| indices = np.linspace(start_idx, end_idx, num=clip_len) | |
| indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) | |
| return indices | |
| def encode_video(self, videos): | |
| """ | |
| :param videos: torch.Size([batch, 16, 3, 224, 224]) | |
| :return: torch.Size([batch, 1568, 768]) | |
| """ | |
| video_features = self.video_encoder(videos.to(self.device)) | |
| modality_token_features = self.modality_token_encoder() | |
| outputs = self.text_image_encoder.universal_projection_encoder([video_features, modality_token_features]).last_hidden_state | |
| return outputs | |