Spaces:
Runtime error
Runtime error
| from open_flamingo.eval.vqa_metric import compute_vqa_accuracy | |
| import sys | |
| import json | |
| from bert_score import BERTScorer | |
| from tqdm.contrib.concurrent import process_map | |
| from tqdm import tqdm | |
| import random | |
| import time | |
| NUM_GPU = 128 | |
| def single_job(args): | |
| data, refs, idx = args | |
| success = False | |
| while not success: | |
| try: | |
| time.sleep(random.random()*10) | |
| scorer = BERTScorer( | |
| lang="en", | |
| rescale_with_baseline=True, | |
| # model_type="microsoft/deberta-xlarge-mnli", | |
| model_type="bert-base-uncased", | |
| batch_size=4096, | |
| device=f"cuda:{idx % 6}" | |
| ) | |
| success = True | |
| except: | |
| time.sleep(random.random()*5) | |
| for i, d in enumerate(tqdm(data, disable=(idx != 0))): | |
| if d["answer"] == "": | |
| continue | |
| cands = [d["answer"]] * len(refs) | |
| P, R, F1 = scorer.score(cands, refs, verbose=False) | |
| d["answer"] = refs[F1.argmax()] | |
| data[i] = d | |
| return data | |
| if __name__ == "__main__": | |
| if sys.argv[1] == "vqav2": | |
| question_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_OpenEnded_mscoco_val2014_questions.json" | |
| annotation_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_mscoco_val2014_annotations.json" | |
| else: | |
| raise NotImplementedError | |
| answer_list = json.load(open("answer_list.json")) | |
| data = json.load(open(sys.argv[2])) | |
| cands = [] | |
| refs = [] | |
| data_parts = [] | |
| for i in range(NUM_GPU): | |
| data_parts.append([[], answer_list, i]) | |
| for i, d in enumerate(data): | |
| data_parts[i % NUM_GPU][0].append(d) | |
| datas = process_map(single_job, data_parts, max_workers=NUM_GPU, disable=True) | |
| all_data = [] | |
| for data in datas: | |
| all_data.extend(data) | |
| json.dump(all_data, open("temp_result", "w")) | |
| acc = compute_vqa_accuracy( | |
| result_json_path="temp_result", | |
| question_json_path=question_json_path, | |
| annotation_json_path=annotation_json_path, | |
| ) | |
| print(acc) | |