Spaces:
Runtime error
Runtime error
| from lavis.datasets.builders import load_dataset | |
| import torch | |
| import more_itertools | |
| from tqdm import tqdm | |
| from coco_metric import compute_cider, postprocess_captioning_generation | |
| import json | |
| import time | |
| import os | |
| from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor | |
| class VisualLogitsProcessor(LogitsProcessor): | |
| def __init__(self, tokenizer): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1] | |
| self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.topk = 2 | |
| def __call__(self, input_ids, scores): | |
| # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk])) | |
| # import pdb; pdb.set_trace() | |
| if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum(): | |
| scores[0, self.object_token_id] = 1000 | |
| if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id: | |
| if (input_ids[0, :-1] == self.object_token_id).sum() != 0: | |
| # print("generate a previsual token next") | |
| scores[0, self.previsual_token_id] = 1000 | |
| elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id: | |
| # print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual") | |
| scores[0, self.eos_token_id] = 1000 | |
| elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id: | |
| # print("generate a visual token next") | |
| scores[0, self.visual_token_id] = 1000 | |
| return scores | |
| def prepare_batch_images(batch, image_processor): | |
| batch_images = None | |
| for b in batch: | |
| b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| if batch_images is None: | |
| batch_images = b_image | |
| else: | |
| batch_images = torch.cat([batch_images, b_image], dim=0) | |
| return batch_images | |
| def evaluate_coco_flickr( | |
| model, | |
| tokenizer, | |
| image_processor, | |
| batch_size, | |
| is_flickr=False, | |
| vis_embed_size=None, | |
| rank=0, | |
| world_size=1, | |
| id=0, | |
| debug=False, | |
| ): | |
| """Evaluate a model on COCO dataset. | |
| Returns: | |
| float: CIDEr score | |
| """ | |
| visual_logits_processor = VisualLogitsProcessor(tokenizer) | |
| coco_dataset = load_dataset("coco_caption") | |
| eval_dataset = coco_dataset["test"] | |
| model.eval().cuda() | |
| predictions = dict() | |
| lang_encoder_name = model.lang_encoder.__class__.__name__.lower() | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
| visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| box_token = "<|#box#|>" | |
| prebox_token = "<|#prebox#|>" | |
| endofobject_token = "<|#endofobject#|>" | |
| object_token = "<|#object#|>" | |
| cnt = 0 | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| desc = "Running inference Flickr30" if is_flickr else "Running inference COCO" | |
| for ii, batch in enumerate(more_itertools.chunked( | |
| tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size | |
| )): | |
| if ii % world_size != rank: | |
| continue | |
| cnt += len(batch) | |
| batch_images = prepare_batch_images( | |
| batch=batch, | |
| image_processor=image_processor, | |
| ).cuda() | |
| prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>" | |
| added_bbox_list = [] | |
| batch_text = [prompt for _ in batch] | |
| encodings = tokenizer( | |
| batch_text, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| ori_prompt_length = len(encodings["input_ids"][0]) | |
| have_prebox = False | |
| while True: | |
| batch_text = [prompt for _ in batch] | |
| encodings = tokenizer( | |
| batch_text, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"].cuda() | |
| attention_mask = encodings["attention_mask"].cuda() | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| if debug: | |
| print("input--->",tokenizer.decode(input_ids[0])) | |
| p1 = MinNewTokensLengthLogitsProcessor( | |
| prompt_length_to_skip=input_ids.shape[-1], | |
| min_new_tokens=5, | |
| eos_token_id=bos_token_id, | |
| ) | |
| with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
| outputs = model.generate( | |
| batch_images, | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=20, | |
| # min_new_tokens=8, | |
| num_beams=1, | |
| # length_penalty=0, | |
| image_start_index_list=image_start_index_list, | |
| image_nums=image_nums, | |
| added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
| logits_processor_list=[p1, visual_logits_processor], | |
| ) | |
| if debug: | |
| print("outputs--->",tokenizer.decode(outputs[0])) | |
| if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id: | |
| prompt = tokenizer.decode(outputs.clone()[0]) | |
| is_visual = (outputs[0, -2] == visual_token_id) | |
| batch_text = tokenizer.batch_decode(outputs[:, :-1]) | |
| encodings = tokenizer( | |
| batch_text, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"].cuda() | |
| attention_mask = encodings["attention_mask"].cuda() | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| if debug: | |
| print("get the visual bbox--->",tokenizer.decode(input_ids[0])) | |
| with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): | |
| outputs = model( | |
| vision_x=batch_images, | |
| lang_x=input_ids, | |
| attention_mask=attention_mask, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
| add_box=added_bbox_list is not None and len(added_bbox_list) != 0, | |
| ) | |
| boxes = outputs["boxes"] | |
| scores = outputs["scores"] | |
| # if not model.valid: | |
| # import pdb; pdb.set_trace() | |
| if boxes is not None: | |
| if is_visual: | |
| if have_prebox: | |
| added_bbox_list.pop() | |
| prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") | |
| have_prebox = False | |
| if debug: | |
| print("find previsual and remove it--->", prompt) | |
| first_box = boxes[scores.argmax()] | |
| added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224] | |
| prompt = prompt[:-len(tokenizer.eos_token)] | |
| prompt += box_token + endofobject_token | |
| if debug: | |
| print("after inserting visual---->", prompt) | |
| else: | |
| # import numpy as np | |
| # import cv2 | |
| # open_cv_image = np.array(batch[0]["image"]) | |
| # open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| # for pre_box in boxes: | |
| # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2) | |
| # cv2.imwrite("Atest.png", open_cv_image) | |
| pre_box = boxes[scores.argmax()] | |
| added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224] | |
| prompt = prompt[:-len(tokenizer.eos_token)] | |
| prompt += prebox_token + object_token | |
| have_prebox = True | |
| if debug: | |
| print("after inserting previsual---->", prompt) | |
| else: | |
| import pdb;pdb.set_trace() | |
| prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) | |
| else: | |
| break | |
| outputs = outputs[:, ori_prompt_length:] | |
| new_predictions = [ | |
| postprocess_captioning_generation(out).replace('"', "") | |
| for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| ] | |
| # import pdb; pdb.set_trace() | |
| if rank == 0: | |
| tqdm.write(new_predictions[0]) | |
| for i, sample in enumerate(batch): | |
| predictions[int(sample["image_id"])] = { | |
| "caption": new_predictions[i], | |
| } | |
| results_path = ( | |
| f"flickrresults_{lang_encoder_name}_{rank}_{id}.json" | |
| if is_flickr | |
| else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json" | |
| ) | |
| with open(results_path, "w") as f: | |
| f.write( | |
| json.dumps( | |
| [ | |
| {"image_id": k, "caption": predictions[k]["caption"]} | |
| for k in predictions | |
| ], | |
| indent=2, | |
| ) | |
| ) | |
| print("save to", results_path) | |
| del predictions | |
| time.sleep(10) | |
| if world_size > 1: | |
| torch.distributed.barrier() | |
| if rank == 0: | |
| print(f"evaluate on rank {rank}. world size is {world_size}") | |
| predictions = [] | |
| for rank_i in range(world_size): | |
| part_results_path = ( | |
| f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json" | |
| if is_flickr | |
| else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json" | |
| ) | |
| print("load", part_results_path) | |
| predictions.extend(json.load(open(part_results_path))) | |
| os.remove(part_results_path) | |
| print("num:", len(predictions)) | |
| results_path = ( | |
| f"flickrresults_{lang_encoder_name}.json" | |
| if is_flickr | |
| else f"cocoresults_{lang_encoder_name}.json" | |
| ) | |
| json.dump(predictions, open(results_path, "w"), indent=2) | |
| metrics = compute_cider( | |
| result_path=results_path, | |
| annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json", | |
| ) | |
| metrics["CIDEr"] *= 100 | |
| os.makedirs("eval_results", exist_ok=True) | |
| acc = metrics["CIDEr"] | |
| with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f: | |
| f.write(json.dumps(predictions, indent=2)) | |
| # delete the temporary file | |
| os.remove(results_path) | |
| else: | |
| metrics = {} | |
| metrics["CIDEr"] = 0.0 | |
| return metrics["CIDEr"] | |