Spaces:
Runtime error
Runtime error
| """ | |
| Precompute ngram counts of captions, to accelerate cider computation during training time. | |
| """ | |
| import os | |
| import json | |
| import argparse | |
| from six.moves import cPickle | |
| import captioning.utils.misc as utils | |
| from collections import defaultdict | |
| import sys | |
| sys.path.append("cider") | |
| from pyciderevalcap.ciderD.ciderD_scorer import CiderScorer | |
| def get_doc_freq(refs, params): | |
| tmp = CiderScorer(df_mode="corpus") | |
| for ref in refs: | |
| tmp.cook_append(None, ref) | |
| tmp.compute_doc_freq() | |
| return tmp.document_frequency, len(tmp.crefs) | |
| def build_dict(imgs, wtoi, params): | |
| wtoi['<eos>'] = 0 | |
| count_imgs = 0 | |
| refs_words = [] | |
| refs_idxs = [] | |
| for img in imgs: | |
| if (params['split'] == img['split']) or \ | |
| (params['split'] == 'train' and img['split'] == 'restval') or \ | |
| (params['split'] == 'all'): | |
| #(params['split'] == 'val' and img['split'] == 'restval') or \ | |
| ref_words = [] | |
| ref_idxs = [] | |
| for sent in img['sentences']: | |
| if hasattr(params, 'bpe'): | |
| sent['tokens'] = params.bpe.segment(' '.join(sent['tokens'])).strip().split(' ') | |
| tmp_tokens = sent['tokens'] + ['<eos>'] | |
| tmp_tokens = [_ if _ in wtoi else 'UNK' for _ in tmp_tokens] | |
| ref_words.append(' '.join(tmp_tokens)) | |
| ref_idxs.append(' '.join([str(wtoi[_]) for _ in tmp_tokens])) | |
| refs_words.append(ref_words) | |
| refs_idxs.append(ref_idxs) | |
| count_imgs += 1 | |
| print('total imgs:', count_imgs) | |
| ngram_words, count_refs = get_doc_freq(refs_words, params) | |
| ngram_idxs, count_refs = get_doc_freq(refs_idxs, params) | |
| print('count_refs:', count_refs) | |
| return ngram_words, ngram_idxs, count_refs | |
| def main(params): | |
| imgs = json.load(open(params['input_json'], 'r')) | |
| dict_json = json.load(open(params['dict_json'], 'r')) | |
| itow = dict_json['ix_to_word'] | |
| wtoi = {w:i for i,w in itow.items()} | |
| # Load bpe | |
| if 'bpe' in dict_json: | |
| import tempfile | |
| import codecs | |
| codes_f = tempfile.NamedTemporaryFile(delete=False) | |
| codes_f.close() | |
| with open(codes_f.name, 'w') as f: | |
| f.write(dict_json['bpe']) | |
| with codecs.open(codes_f.name, encoding='UTF-8') as codes: | |
| bpe = apply_bpe.BPE(codes) | |
| params.bpe = bpe | |
| imgs = imgs['images'] | |
| ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params) | |
| utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb')) | |
| utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb')) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # input json | |
| parser.add_argument('--input_json', default='data/dataset_coco.json', help='input json file to process into hdf5') | |
| parser.add_argument('--dict_json', default='data/cocotalk.json', help='output json file') | |
| parser.add_argument('--output_pkl', default='data/coco-all', help='output pickle file') | |
| parser.add_argument('--split', default='all', help='test, val, train, all') | |
| args = parser.parse_args() | |
| params = vars(args) # convert to ordinary dict | |
| main(params) | |