Spaces:
Runtime error
Runtime error
| # copy from https://github.com/Lyken17/Efficient-PyTorch/tools | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| import os.path as osp | |
| import os, sys | |
| import os.path as osp | |
| from PIL import Image | |
| import six | |
| import string | |
| from lmdbdict import lmdbdict | |
| from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC | |
| import pickle | |
| import tqdm | |
| import numpy as np | |
| import argparse | |
| import json | |
| import torch | |
| import torch.utils.data as data | |
| from torch.utils.data import DataLoader | |
| import csv | |
| csv.field_size_limit(sys.maxsize) | |
| FIELDNAMES = ['image_id', 'status'] | |
| class FolderLMDB(data.Dataset): | |
| def __init__(self, db_path, fn_list=None): | |
| self.db_path = db_path | |
| self.lmdb = lmdbdict(db_path, unsafe=True) | |
| self.lmdb._key_dumps = DUMPS_FUNC['ascii'] | |
| self.lmdb._value_loads = LOADS_FUNC['identity'] | |
| if fn_list is not None: | |
| self.length = len(fn_list) | |
| self.keys = fn_list | |
| else: | |
| raise Error | |
| def __getitem__(self, index): | |
| byteflow = self.lmdb[self.keys[index]] | |
| # load image | |
| imgbuf = byteflow | |
| buf = six.BytesIO() | |
| buf.write(imgbuf) | |
| buf.seek(0) | |
| try: | |
| if args.extension == '.npz': | |
| feat = np.load(buf)['feat'] | |
| else: | |
| feat = np.load(buf) | |
| except Exception as e: | |
| print(self.keys[index], e) | |
| return None | |
| return feat | |
| def __len__(self): | |
| return self.length | |
| def __repr__(self): | |
| return self.__class__.__name__ + ' (' + self.db_path + ')' | |
| def make_dataset(dir, extension): | |
| images = [] | |
| dir = os.path.expanduser(dir) | |
| for root, _, fnames in sorted(os.walk(dir)): | |
| for fname in sorted(fnames): | |
| if has_file_allowed_extension(fname, [extension]): | |
| path = os.path.join(root, fname) | |
| images.append(path) | |
| return images | |
| def raw_reader(path): | |
| with open(path, 'rb') as f: | |
| bin_data = f.read() | |
| return bin_data | |
| def raw_npz_reader(path): | |
| with open(path, 'rb') as f: | |
| bin_data = f.read() | |
| try: | |
| npz_data = np.load(six.BytesIO(bin_data))['feat'] | |
| except Exception as e: | |
| print(path) | |
| npz_data = None | |
| print(e) | |
| return bin_data, npz_data | |
| def raw_npy_reader(path): | |
| with open(path, 'rb') as f: | |
| bin_data = f.read() | |
| try: | |
| npy_data = np.load(six.BytesIO(bin_data)) | |
| except Exception as e: | |
| print(path) | |
| npy_data = None | |
| print(e) | |
| return bin_data, npy_data | |
| class Folder(data.Dataset): | |
| def __init__(self, root, loader, extension, fn_list=None): | |
| super(Folder, self).__init__() | |
| self.root = root | |
| if fn_list: | |
| samples = [os.path.join(root, str(_)+extension) for _ in fn_list] | |
| else: | |
| samples = make_dataset(self.root, extension) | |
| self.loader = loader | |
| self.extension = extension | |
| self.samples = samples | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: (sample, target) where target is class_index of the target class. | |
| """ | |
| path = self.samples[index] | |
| sample = self.loader(path) | |
| return (path.split('/')[-1].split('.')[0],) + sample | |
| def __len__(self): | |
| return len(self.samples) | |
| def folder2lmdb(dpath, fn_list, write_frequency=5000): | |
| directory = osp.expanduser(osp.join(dpath)) | |
| print("Loading dataset from %s" % directory) | |
| if args.extension == '.npz': | |
| dataset = Folder(directory, loader=raw_npz_reader, extension='.npz', | |
| fn_list=fn_list) | |
| else: | |
| dataset = Folder(directory, loader=raw_npy_reader, extension='.npy', | |
| fn_list=fn_list) | |
| data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x) | |
| # lmdb_path = osp.join(dpath, "%s.lmdb" % (directory.split('/')[-1])) | |
| lmdb_path = osp.join("%s.lmdb" % (directory)) | |
| isdir = os.path.isdir(lmdb_path) | |
| print("Generate LMDB to %s" % lmdb_path) | |
| db = lmdbdict(lmdb_path, mode='w', key_method='ascii', value_method='identity') | |
| tsvfile = open(args.output_file, 'a') | |
| writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) | |
| names = [] | |
| all_keys = [] | |
| for idx, data in enumerate(tqdm.tqdm(data_loader)): | |
| # print(type(data), data) | |
| name, byte, npz = data[0] | |
| if npz is not None: | |
| db[name] = byte | |
| all_keys.append(name) | |
| names.append({'image_id': name, 'status': str(npz is not None)}) | |
| if idx % write_frequency == 0: | |
| print("[%d/%d]" % (idx, len(data_loader))) | |
| print('writing') | |
| db.flush() | |
| # write in tsv | |
| for name in names: | |
| writer.writerow(name) | |
| names = [] | |
| tsvfile.flush() | |
| print('writing finished') | |
| # write all keys | |
| # txn.put("keys".encode(), pickle.dumps(all_keys)) | |
| # # finish iterating through dataset | |
| # txn.commit() | |
| for name in names: | |
| writer.writerow(name) | |
| tsvfile.flush() | |
| tsvfile.close() | |
| print("Flushing database ...") | |
| db.flush() | |
| del db | |
| def parse_args(): | |
| """ | |
| Parse input arguments | |
| """ | |
| parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') | |
| # parser.add_argument('--json) | |
| parser.add_argument('--input_json', default='./data/dataset_coco.json', type=str) | |
| parser.add_argument('--output_file', default='.dump_cache.tsv', type=str) | |
| parser.add_argument('--folder', default='./data/cocobu_att', type=str) | |
| parser.add_argument('--extension', default='.npz', type=str) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| global args | |
| args = parse_args() | |
| args.output_file += args.folder.split('/')[-1] | |
| if args.folder.find('/') > 0: | |
| args.output_file = args.folder[:args.folder.rfind('/')+1]+args.output_file | |
| print(args.output_file) | |
| img_list = json.load(open(args.input_json, 'r'))['images'] | |
| fn_list = [str(_['cocoid']) for _ in img_list] | |
| found_ids = set() | |
| try: | |
| with open(args.output_file, 'r') as tsvfile: | |
| reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) | |
| for item in reader: | |
| if item['status'] == 'True': | |
| found_ids.add(item['image_id']) | |
| except: | |
| pass | |
| fn_list = [_ for _ in fn_list if _ not in found_ids] | |
| folder2lmdb(args.folder, fn_list) | |
| # Test existing. | |
| found_ids = set() | |
| with open(args.output_file, 'r') as tsvfile: | |
| reader = csv.DictReader(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) | |
| for item in reader: | |
| if item['status'] == 'True': | |
| found_ids.add(item['image_id']) | |
| folder_dataset = FolderLMDB(args.folder+'.lmdb', list(found_ids)) | |
| data_loader = DataLoader(folder_dataset, num_workers=16, collate_fn=lambda x: x) | |
| for data in tqdm.tqdm(data_loader): | |
| assert data[0] is not None |