Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
| import math | |
| import random | |
| from collections import defaultdict | |
| from itertools import accumulate | |
| import nncore | |
| import numpy as np | |
| import termplotlib as tpl | |
| import torch | |
| from tabulate import tabulate | |
| from torch.utils.data import Dataset | |
| from videomind.constants import IGNORE_INDEX | |
| from videomind.dataset.utils import preprocess, process_vision_info | |
| from videomind.utils.parser import parse_span | |
| DATASETS = nncore.Registry('datasets') | |
| class HybridDataset(Dataset): | |
| def __init__(self, processor, model_config, model_args, data_args, training_args): | |
| super().__init__() | |
| datasets = [] | |
| for key in data_args.datasets.split(','): | |
| datasets.append(DATASETS.get(key)(processor, model_args, data_args, training_args)) | |
| data_types = [a['data_type'] for d in datasets for a in d.annos] | |
| cum_length = [0] + list(accumulate([len(d) for d in datasets])) | |
| idx_ranges = [[cum_length[i], cum_length[i + 1]] for i in range(len(cum_length) - 1)] | |
| if training_args.local_rank in (0, -1): | |
| raw_length = sum(d.raw_length for d in datasets) | |
| cur_length = idx_ranges[-1][-1] | |
| ratio = round(cur_length / raw_length * 100, 2) | |
| print(f'Number of samples: {raw_length} (original) -> {cur_length} (filtered) {ratio}%') | |
| data_type_cnt = ' '.join([f'{data_types.count(t)} ({t})' for t in list(set(data_types))]) | |
| print(f'Data types: {data_type_cnt}') | |
| tab = defaultdict(int) | |
| for dataset in datasets: | |
| for anno in dataset.annos: | |
| tab[anno.get('source', 'unknown')] += 1 | |
| tab = [[k, v, round(v / cur_length, 3)] for k, v in tab.items()] | |
| print(tabulate(tab, headers=['Source', '#Samples', 'Ratio'], tablefmt='pretty', stralign='left')) | |
| d, _ = torch.Tensor([a['duration'] for d in datasets for a in d.annos if 'duration' in a]).sort() | |
| if d.size(0) > 0: | |
| n, r = min(d.size(0), 10), d.flip(0) | |
| print(f'Top-{n} max video durations: {[round(r[i].item(), 1) for i in range(n)]}') | |
| print(f'Top-{n} min video durations: {[round(d[i].item(), 1) for i in range(n)]}') | |
| print(f'Average video duration ({d.size(0)} samples): {round(d.mean().item(), 1)}s') | |
| print('Video duration histogram:') | |
| counts, edges = np.histogram(d) | |
| labels = [f'{edges[i]:.2f}s - {edges[i + 1]:.2f}s' for i in range(len(edges) - 1)] | |
| fig = tpl.figure() | |
| fig.barh(counts, labels) | |
| fig.show() | |
| d, _ = torch.Tensor([abs(b[0] - b[1]) for d in datasets for a in d.annos if 'span' in a | |
| for b in a['span']]).sort() | |
| if d.size(0) > 0: | |
| n, r = min(d.size(0), 10), d.flip(0) | |
| print(f'Top-{n} max span durations: {[round(r[i].item(), 1) for i in range(n)]}') | |
| print(f'Top-{n} min span durations: {[round(d[i].item(), 1) for i in range(n)]}') | |
| print(f'Average span duration ({d.size(0)} samples): {round(d.mean().item(), 1)}s') | |
| print('Span duration histogram:') | |
| counts, edges = np.histogram(d) | |
| labels = [f'{edges[i]:.2f}s - {edges[i + 1]:.2f}s' for i in range(len(edges) - 1)] | |
| fig = tpl.figure() | |
| fig.barh(counts, labels) | |
| fig.show() | |
| self.datasets = datasets | |
| self.data_types = data_types | |
| self.idx_ranges = idx_ranges | |
| self.processor = processor | |
| self.model_config = model_config | |
| self.model_args = model_args | |
| self.data_args = data_args | |
| self.training_args = training_args | |
| def __len__(self): | |
| return self.idx_ranges[-1][-1] | |
| def __getitem__(self, idx): | |
| for retry in range(self.data_args.max_retries + 1): | |
| try: | |
| return self.fetch_data(idx) | |
| except Exception as e: | |
| print(f'Error in loading {idx}: {type(e).__name__}({e})') | |
| idx = random.choice([i for i, t in enumerate(self.data_types) if t == self.data_types[idx]]) | |
| raise RuntimeError(f'Data loading failed after {retry} retries') | |
| def map(self, *args, **kwargs): | |
| return self | |
| def fetch_data(self, idx): | |
| for (s, e), dataset in zip(self.idx_ranges, self.datasets): | |
| if s <= idx < e: | |
| meta = dataset[idx - s] | |
| break | |
| text = self.processor.apply_chat_template(meta['messages']) | |
| text = [text.strip()] | |
| images, videos = process_vision_info(meta['messages'], sanity_check=True) | |
| data = self.processor(text=text, images=images, videos=videos, return_tensors='pt') | |
| assert data['input_ids'].size(0) == 1 | |
| data['input_ids'] = data['input_ids'][0] | |
| data['labels'] = preprocess(data['input_ids'], text[0], self.processor.tokenizer, self.model_args.conv_type) | |
| # insert segment start/end tokens | |
| if 'ss' in meta and 'se' in meta: | |
| video_grid_thw = data['video_grid_thw'][0] | |
| num_frames, window = int(video_grid_thw[0]), int(video_grid_thw[1] * video_grid_thw[2] / 4) | |
| assert num_frames * window * 4 == data['pixel_values_videos'].size(0) | |
| pos_s, pos_e = round(meta['ss'] * num_frames), round(meta['se'] * num_frames) | |
| pos_s, pos_e = min(max(0, pos_s), num_frames), min(max(0, pos_e), num_frames) | |
| assert pos_s <= pos_e, (num_frames, meta['ss'], meta['se']) | |
| base_idx = torch.nonzero(data['input_ids'] == self.model_config.vision_start_token_id).item() | |
| pos_s, pos_e = pos_s * window + base_idx + 1, pos_e * window + base_idx + 2 | |
| input_ids = data['input_ids'].tolist() | |
| input_ids.insert(pos_s, self.model_config.seg_s_token_id) | |
| input_ids.insert(pos_e, self.model_config.seg_e_token_id) | |
| data['input_ids'] = torch.LongTensor(input_ids) | |
| labels = data['labels'].tolist() | |
| labels.insert(pos_s, IGNORE_INDEX) | |
| labels.insert(pos_e, IGNORE_INDEX) | |
| data['labels'] = torch.LongTensor(labels) | |
| if 'span' in meta: | |
| span, duration = meta['span'], meta['duration'] | |
| pixel_values_videos, video_grid_thw = data['pixel_values_videos'], data['video_grid_thw'] | |
| num_frames = int(video_grid_thw[0][0]) | |
| assert video_grid_thw.size(0) == 1 | |
| assert video_grid_thw.prod() == pixel_values_videos.size(0) | |
| # actual fps would be 1/2 of config (temporal patch size = 2) | |
| fps = num_frames / duration | |
| safe_span = [parse_span(b, duration, 1 / fps) for b in span] | |
| # num_reg_tokens -> num_bnds -> s & e | |
| timestamps = [[[s / duration, e / duration] for s, e in safe_span]] | |
| saliency, pos_inds = torch.zeros(num_frames), [] | |
| for s, e in safe_span: | |
| span_ind = max(0, s * fps), min(e * fps, num_frames) | |
| pos_inds = list(range(math.ceil(span_ind[0]), math.ceil(span_ind[1]))) | |
| assert len(pos_inds) > 0, f'empty pos_inds ({idx}): {fps} {num_frames} {duration} {span}' | |
| saliency[pos_inds] = 1 | |
| assert saliency.any(), f'empty saliency ({idx}): {pos_inds} {fps} {num_frames} {duration} {span}' | |
| pos_clip = random.sample(saliency.nonzero()[:, 0].tolist(), 1) | |
| pos_clip = torch.LongTensor(pos_clip) | |
| data['timestamps'] = timestamps | |
| data['saliency'] = saliency | |
| data['pos_clip'] = pos_clip | |
| return data | |