Spaces:
Runtime error
Runtime error
| from torch.utils.data import DataLoader, Dataset, Sampler | |
| from pathlib import Path | |
| import json | |
| from multiprocessing import Pool | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import random | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as T | |
| from torch.utils.data.distributed import DistributedSampler | |
| from transformers import T5Tokenizer, BertTokenizer, BertTokenizerFast, CLIPTokenizer | |
| import text_utils | |
| project_dir = Path(__file__).parent.resolve() | |
| workspace_dir = project_dir.parent.parent | |
| dataset_dir = workspace_dir.joinpath('datasets/').resolve() | |
| # coco_dir = dataset_dir.joinpath('COCO') | |
| # vg_dir = dataset_dir.joinpath('VG') | |
| coco_img_dir = dataset_dir.joinpath('COCO/images/') | |
| coco_data_dir = project_dir.parent.joinpath('CLIP-ViL/CLIP-ViL-Direct/caption/data/') | |
| # coco_feature_dir = coco_dir.joinpath('features') | |
| class COCORetrievalDataset(Dataset): | |
| def __init__(self, split='karpathy_train', rank=-1, topk=-1, verbose=True, args=None, mode='train'): | |
| super().__init__() | |
| self.topk = topk | |
| self.verbose = verbose | |
| self.args = args | |
| self.rank = rank | |
| self.mode = mode | |
| # Loading datasets to data | |
| self.source = split | |
| if self.verbose: | |
| print('Data source: ', self.source) | |
| # if self.args.tokenizer is None: | |
| # self.args.tokenizer = self.args.decoder_backbone | |
| # if 'bert' in self.args.tokenizer: | |
| # self.tokenizer = BertTokenizerFast.from_pretrained( | |
| # self.args.tokenizer, | |
| # # max_length=self.args.max_text_length, | |
| # # do_lower_case=self.args.do_lower_case | |
| # ) | |
| # elif 'clip' in self.args.tokenizer: | |
| # self.tokenizer = CLIPTokenizer.from_pretrained( | |
| # self.args.tokenizer, | |
| # # max_length=self.args.max_text_length, | |
| # # do_lower_case=self.args.do_lower_case | |
| # ) | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| self.args.tokenizer, | |
| # max_length=self.args.max_text_length, | |
| # do_lower_case=self.args.do_lower_case | |
| ) | |
| with open(coco_data_dir.joinpath('cocotalk.json')) as f: | |
| self.vocab = list(json.load(f)['ix_to_word'].values()) | |
| popped = self.vocab.pop(-1) | |
| assert popped == 'UNK' | |
| if self.verbose: | |
| print('vocab size: ', len(self.vocab)) | |
| data_info_path = coco_data_dir.joinpath('dataset_coco.json') | |
| with open(data_info_path) as f: | |
| karpathy_data = json.load(f) | |
| split_rename = { | |
| 'train': 'train', | |
| 'restval': 'train', | |
| 'val': 'val', | |
| 'test': 'test' | |
| } | |
| n_images = 0 | |
| data = [] | |
| # self.vocab = set() | |
| for datum in karpathy_data['images']: | |
| re_split = split_rename[datum['split']] | |
| # if re_split == 'train': | |
| # for d in datum['sentences']: | |
| # self.vocab = self.vocab.union(set(d['tokens'])) | |
| if re_split != self.source.split('_')[-1]: | |
| continue | |
| if re_split == 'train': | |
| # for d in datum['sentences']: | |
| # img_id = datum['filename'].split('.')[0] | |
| # new_datum = { | |
| # 'filename': datum['filename'], | |
| # 'img_id': img_id, | |
| # 'sent': d['raw'].strip(), | |
| # 'targets': [d['raw'].strip() for d in datum['sentences']], | |
| # 'is_train': True, | |
| # 'cocoid': datum['cocoid'] | |
| # } | |
| # data.append(new_datum) | |
| img_id = datum['filename'].split('.')[0] | |
| new_datum = { | |
| 'filename': datum['filename'], | |
| 'img_id': img_id, | |
| # 'sent': d['raw'], | |
| # 'targets': [d['raw'].strip() for d in datum['sentences']], | |
| 'targets': [" ".join(d['tokens']) for d in datum['sentences']], | |
| 'is_train': True, | |
| 'cocoid': datum['cocoid'] | |
| } | |
| data.append(new_datum) | |
| else: | |
| img_id = datum['filename'].split('.')[0] | |
| new_datum = { | |
| 'filename': datum['filename'], | |
| 'img_id': img_id, | |
| # 'sent': d['raw'], | |
| # 'targets': [d['raw'].strip() for d in datum['sentences']], | |
| 'targets': [" ".join(d['tokens']) for d in datum['sentences']], | |
| 'is_train': False, | |
| 'cocoid': datum['cocoid'] | |
| } | |
| data.append(new_datum) | |
| n_images += 1 | |
| if self.verbose: | |
| print(f"{self.source} has {n_images} images") | |
| # print(f"Loaded {len(data)} data from", split) | |
| self.n_gpus = torch.cuda.device_count() | |
| if self.topk > 0: | |
| data = data[:self.topk] | |
| if self.verbose: | |
| print(f"Use only {self.topk} data") | |
| self.data = data | |
| # if self.verbose: | |
| # print("# all sentences:", len(self.data)) | |
| if self.args.load_feat: | |
| # feat_dir = coco_dir.joinpath('' | |
| # self.feat_loader = HybridLoader('/scratch-space/CLIP-ViL/CLIP-ViL-Direct/caption/data/cocotalk_clipscore_vis', ext='.npy', in_memory=False) | |
| self.feat_loader = HybridLoader( | |
| coco_data_dir.joinpath('cocotalk_clipscore_vis'), | |
| ext='.npy', in_memory=False) | |
| else: | |
| if 'openai/clip' in self.args.encoder_backbone: | |
| # from transformers import CLIPProcessor | |
| # self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", | |
| # size=args.image_size, | |
| # do_resize=True, | |
| # do_center_crop=False, | |
| # ) | |
| # self.img_transform = lambda image: self.processor.feature_extractor( | |
| # image, | |
| # return_tensors='pt')['pixel_values'][0] | |
| self.image_mean = [0.48145466, 0.4578275, 0.40821073] | |
| self.image_std = [0.26862954, 0.26130258, 0.27577711] | |
| # captioning | |
| # self.img_transform = T.Compose([ | |
| # T.Resize((self.args.image_size, self.args.image_size)) | |
| # ]) | |
| # retrieval | |
| self.img_transform = T.Compose([ | |
| T.Resize(self.args.image_size, interpolation=T.functional.InterpolationMode.BICUBIC), | |
| T.CenterCrop(self.args.image_size) | |
| ]) | |
| self.img_tensor_transform = T.Compose([ | |
| # T.RandomCrop(224), | |
| # T.RandomHorizontalFlip(p=0.3), | |
| T.ConvertImageDtype(torch.float), | |
| T.Normalize(self.image_mean, self.image_std) | |
| ] | |
| ) | |
| # elif 'google/vit' in self.args.encoder_backbone: | |
| # self.image_mean = [0.5, 0.5, 0.5] | |
| # self.image_std = [0.5, 0.5, 0.5] | |
| # self.img_transform = T.Compose([ | |
| # # T.PILToTensor(), | |
| # T.Resize((self.args.image_size, self.args.image_size)) | |
| # ]) | |
| # self.img_tensor_transform = T.Compose([ | |
| # # T.RandomCrop(224), | |
| # # T.RandomHorizontalFlip(p=0.3), | |
| # T.ConvertImageDtype(torch.float), | |
| # T.Normalize(self.image_mean, self.image_std) | |
| # ] | |
| # ) | |
| def get_negative_text(self, text): | |
| neg_type = random.choice(['repeat', 'remove', 'insert', 'swap', 'shuffle']) | |
| if neg_type == 'repeat': | |
| text = text_utils.repeat(text) | |
| elif neg_type == 'remove': | |
| text = text_utils.remove(text) | |
| elif neg_type == 'insert': | |
| text = text_utils.insert(text, self.vocab) | |
| elif neg_type == 'swap': | |
| text = text_utils.swap(text, self.vocab) | |
| elif neg_type == 'shuffle': | |
| text = text_utils.shuffle(text) | |
| return text, neg_type | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| datum = self.data[idx] | |
| return self.process_datum(datum) | |
| def process_datum(self, datum): | |
| out_dict = {} | |
| ###### Image ###### | |
| if self.args.load_feat: | |
| cocoid = datum['cocoid'] | |
| out_dict['cocoid'] = str(cocoid) | |
| img_feat = self.feat_loader.get(str(cocoid)) | |
| out_dict['img_feat'] = torch.from_numpy(img_feat) | |
| else: | |
| img_id = datum['img_id'] | |
| out_dict['img_id'] = img_id | |
| if 'train' in datum['filename']: | |
| img_split = 'train2014' | |
| elif 'val' in datum['filename']: | |
| img_split = 'val2014' | |
| img_path = coco_img_dir.joinpath(img_split).joinpath(datum['filename']).with_suffix('.jpg') | |
| assert img_path.exists() | |
| img_path = str(img_path) | |
| out_dict['img_path'] = img_path | |
| img_tensor = torchvision.io.read_image(img_path) | |
| # out_dict['img_tensor'] = img | |
| # img = Image.open(img_path).convert('RGB') | |
| # img_tensor = torch.as_tensor(np.asarray(img)) | |
| out_dict['img_tensor'] = self.img_transform(img_tensor) | |
| # self.img_transform(img_tensor) | |
| # out_dict['img_tensor'] = self.img_transform(img) | |
| ###### Text ##### | |
| # if datum['is_train']: | |
| # sent = datum['sent'].strip() | |
| sent = random.choice(datum['targets']) | |
| # target_ids = self.tokenizer.encode( | |
| # sent, max_length=self.args.gen_max_length, truncation=True) | |
| # assert len(target_ids) <= self.args.gen_max_length, len(target_ids) | |
| out_dict['sent'] = sent | |
| # out_dict['target_ids'] = torch.LongTensor(target_ids) | |
| # out_dict['target_length'] = len(target_ids) | |
| # negative sample | |
| neg_sent, neg_type = self.get_negative_text(sent) | |
| # neg_target_ids = self.tokenizer.encode( | |
| # neg_sent, max_length=self.args.gen_max_length, truncation=True) | |
| # assert len(neg_target_ids) <= self.args.gen_max_length, len(neg_target_ids) | |
| out_dict['neg_sent'] = neg_sent | |
| out_dict['neg_type'] = neg_type | |
| # out_dict['neg_target_ids'] = torch.LongTensor(neg_target_ids) | |
| # out_dict['neg_target_length'] = len(neg_target_ids) | |
| if 'targets' in datum: | |
| out_dict['targets'] = datum['targets'] | |
| return out_dict | |
| def collate_fn(self, batch): | |
| batch_entry = {} | |
| B = len(batch) | |
| # if 'target_ids' in batch[0]: | |
| # T_W_L = max(entry['target_length'] for entry in batch) | |
| # target_ids = torch.ones( | |
| # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id | |
| # if 'target_ids' in batch[0]: | |
| # T_W_L = max(entry['target_length'] for entry in batch) | |
| # target_ids = torch.ones( | |
| # B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id | |
| targets = [] | |
| img_ids = [] | |
| img_paths = [] | |
| coco_ids = [] | |
| if self.args.load_feat: | |
| img_feats = torch.zeros(B, 512, dtype=torch.float) | |
| else: | |
| # imgs = [] | |
| img_tensor = torch.zeros(B, 3, self.args.image_size, self.args.image_size, dtype=torch.uint8) | |
| for i, entry in enumerate(batch): | |
| if self.args.load_feat: | |
| coco_ids.append(entry['cocoid']) | |
| img_feats[i] = entry['img_feat'] | |
| else: | |
| img_ids.append(entry['img_id']) | |
| img_paths.append(entry['img_path']) | |
| img_tensor[i] = entry['img_tensor'] | |
| # if 'target_ids' in entry: | |
| # target_ids[i, :entry['target_length']] = entry['target_ids'] | |
| if 'targets' in entry: | |
| targets.append(entry['targets']) | |
| if 'sent' in batch[0]: | |
| # word_mask = target_ids != self.tokenizer.pad_token_id | |
| # target_ids[~word_mask] = -100 | |
| # batch_entry['target_ids'] = target_ids | |
| tokenized = self.tokenizer([entry['sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt') | |
| neg_tokenized = self.tokenizer([entry['neg_sent'] for entry in batch], truncation=True, padding=True, return_tensors='pt') | |
| # sent, max_length=self.args.gen_max_length, truncation=True) | |
| batch_entry['text'] = (tokenized.input_ids, tokenized.attention_mask) | |
| batch_entry['neg_text'] = (neg_tokenized.input_ids, neg_tokenized.attention_mask) | |
| if self.args.load_feat: | |
| batch_entry['coco_ids'] = coco_ids | |
| batch_entry['img_feats'] = img_feats | |
| else: | |
| img_tensor = self.img_tensor_transform(img_tensor) | |
| batch_entry['img_id'] = img_ids | |
| batch_entry['img_paths'] = img_paths | |
| batch_entry['img_tensor'] = img_tensor | |
| batch_entry['targets'] = targets | |
| # print('batch created') | |
| # batch_entry['task'] = 'caption' | |
| return batch_entry | |
| # def get_loader(args, split='karpathy_train', mode='train', | |
| # batch_size=32, workers=4, distributed=False, gpu=0, | |
| # topk=-1): | |
| # verbose = (gpu == 0) | |
| # dataset = COCORetrievalDataset( | |
| # split, | |
| # rank=gpu, | |
| # topk=topk, | |
| # verbose=verbose, | |
| # args=args, | |
| # mode=mode) | |
| # # if distributed: | |
| # # sampler = DistributedSampler(dataset) | |
| # # else: | |
| # # sampler = None | |
| # if mode == 'train': | |
| # loader = DataLoader( | |
| # dataset, batch_size=batch_size, shuffle=(sampler is None), | |
| # num_workers=workers, pin_memory=True, sampler=sampler, | |
| # collate_fn=dataset.collate_fn) | |
| # else: | |
| # loader = DataLoader( | |
| # dataset, | |
| # batch_size=batch_size, shuffle=False, | |
| # num_workers=workers, pin_memory=True, | |
| # sampler=sampler, | |
| # collate_fn=dataset.collate_fn, | |
| # drop_last=False) | |
| # # if verbose: | |
| # # loader.evaluator = COCOCaptionEvaluator() | |
| # # loader.task = 'caption' | |
| # return loader | |
| # class COCOCaptionEvaluator: | |
| # def __init__(self): | |
| # import language_evaluation | |
| # self.evaluator = language_evaluation.CocoEvaluator(verbose=False) | |
| # def evaluate(self, predicts, answers): | |
| # results = self.evaluator.run_evaluation(predicts, answers) | |
| # return results | |
| import six | |
| import os | |
| import h5py | |
| class HybridLoader: | |
| """ | |
| If db_path is a director, then use normal file loading | |
| If lmdb, then load from lmdb | |
| The loading method depend on extention. | |
| in_memory: if in_memory is True, we save all the features in memory | |
| For individual np(y|z)s, we don't need to do that because the system will do this for us. | |
| Should be useful for lmdb or h5. | |
| (Copied this idea from vilbert) | |
| """ | |
| def __init__(self, db_path, ext='.npy', in_memory=False): | |
| self.db_path = db_path | |
| self.ext = ext | |
| if self.ext == '.npy': | |
| self.loader = lambda x: np.load(six.BytesIO(x)) | |
| else: | |
| self.loader = lambda x: np.load(six.BytesIO(x))['feat'] | |
| # if db_path.endswith('.lmdb'): | |
| # self.db_type = 'lmdb' | |
| # self.lmdb = lmdbdict(db_path, unsafe=True) | |
| # self.lmdb._key_dumps = DUMPS_FUNC['ascii'] | |
| # self.lmdb._value_loads = LOADS_FUNC['identity'] | |
| # elif db_path.endswith('.pth'): # Assume a key,value dictionary | |
| # self.db_type = 'pth' | |
| # self.feat_file = torch.load(db_path) | |
| # self.loader = lambda x: x | |
| # print('HybridLoader: ext is ignored') | |
| # elif db_path.endswith('h5'): | |
| # self.db_type = 'h5' | |
| # self.loader = lambda x: np.array(x).astype('float32') | |
| # else: | |
| # self.db_type = 'dir' | |
| self.in_memory = in_memory | |
| if self.in_memory: | |
| self.features = {} | |
| def get(self, key): | |
| # if self.in_memory and key in self.features: | |
| # # We save f_input because we want to save the | |
| # # compressed bytes to save memory | |
| # f_input = self.features[key] | |
| # elif self.db_type == 'lmdb': | |
| # f_input = self.lmdb[key] | |
| # elif self.db_type == 'pth': | |
| # f_input = self.feat_file[key] | |
| # elif self.db_type == 'h5': | |
| # f_input = h5py.File(self.db_path, 'r')[key] | |
| # else: | |
| # f_input = open(os.path.join( | |
| # self.db_path, key + self.ext), 'rb').read() | |
| f_input = open(os.path.join( | |
| self.db_path, key + self.ext), 'rb').read() | |
| if self.in_memory and key not in self.features: | |
| self.features[key] = f_input | |
| # load image | |
| feat = self.loader(f_input) | |
| return feat | |