Spaces:
Runtime error
Runtime error
| import os.path as osp | |
| import random | |
| from glob import glob | |
| from torchvision import transforms | |
| import numpy as np | |
| import torch | |
| import torch.utils.data as data | |
| import torch.nn.functional as F | |
| from torchvision.transforms import Lambda | |
| from ....dataset.transform import ToTensorVideo, CenterCropVideo | |
| from ....utils.dataset_utils import DecordInit | |
| def TemporalRandomCrop(total_frames, size): | |
| """ | |
| Performs a random temporal crop on a video sequence. | |
| This function randomly selects a continuous frame sequence of length `size` from a video sequence. | |
| `total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped. | |
| Parameters: | |
| - total_frames (int): The total number of frames in the video sequence. | |
| - size (int): The length of the frame sequence to be cropped. | |
| Returns: | |
| - (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence, | |
| and the second integer is the ending frame index (inclusive) of the cropped sequence. | |
| """ | |
| rand_end = max(0, total_frames - size - 1) | |
| begin_index = random.randint(0, rand_end) | |
| end_index = min(begin_index + size, total_frames) | |
| return begin_index, end_index | |
| def resize(x, resolution): | |
| height, width = x.shape[-2:] | |
| resolution = min(2 * resolution, height, width) | |
| aspect_ratio = width / height | |
| if width <= height: | |
| new_width = resolution | |
| new_height = int(resolution / aspect_ratio) | |
| else: | |
| new_height = resolution | |
| new_width = int(resolution * aspect_ratio) | |
| resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True) | |
| return resized_x | |
| class VideoDataset(data.Dataset): | |
| """ Generic dataset for videos files stored in folders | |
| Returns BCTHW videos in the range [-0.5, 0.5] """ | |
| video_exts = ['avi', 'mp4', 'webm'] | |
| def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True): | |
| self.train = train | |
| self.sequence_length = sequence_length | |
| self.sample_rate = sample_rate | |
| self.resolution = resolution | |
| self.v_decoder = DecordInit() | |
| self.video_folder = video_folder | |
| self.dynamic_sample = dynamic_sample | |
| self.transform = transforms.Compose([ | |
| ToTensorVideo(), | |
| # Lambda(lambda x: resize(x, self.resolution)), | |
| CenterCropVideo(self.resolution), | |
| Lambda(lambda x: 2.0 * x - 1.0) | |
| ]) | |
| print('Building datasets...') | |
| self.samples = self._make_dataset() | |
| def _make_dataset(self): | |
| samples = [] | |
| samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True) | |
| for ext in self.video_exts], []) | |
| return samples | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| video_path = self.samples[idx] | |
| try: | |
| video = self.decord_read(video_path) | |
| video = self.transform(video) # T C H W -> T C H W | |
| video = video.transpose(0, 1) # T C H W -> C T H W | |
| return dict(video=video, label="") | |
| except Exception as e: | |
| print(f'Error with {e}, {video_path}') | |
| return self.__getitem__(random.randint(0, self.__len__()-1)) | |
| def decord_read(self, path): | |
| decord_vr = self.v_decoder(path) | |
| total_frames = len(decord_vr) | |
| # Sampling video frames | |
| if self.dynamic_sample: | |
| sample_rate = random.randint(1, self.sample_rate) | |
| else: | |
| sample_rate = self.sample_rate | |
| size = self.sequence_length * sample_rate | |
| start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size) | |
| # assert end_frame_ind - start_frame_ind >= self.num_frames | |
| frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int) | |
| video_data = decord_vr.get_batch(frame_indice).asnumpy() | |
| video_data = torch.from_numpy(video_data) | |
| video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W) | |
| return video_data |