Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import pickle | |
| class Game_Cache: | |
| def __init__(self, question_dict, image_files, yn, hard_setting): | |
| self.question_dict = question_dict | |
| self.image_files = image_files | |
| self.yn = yn | |
| self.hard_setting = hard_setting | |
| image_list = [] | |
| with open('./mscoco/mscoco_images.txt', 'r') as f: | |
| for line in f.readlines(): | |
| image_list.append(line.strip()) | |
| image_list_hard = [] | |
| with open('./mscoco/mscoco_images_attribute_n=1.txt', 'r') as f: | |
| for line in f.readlines(): | |
| image_list_hard.append(line.strip()) | |
| yn_indices = list(range(40,80))+list(range(120,160)) | |
| hard_setting_indices = list(range(80,160)) | |
| from model.run_question_asking_model import return_modules, return_modules_yn | |
| global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions | |
| global question_model, response_model_simul, caption_model | |
| question_model, response_model_simul, _, caption_model = return_modules() | |
| global question_model_yn, response_model_simul_yn, caption_model_yn | |
| question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn() | |
| def create_cache(taskid): | |
| original_taskid = taskid | |
| global question_model, response_model_simul, caption_model | |
| global question_model_yn, response_model_simul_yn, caption_model_yn | |
| if taskid in yn_indices: | |
| yn = True | |
| curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn | |
| taskid-=40 | |
| else: | |
| yn = False | |
| curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model | |
| if taskid in hard_setting_indices: | |
| hard_setting = True | |
| image_list_curr = image_list_hard | |
| taskid -= 80 | |
| else: | |
| hard_setting = False | |
| image_list_curr = image_list | |
| id1 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+0]}" | |
| id2 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+1]}" | |
| id3 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+2]}" | |
| id4 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+3]}" | |
| id5 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+4]}" | |
| id6 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+5]}" | |
| id7 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+6]}" | |
| id8 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+7]}" | |
| id9 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+8]}" | |
| id10 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+9]}" | |
| image_names = [] | |
| for i in range(10): | |
| image_names.append(image_list_curr[int(taskid)*10+i]) | |
| image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10] | |
| image_files = [x[15:] for x in image_files] | |
| import os | |
| for i in image_files: | |
| os.system(f"cp ./../../../data/ms-coco/images/{i} ./mscoco-images/val2014/") | |
| images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files] | |
| images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np] | |
| p_y_x = (torch.ones(10)/10).to(curr_question_model.device) | |
| captions = curr_caption_model.get_captions(image_files) | |
| questions, target_questions = curr_question_model.get_questions(image_files, captions, 0) | |
| curr_question_model.reset_question_bank() | |
| first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul) | |
| gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting) | |
| with open(f'./cache-soft/{int(original_taskid)}.p', 'wb') as fp: | |
| pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL) | |
| if __name__=="__main__": | |
| for i in range(160): | |
| create_cache(i) | |