Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
| import argparse | |
| import copy | |
| import nncore | |
| import torch | |
| from videomind.constants import GROUNDER_PROMPT | |
| from videomind.dataset.hybrid import DATASETS | |
| from videomind.dataset.utils import process_vision_info | |
| from videomind.model.builder import build_model | |
| from videomind.utils.io import get_duration | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--dataset') | |
| parser.add_argument('--pred_path') | |
| parser.add_argument('--model_gnd_path') | |
| parser.add_argument('--split', default='test', choices=['train', 'valid', 'test']) | |
| parser.add_argument('--num_threads', type=int, default=1) | |
| parser.add_argument('--device', default='auto') | |
| parser.add_argument('--chunk', type=int, default=1) | |
| parser.add_argument('--index', type=int, default=0) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| if args.chunk > 1: | |
| pred_path = nncore.join(args.pred_path, f'output_{args.index}.jsonl') | |
| else: | |
| pred_path = nncore.join(args.pred_path, 'output.jsonl') | |
| print(f'Dataset: {args.dataset}({args.split}) Chunk: {args.chunk} Index: {args.index} Output Path: {pred_path}') | |
| model, processor = build_model(args.model_gnd_path, device=args.device) | |
| device = next(model.parameters()).device | |
| annos = DATASETS.get(args.dataset).load_annos(split=args.split) | |
| annos = [annos[i::args.chunk] for i in range(args.chunk)][args.index] | |
| dumps = [] | |
| for i in nncore.ProgressBar(range(len(annos))): | |
| anno = copy.deepcopy(annos[i]) | |
| dump = dict() | |
| video_path, query, duration, span = anno['video_path'], anno['query'], anno.get('duration'), anno.get('span') | |
| if duration is None: | |
| duration = get_duration(video_path, num_threads=args.num_threads) | |
| print() | |
| print(video_path) | |
| print(duration) | |
| print(query) | |
| messages = [{ | |
| 'role': | |
| 'user', | |
| 'content': [{ | |
| 'type': 'video', | |
| 'video': video_path, | |
| 'num_threads': args.num_threads, | |
| 'min_pixels': 36 * 28 * 28, | |
| 'max_pixels': 64 * 28 * 28, | |
| 'max_frames': 150, | |
| 'fps': 1.0 | |
| }, { | |
| 'type': 'text', | |
| 'text': GROUNDER_PROMPT.format(query) | |
| }] | |
| }] | |
| text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| print(text) | |
| images, videos = process_vision_info(messages) | |
| data = processor(text=[text], images=images, videos=videos, return_tensors='pt') | |
| data = data.to(device) | |
| output_ids = model.generate( | |
| **data, | |
| do_sample=False, | |
| temperature=None, | |
| top_p=None, | |
| top_k=None, | |
| repetition_penalty=None, | |
| max_new_tokens=256) | |
| assert data.input_ids.size(0) == output_ids.size(0) == 1 | |
| output_ids = output_ids[0, data.input_ids.size(1):] | |
| if output_ids[-1] == processor.tokenizer.eos_token_id: | |
| output_ids = output_ids[:-1] | |
| response = processor.decode(output_ids, clean_up_tokenization_spaces=False) | |
| print(response) | |
| grounder_success = len(model.reg) > 0 | |
| if grounder_success: | |
| # 1. extract timestamps and confidences | |
| blob = model.reg[0].cpu().float() | |
| pred, conf = blob[:, :2] * duration, blob[:, 2:] | |
| print(pred[0], span, duration) | |
| # 2. clamp timestamps | |
| pred = pred.clamp(min=0, max=duration) | |
| # 3. round timestamps to units | |
| unit = getattr(DATASETS.get(args.dataset), 'UNIT', 0.001) | |
| pred = torch.round(pred / unit).long() * unit | |
| # 4. sort timestamps | |
| inds = (pred[:, 1] - pred[:, 0] < 0).nonzero()[:, 0] | |
| pred[inds] = pred[inds].roll(1) | |
| # 5. merge timestamps back with confidences | |
| pred = torch.cat((pred, conf), dim=1) | |
| else: | |
| print('WARNING: Failed to parse grounder response') | |
| pred = torch.Tensor([[0, duration, 1]]) | |
| print(pred[0], span, duration) | |
| dump['vid'] = anno['vid'] | |
| dump['qid'] = anno['qid'] | |
| dump['pred_relevant_windows'] = pred.tolist() | |
| dumps.append(dump) | |
| nncore.dump(dumps, pred_path) | |