Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| from collections import OrderedDict | |
| import os | |
| import os.path as osp | |
| import pickle | |
| import time | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms._transforms_video as transforms_video | |
| from lavila.data import datasets | |
| from lavila.data.video_transforms import Permute | |
| from lavila.models import models | |
| from lavila.utils.preprocess import generate_tokenizer | |
| from lavila.utils import distributed as dist_utils | |
| from eval_narrator import decode_one | |
| class IndexedDataset(torch.utils.data.Dataset): | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| def __getitem__(self, index): | |
| return index, self.dataset[index] | |
| def __len__(self): | |
| return len(self.dataset) | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser(description='lavila infer narrator', add_help=False) | |
| parser.add_argument('--dataset', default='ego4d', type=str, choices=['ego4d']) | |
| parser.add_argument('--root', | |
| default='datasets/Ego4D/video_5min_chunks_288px/', | |
| type=str, help='path to dataset root') | |
| parser.add_argument('--metadata', | |
| default='datasets/Ego4D/ego4d_train.pkl', | |
| type=str, help='path to metadata file') | |
| parser.add_argument('--output-dir', default='./', type=str, help='output dir') | |
| parser.add_argument('--batch-size', default=64, type=int) | |
| parser.add_argument('--use-half', action='store_true') | |
| parser.add_argument('--clip-length', default=4, type=int, help='clip length') | |
| parser.add_argument('--clip-stride', default=16, type=int, help='clip stride') | |
| parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') | |
| parser.add_argument('--caption-sample', default='multinomial_sample', | |
| choices=['multinomial_sample', 'beam_sample', 'group_beam_search']) | |
| parser.add_argument('--caption-top-k', default=None, type=int) | |
| parser.add_argument('--caption-top-p', default=0.95, type=float) | |
| parser.add_argument('--caption-num-beams', default=1, type=int) | |
| parser.add_argument('--caption-num-beam-groups', default=1, type=int) | |
| parser.add_argument('--caption-temperature', default=0.7, type=float) | |
| parser.add_argument('--caption-length-penalty', default=1.0, type=float) | |
| parser.add_argument('--caption-num-return-sequences', default=10, type=int) | |
| parser.add_argument('--caption-max-len', default=77, type=int) | |
| parser.add_argument('--caption-early-stop', action='store_true', help='early stopping to save computation') | |
| # System | |
| parser.add_argument('--print-freq', default=10, type=int, help='print frequency') | |
| parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', | |
| help='number of data loading workers per process') | |
| parser.add_argument('--world-size', default=1, type=int, | |
| help='number of nodes for distributed training') | |
| parser.add_argument('--rank', default=0, type=int, | |
| help='node rank for distributed training') | |
| parser.add_argument("--local_rank", type=int, default=0) | |
| parser.add_argument('--dist-url', default='env://', type=str, | |
| help='url used to set up distributed training') | |
| parser.add_argument('--dist-backend', default='nccl', type=str) | |
| parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') | |
| return parser | |
| def main(args): | |
| dist_utils.init_distributed_mode(args) | |
| print(args) | |
| if args.resume: | |
| ckpt_path = args.resume | |
| elif osp.isfile(osp.join(args.output_dir, 'checkpoint_best.pt')): | |
| ckpt_path = osp.join(args.output_dir, 'checkpoint_best.pt') | |
| else: | |
| raise Exception('no checkpoint found') | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| state_dict = OrderedDict() | |
| for k, v in ckpt['state_dict'].items(): | |
| state_dict[k.replace('module.', '')] = v | |
| # create model | |
| old_args = ckpt['args'] | |
| print('=> creating model: {}'.format(old_args.model)) | |
| model = getattr(models, old_args.model)( | |
| text_use_cls_token=old_args.use_cls_token, | |
| gated_xattn=old_args.gated_xattn, | |
| timesformer_gated_xattn=old_args.timesformer_gated_xattn, | |
| num_frames=old_args.clip_length, | |
| drop_path_rate=0, | |
| ) | |
| model.cuda() | |
| model.load_state_dict(state_dict, strict=True) | |
| print("=> loaded resume checkpoint '{}' (epoch {})".format(args.resume, ckpt['epoch'])) | |
| torch.backends.cudnn.benchmark = True | |
| # Data loading | |
| print("=> creating dataset") | |
| tokenizer = generate_tokenizer(old_args.model) | |
| crop_size = 224 if '336PX' not in old_args.model else 336 | |
| val_transform = transforms.Compose([ | |
| Permute([3, 0, 1, 2]), # T H W C -> C T H W | |
| transforms.Resize(crop_size), | |
| transforms.CenterCrop(crop_size), | |
| (transforms_video.NormalizeVideo(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]) if 'OPENAI' not in old_args.model else | |
| transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])), | |
| ]) | |
| val_dataset = datasets.VideoCaptionDatasetCLIP( | |
| args.dataset, | |
| args.root, | |
| args.metadata, | |
| transform=val_transform, | |
| is_training=False, | |
| tokenizer=tokenizer, | |
| clip_length=args.clip_length, | |
| clip_stride=args.clip_stride, | |
| sparse_sample=False, | |
| subsample_stride=1, | |
| ) | |
| val_dataset = IndexedDataset(val_dataset) | |
| print(len(val_dataset)) | |
| if args.distributed: | |
| val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) | |
| else: | |
| val_sampler = None | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False | |
| ) | |
| print('len(val_loader) = {}'.format(len(val_loader))) | |
| model.eval() | |
| if args.use_half: | |
| model.half() | |
| id_offset = 0 | |
| all_captions_cache = [] | |
| end = time.time() | |
| with torch.no_grad(): | |
| for data_iter, (indices, inputs) in enumerate(val_loader): | |
| indices = indices.tolist() | |
| if data_iter % args.print_freq == 0: | |
| print("finished {}/{} in {}".format(data_iter, len(val_loader), time.time() - end)) | |
| end = time.time() | |
| if len(inputs) == 2 or len(inputs) == 3: | |
| images = inputs[0].cuda(non_blocking=True) | |
| if args.use_half: | |
| images = images.half() | |
| image_features = dist_utils.get_model(model).encode_image(images) | |
| if not isinstance(image_features, (list, tuple)): | |
| image_tokens = image_features | |
| else: | |
| image_tokens = image_features[1] | |
| if args.caption_sample == 'multinomial_sample': | |
| generated_text_ids, ppls = dist_utils.get_model(model).generate( | |
| image_tokens, | |
| tokenizer, | |
| target=None, | |
| max_text_length=args.caption_max_len, | |
| top_k=args.caption_top_k, | |
| top_p=args.caption_top_p, | |
| num_return_sequences=args.caption_num_return_sequences, | |
| temperature=args.caption_temperature, | |
| early_stopping=args.caption_early_stop, | |
| ) | |
| elif args.caption_sample == 'beam_sample': | |
| generated_text_ids, ppls = dist_utils.get_model(model).beam_sample( | |
| image_tokens, | |
| tokenizer, | |
| target=None, | |
| max_text_length=args.caption_max_len, | |
| top_k=args.caption_top_k, | |
| top_p=args.caption_top_p, | |
| temperature=args.caption_temperature, | |
| length_penalty=args.caption_length_penalty, | |
| num_beams=args.caption_num_beams, | |
| num_return_sequences=args.caption_num_return_sequences, | |
| ) | |
| elif args.caption_sample == 'group_beam_search': | |
| assert args.caption_num_beam_groups > 1 and args.caption_num_beams % args.caption_num_beam_groups == 0 | |
| generated_text_ids, ppls = dist_utils.get_model(model).group_beam_search( | |
| image_tokens, | |
| tokenizer, | |
| target=None, | |
| max_text_length=args.caption_max_len, | |
| top_k=args.caption_top_k, | |
| top_p=args.caption_top_p, | |
| temperature=args.caption_temperature, | |
| length_penalty=args.caption_length_penalty, | |
| num_beams=args.caption_num_beams, | |
| num_beam_groups=args.caption_num_beam_groups, | |
| num_return_sequences=args.caption_num_return_sequences, | |
| ) | |
| for j in range(generated_text_ids.shape[0] // args.caption_num_return_sequences): | |
| generated_text_str_list = [] | |
| ppls_list = [] | |
| for k in range(args.caption_num_return_sequences): | |
| jj = j * args.caption_num_return_sequences + k | |
| generated_text_str = decode_one(generated_text_ids[jj], tokenizer) | |
| generated_text_str_list.append(generated_text_str) | |
| ppls_list.append(ppls[jj].item()) | |
| video_uid, t_start, t_end, _ = val_loader.dataset.dataset.samples[indices[j]] | |
| if args.caption_num_return_sequences == 1: | |
| all_captions_cache.append((video_uid, t_start, t_end, generated_text_str, ppls[jj].item())) | |
| else: | |
| all_captions_cache.append((video_uid, t_start, t_end, generated_text_str_list, ppls_list)) | |
| id_offset += generated_text_ids.shape[0] | |
| pickle.dump(all_captions_cache, open(osp.join(args.output_dir, 'cache.{}.pkl'.format(args.rank)), 'wb')) | |
| torch.distributed.barrier() | |
| disorded_list = [] | |
| total_num = 0 | |
| if args.rank == 0: | |
| for i in range(args.world_size): | |
| print('=> reading {}'.format(osp.join(args.output_dir, f'cache.{i}.pkl'))) | |
| sublist = pickle.load(open(osp.join(args.output_dir, f'cache.{i}.pkl'), 'rb')) | |
| disorded_list.append(sublist) | |
| total_num += len(sublist) | |
| ordered_list = [] | |
| for i in range(total_num): | |
| ordered_list.append(disorded_list[i % args.world_size][i // args.world_size]) | |
| print(f"{len(val_dataset)}/{len(ordered_list)}") | |
| ordered_list = ordered_list[:len(val_dataset)] | |
| pickle.dump(ordered_list, open(osp.join(args.output_dir, 'total.pkl'), 'wb')) | |
| for i in range(args.world_size): | |
| print('=> deleting {}'.format(osp.join(args.output_dir, f'cache.{i}.pkl'))) | |
| os.remove(osp.join(args.output_dir, f'cache.{i}.pkl')) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser('lavila infer narrator', parents=[get_args_parser()]) | |
| args = parser.parse_args() | |
| main(args) | |