Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import sys | |
| import json | |
| import tqdm | |
| import copy | |
| from queue import PriorityQueue | |
| import functools | |
| import spacy | |
| nlp = spacy.load("en_core_web_sm") | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, CLIPImageProcessor | |
| from transformers import OwlViTProcessor | |
| from VisualSearch.model.VSM import VSMForCausalLM | |
| from VisualSearch.model.llava import conversation as conversation_lib | |
| from VisualSearch.model.llava.mm_utils import tokenizer_image_token | |
| from VisualSearch.utils.utils import expand2square | |
| from VisualSearch.utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) | |
| def parse_args(args): | |
| parser = argparse.ArgumentParser(description="Visual Search Evaluation") | |
| parser.add_argument("--version", default="craigwu/seal_vsm_7b") | |
| parser.add_argument("--benchmark-folder", default="vstar_bench", type=str) | |
| parser.add_argument("--visualization", action="store_true", default=False) | |
| parser.add_argument("--output_path", default="", type=str) | |
| parser.add_argument("--confidence_low", default=0.3, type=float) | |
| parser.add_argument("--confidence_high", default=0.5, type=float) | |
| parser.add_argument("--target_cue_threshold", default=6.0, type=float) | |
| parser.add_argument("--target_cue_threshold_decay", default=0.7, type=float) | |
| parser.add_argument("--target_cue_threshold_minimum", default=3.0, type=float) | |
| parser.add_argument("--minimum_size_scale", default=4.0, type=float) | |
| parser.add_argument("--minimum_size", default=224, type=int) | |
| parser.add_argument("--model_max_length", default=512, type=int) | |
| parser.add_argument( | |
| "--vision-tower", default="openai/clip-vit-large-patch14", type=str | |
| ) | |
| parser.add_argument("--use_mm_start_end", action="store_true", default=True) | |
| parser.add_argument( | |
| "--conv_type", | |
| default="llava_v1", | |
| type=str, | |
| choices=["llava_v1", "llava_llama_2"], | |
| ) | |
| return parser.parse_args(args) | |
| def tranverse(token): | |
| children = [_ for _ in token.children] | |
| if len(children) == 0: | |
| return token.i, token.i | |
| left_i = token.i | |
| right_i = token.i | |
| for child in children: | |
| child_left_i, child_right_i = tranverse(child) | |
| left_i = min(left_i, child_left_i) | |
| right_i = max(right_i, child_right_i) | |
| return left_i, right_i | |
| def get_noun_chunks(token): | |
| left_children = [] | |
| right_children = [] | |
| for child in token.children: | |
| if child.i < token.i: | |
| left_children.append(child) | |
| else: | |
| right_children.append(child) | |
| start_token_i = token.i | |
| for left_child in left_children[::-1]: | |
| if left_child.dep_ in ['amod', 'compound', 'poss']: | |
| start_token_i, _ = tranverse(left_child) | |
| else: | |
| break | |
| end_token_i = token.i | |
| for right_child in right_children: | |
| if right_child.dep_ in ['relcl', 'prep']: | |
| _, end_token_i = tranverse(right_child) | |
| else: | |
| break | |
| return start_token_i, end_token_i | |
| def filter_chunk_list(chunks): | |
| def overlap(min1, max1, min2, max2): | |
| return min(max1, max2) - max(min1, min2) | |
| chunks = sorted(chunks, key=lambda chunk: chunk[1]-chunk[0], reverse=True) | |
| filtered_chunks = [] | |
| for chunk in chunks: | |
| flag=True | |
| for exist_chunk in filtered_chunks: | |
| if overlap(exist_chunk[0], exist_chunk[1], chunk[0], chunk[1]) >= 0: | |
| flag = False | |
| break | |
| if flag: | |
| filtered_chunks.append(chunk) | |
| return sorted(filtered_chunks, key=lambda chunk: chunk[0]) | |
| def extract_noun_chunks(expression): | |
| doc = nlp(expression) | |
| cur_chunks = [] | |
| for token in doc: | |
| if token.pos_ not in ["NOUN", "PRON"]: | |
| continue | |
| cur_chunks.append(get_noun_chunks(token)) | |
| cur_chunks = filter_chunk_list(cur_chunks) | |
| cur_chunks = [doc[chunk[0]:chunk[1]+1].text for chunk in cur_chunks] | |
| return cur_chunks | |
| def preprocess( | |
| x, | |
| pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), | |
| pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), | |
| img_size=1024, | |
| ) -> torch.Tensor: | |
| """Normalize pixel values and pad to a square input.""" | |
| # Normalize colors | |
| x = (x - pixel_mean) / pixel_std | |
| # Pad | |
| h, w = x.shape[-2:] | |
| padh = img_size - h | |
| padw = img_size - w | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| return x | |
| def box_cxcywh_to_xyxy(x): | |
| x_c, y_c, w, h = x.unbind(1) | |
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h), | |
| (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
| return torch.stack(b, dim=1) | |
| def rescale_bboxes(out_bbox, size): | |
| img_w, img_h = size | |
| b = box_cxcywh_to_xyxy(out_bbox) | |
| b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
| return b | |
| class VSM: | |
| def __init__(self, args): | |
| kwargs = {} | |
| kwargs['torch_dtype'] = torch.bfloat16 | |
| kwargs['device_map'] = 'cuda' | |
| kwargs['is_eval'] = True | |
| vsm_tokenizer = AutoTokenizer.from_pretrained( | |
| args.version, | |
| cache_dir=None, | |
| model_max_length=args.model_max_length, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| vsm_tokenizer.pad_token = vsm_tokenizer.unk_token | |
| loc_token_idx = vsm_tokenizer("[LOC]", add_special_tokens=False).input_ids[0] | |
| vsm_model = VSMForCausalLM.from_pretrained( | |
| args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, loc_token_idx=loc_token_idx, **kwargs | |
| ) | |
| vsm_model.get_model().initialize_vision_modules(vsm_model.get_model().config) | |
| vision_tower = vsm_model.get_model().get_vision_tower().cuda().to(dtype=torch.bfloat16) | |
| vsm_image_processor = vision_tower.image_processor | |
| vsm_model.eval() | |
| clip_image_processor = CLIPImageProcessor.from_pretrained(vsm_model.config.vision_tower) | |
| transform = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16") | |
| self.model = vsm_model | |
| self.vsm_tokenizer = vsm_tokenizer | |
| self.vsm_image_processor = vsm_image_processor | |
| self.clip_image_processor = clip_image_processor | |
| self.transform = transform | |
| self.conv_type = args.conv_type | |
| self.use_mm_start_end = args.use_mm_start_end | |
| def inference(self, image, question, mode='segmentation'): | |
| conv = conversation_lib.conv_templates[self.conv_type].copy() | |
| conv.messages = [] | |
| prompt = DEFAULT_IMAGE_TOKEN + "\n" + question | |
| if self.use_mm_start_end: | |
| replace_token = ( DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN) | |
| prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) | |
| conv.append_message(conv.roles[0], prompt) | |
| conv.append_message(conv.roles[1], "") | |
| prompt = conv.get_prompt() | |
| background_color = tuple(int(x*255) for x in self.clip_image_processor.image_mean) | |
| image_clip = self.clip_image_processor.preprocess(expand2square(image, background_color), return_tensors="pt")["pixel_values"][0].unsqueeze(0).cuda() | |
| image_clip = image_clip.bfloat16() | |
| image = np.array(image) | |
| original_size_list = [image.shape[:2]] | |
| image = self.transform(images=image, return_tensors="pt")['pixel_values'].cuda() | |
| resize_list = [image.shape[:2]] | |
| image = image.bfloat16() | |
| input_ids = tokenizer_image_token(prompt, self.vsm_tokenizer, return_tensors="pt") | |
| input_ids = input_ids.unsqueeze(0).cuda() | |
| output_ids, pred_masks, det_result = self.model.inference( | |
| image_clip, | |
| image, | |
| input_ids, | |
| resize_list, | |
| original_size_list, | |
| max_new_tokens=100, | |
| tokenizer=self.vsm_tokenizer, | |
| mode = mode | |
| ) | |
| if mode == 'segmentation': | |
| pred_mask = pred_masks[0] | |
| pred_mask = torch.clamp(pred_mask, min=0) | |
| return pred_mask[-1] | |
| elif mode == 'vqa': | |
| input_token_len = input_ids.shape[1] | |
| n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | |
| if n_diff_input_output > 0: | |
| print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') | |
| text_output = self.vsm_tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | |
| text_output = text_output.replace("\n", "").replace(" ", " ").strip() | |
| return text_output | |
| elif mode == 'detection': | |
| pred_mask = pred_masks[0] | |
| pred_mask = torch.clamp(pred_mask, min=0) | |
| return det_result['pred_boxes'][0].cpu(), det_result['pred_logits'][0].sigmoid().cpu(), pred_mask[-1] | |
| def refine_bbox(bbox, image_width, image_height): | |
| bbox[0] = max(0, bbox[0]) | |
| bbox[1] = max(0, bbox[1]) | |
| bbox[2] = min(bbox[2], image_width-bbox[0]) | |
| bbox[3] = min(bbox[3], image_height-bbox[1]) | |
| return bbox | |
| def split_4subpatches(current_patch_bbox): | |
| hw_ratio = current_patch_bbox[3] / current_patch_bbox[2] | |
| if hw_ratio >= 2: | |
| return 1, 4 | |
| elif hw_ratio <= 0.5: | |
| return 4, 1 | |
| else: | |
| return 2, 2 | |
| def get_sub_patches(current_patch_bbox, num_of_width_patches, num_of_height_patches): | |
| width_stride = int(current_patch_bbox[2]//num_of_width_patches) | |
| height_stride = int(current_patch_bbox[3]/num_of_height_patches) | |
| sub_patches = [] | |
| for j in range(num_of_height_patches): | |
| for i in range(num_of_width_patches): | |
| sub_patch_width = current_patch_bbox[2] - i*width_stride if i == num_of_width_patches-1 else width_stride | |
| sub_patch_height = current_patch_bbox[3] - j*height_stride if j == num_of_height_patches-1 else height_stride | |
| sub_patch = [current_patch_bbox[0]+i*width_stride, current_patch_bbox[1]+j*height_stride, sub_patch_width, sub_patch_height] | |
| sub_patches.append(sub_patch) | |
| return sub_patches, width_stride, height_stride | |
| def get_subpatch_scores(score_heatmap, current_patch_bbox, sub_patches): | |
| total_sum = (score_heatmap/(current_patch_bbox[2]*current_patch_bbox[3])).sum() | |
| sub_scores = [] | |
| for sub_patch in sub_patches: | |
| bbox = [(sub_patch[0]-current_patch_bbox[0]), sub_patch[1]-current_patch_bbox[1], sub_patch[2], sub_patch[3]] | |
| score = (score_heatmap[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]/(current_patch_bbox[2]*current_patch_bbox[3])).sum() | |
| if total_sum > 0: | |
| score /= total_sum | |
| else: | |
| score *= 0 | |
| sub_scores.append(score) | |
| return sub_scores | |
| def normalize_score(score_heatmap): | |
| max_score = score_heatmap.max() | |
| min_score = score_heatmap.min() | |
| if max_score != min_score: | |
| score_heatmap = (score_heatmap - min_score) / (max_score - min_score) | |
| else: | |
| score_heatmap = score_heatmap * 0 | |
| return score_heatmap | |
| def iou(bbox1, bbox2): | |
| x1 = max(bbox1[0], bbox2[0]) | |
| y1 = max(bbox1[1], bbox2[1]) | |
| x2 = min(bbox1[0]+bbox1[2], bbox2[0]+bbox2[2]) | |
| y2 = min(bbox1[1]+bbox1[3],bbox2[1]+bbox2[3]) | |
| inter_area = max(0, x2 - x1) * max(0, y2 - y1) | |
| return inter_area/(bbox1[2]*bbox1[3]+bbox2[2]*bbox2[3]-inter_area) | |
| BOX_COLOR = (255, 0, 0) # Red | |
| TEXT_COLOR = (255, 255, 255) # White | |
| import cv2 | |
| from matplotlib import pyplot as plt | |
| def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=2): | |
| """Visualizes a single bounding box on the image""" | |
| x_min, y_min, w, h = bbox | |
| x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h) | |
| cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness) | |
| ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1) | |
| cv2.putText( | |
| img, | |
| text=class_name, | |
| org=(x_min, y_min - int(0.3 * text_height)), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
| fontScale=0.5, | |
| color=TEXT_COLOR, | |
| lineType=cv2.LINE_AA, | |
| ) | |
| return img | |
| def show_heatmap_on_image(img: np.ndarray, | |
| mask: np.ndarray, | |
| use_rgb: bool = False, | |
| colormap: int = cv2.COLORMAP_JET, | |
| image_weight: float = 0.5) -> np.ndarray: | |
| mask = np.clip(mask, 0, 1) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) | |
| if use_rgb: | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| heatmap = np.float32(heatmap) / 255 | |
| if np.max(img) > 1: | |
| raise Exception( | |
| "The input image should np.float32 in the range [0, 1]") | |
| if image_weight < 0 or image_weight > 1: | |
| raise Exception( | |
| f"image_weight should be in the range [0, 1].\ | |
| Got: {image_weight}") | |
| cam = (1 - image_weight) * heatmap + image_weight * img | |
| cam = cam / np.max(cam) | |
| return np.uint8(255 * cam) | |
| def vis_heatmap(image, heatmap, use_rgb=False): | |
| max_v = np.max(heatmap) | |
| min_v = np.min(heatmap) | |
| if max_v != min_v: | |
| heatmap = (heatmap - min_v) / (max_v - min_v) | |
| heatmap_image = show_heatmap_on_image(image.astype(float)/255., heatmap, use_rgb=use_rgb) | |
| return heatmap_image | |
| def visualize_search_path(image, search_path, search_length, target_bbox, label, save_path): | |
| context_cue_list = [] | |
| whole_image = image | |
| os.makedirs(save_path, exist_ok=True) | |
| whole_image.save(os.path.join(save_path, 'whole_image.jpg')) | |
| whole_image = np.array(whole_image) | |
| if target_bbox is not None: | |
| whole_image = visualize_bbox(whole_image.copy(), target_bbox, class_name="gt: "+label, color=(255,0,0)) | |
| for step_i, node in enumerate(search_path): | |
| if step_i + 1 > search_length: | |
| break | |
| current_patch_box = node['bbox'] | |
| if 'detection_result' in node: | |
| final_patch_image = image.crop((current_patch_box[0],current_patch_box[1],current_patch_box[0]+current_patch_box[2], current_patch_box[1]+current_patch_box[3])) | |
| final_patch_image.save(os.path.join(save_path, 'final_patch_image.jpg')) | |
| final_search_result = visualize_bbox(np.array(final_patch_image), node['detection_result'], class_name='search result', color=(255,0,0)) | |
| final_search_result = cv2.cvtColor(final_search_result, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(os.path.join(save_path, 'search_result.jpg'), final_search_result) | |
| cur_whole_image = visualize_bbox(whole_image.copy(), current_patch_box, class_name="step-{}".format(step_i+1), color=(0,0,255)) | |
| # if step_i != len(search_path)-1: | |
| # next_patch_box = search_path[step_i+1]['bbox'] | |
| # cur_whole_image = visualize_bbox(cur_whole_image, next_patch_box, class_name="next-step", color=(0,255,0)) | |
| cur_whole_image = cv2.cvtColor(cur_whole_image, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(os.path.join(save_path, 'step_{}.jpg'.format(step_i+1)), cur_whole_image) | |
| cur_patch_image = image.crop((current_patch_box[0],current_patch_box[1],current_patch_box[0]+current_patch_box[2], current_patch_box[1]+current_patch_box[3])) | |
| if 'context_cue' in node: | |
| context_cue = node['context_cue'] | |
| context_cue_list.append('step{}: {}'.format(step_i+1, context_cue)+'\n') | |
| if 'final_heatmap' in node: | |
| score_map = node['final_heatmap'] | |
| score_map = vis_heatmap(np.array(cur_patch_image), score_map, use_rgb=True) | |
| score_map = cv2.cvtColor(score_map, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(os.path.join(save_path, 'step_{}_heatmap.jpg'.format(step_i+1)), score_map) | |
| with open(os.path.join(save_path, 'context_cue.txt'),"w") as f: | |
| f.writelines(context_cue_list) | |
| class Prioritize: | |
| def __init__(self, priority, item): | |
| self.priority = priority | |
| self.item = item | |
| def __eq__(self, other): | |
| return self.priority == other.priority | |
| def __lt__(self, other): | |
| return self.priority < other.priority | |
| def visual_search_queue(vsm, image, target_object_name, current_patch, search_path, queue, smallest_size=224, confidence_high=0.5, target_cue_threshold=6.0, target_cue_threshold_decay=0.7, target_cue_threshold_minimum=3.0): | |
| current_patch_bbox = current_patch['bbox'] | |
| current_patch_scale_level = current_patch['scale_level'] | |
| image_patch = image.crop((int(current_patch_bbox[0]), int(current_patch_bbox[1]), int(current_patch_bbox[0]+current_patch_bbox[2]), int(current_patch_bbox[1]+current_patch_bbox[3]))) | |
| # whehter we can detect the target object on the current image patch | |
| question = "Please locate the {} in this image.".format(target_object_name) | |
| pred_bboxes, pred_logits, target_cue_heatmap = vsm.inference(copy.deepcopy(image_patch), question, mode='detection') | |
| if len(pred_logits) > 0: | |
| top_index = pred_logits.view(-1).argmax() | |
| top_logit = pred_logits.view(-1).max() | |
| final_bbox = pred_bboxes[top_index].view(4) | |
| final_bbox = final_bbox * torch.Tensor([image_patch.width, image_patch.height, image_patch.width, image_patch.height]) | |
| final_bbox[:2] -= final_bbox[2:] / 2 | |
| if top_logit > confidence_high: | |
| search_path[-1]['detection_result'] = final_bbox | |
| # only return multiple detected instances on the whole image | |
| if len(search_path) == 1: | |
| all_valid_boxes = pred_bboxes[pred_logits.view(-1)>0.5].view(-1, 4) | |
| all_valid_boxes = all_valid_boxes * torch.Tensor([[image_patch.width, image_patch.height, image_patch.width, image_patch.height]]) | |
| all_valid_boxes[:, :2] -= all_valid_boxes[:, 2:] / 2 | |
| return True, search_path, all_valid_boxes | |
| return True, search_path, None | |
| else: | |
| search_path[-1]['temp_detection_result'] = (top_logit, final_bbox) | |
| ### current patch is already the smallest unit | |
| if min(current_patch_bbox[2], current_patch_bbox[3]) <= smallest_size: | |
| return False, search_path, None | |
| target_cue_heatmap = target_cue_heatmap.view(current_patch_bbox[3], current_patch_bbox[2], 1) | |
| score_max = target_cue_heatmap.max().item() | |
| # check whether the target cue is prominent | |
| threshold = max(target_cue_threshold_minimum, target_cue_threshold*(target_cue_threshold_decay)**(current_patch_scale_level-1)) | |
| if score_max > threshold: | |
| target_cue_heatmap = normalize_score(target_cue_heatmap) | |
| final_heatmap = target_cue_heatmap | |
| else: | |
| question = "According to the common sense knowledge and possible visual cues, what is the most likely location of the {} in the image?".format(target_object_name) | |
| vqa_results = vsm.inference(copy.deepcopy(image_patch), question, mode='vqa') | |
| possible_location_phrase = vqa_results.split('most likely to appear')[-1].strip() | |
| if possible_location_phrase.endswith('.'): | |
| possible_location_phrase = possible_location_phrase[:-1] | |
| possible_location_phrase = possible_location_phrase.split(target_object_name)[-1] | |
| noun_chunks = extract_noun_chunks(possible_location_phrase) | |
| if len(noun_chunks) == 1: | |
| possible_location_phrase = noun_chunks[0] | |
| else: | |
| possible_location_phrase = "region {}".format(possible_location_phrase) | |
| question = "Please locate the {} in this image.".format(possible_location_phrase) | |
| context_cue_heatmap = vsm.inference(copy.deepcopy(image_patch), question, mode='segmentation').view(current_patch_bbox[3], current_patch_bbox[2], 1) | |
| context_cue_heatmap = normalize_score(context_cue_heatmap) | |
| final_heatmap = context_cue_heatmap | |
| current_patch_index = len(search_path)-1 | |
| if score_max <= threshold: | |
| search_path[current_patch_index]['context_cue'] = vqa_results + "#" + possible_location_phrase | |
| search_path[current_patch_index]['final_heatmap'] = final_heatmap.cpu().numpy() | |
| ### split the current patch into 4 sub-patches | |
| basic_sub_patches, sub_patch_width, sub_patch_height = get_sub_patches(current_patch_bbox, *split_4subpatches(current_patch_bbox)) | |
| tmp_patch = current_patch | |
| basic_sub_scores = [0]*len(basic_sub_patches) | |
| while True: | |
| tmp_score_heatmap = tmp_patch['final_heatmap'] | |
| tmp_sub_scores = get_subpatch_scores(tmp_score_heatmap, tmp_patch['bbox'], basic_sub_patches) | |
| basic_sub_scores = [basic_sub_scores[patch_i]+tmp_sub_scores[patch_i]/(4**tmp_patch['scale_level']) for patch_i in range(len(basic_sub_scores))] | |
| if tmp_patch['parent_index'] == -1: | |
| break | |
| else: | |
| tmp_patch = search_path[tmp_patch['parent_index']] | |
| sub_patches = basic_sub_patches | |
| sub_scores = basic_sub_scores | |
| for sub_patch, sub_score in zip(sub_patches, sub_scores): | |
| new_patch_info = dict() | |
| new_patch_info['bbox'] = sub_patch | |
| new_patch_info['scale_level'] = current_patch_scale_level + 1 | |
| new_patch_info['score'] = sub_score | |
| new_patch_info['parent_index'] = current_patch_index | |
| queue.put(Prioritize(-new_patch_info['score'], new_patch_info)) | |
| while(not queue.empty()): | |
| patch_chosen = queue.get().item | |
| search_path.append(patch_chosen) | |
| success, search_path, all_valid_boxes = visual_search_queue(vsm, image, target_object_name, patch_chosen, search_path, queue, smallest_size=smallest_size, confidence_high=confidence_high, target_cue_threshold=target_cue_threshold, target_cue_threshold_decay=target_cue_threshold_decay, target_cue_threshold_minimum=target_cue_threshold_minimum) | |
| if success: | |
| return success, search_path, all_valid_boxes | |
| return False, search_path, None | |
| def visual_search(vsm, image, target_object_name, target_bbox, smallest_size, confidence_high=0.5, confidence_low=0.3, target_cue_threshold=6.0, target_cue_threshold_decay=0.7, target_cue_threshold_minimum=3.0, visualize=False, save_path=None): | |
| if visualize: | |
| assert save_path is not None | |
| init_patch = dict() | |
| init_patch['bbox'] = [0,0,image.width,image.height] | |
| init_patch['scale_level'] = 1 | |
| init_patch['score'] = None | |
| init_patch['parent_index'] = -1 | |
| search_path = [init_patch] | |
| queue = PriorityQueue() | |
| search_successful, search_path, all_valid_boxes = visual_search_queue(vsm, image, target_object_name, init_patch, search_path, queue, smallest_size=smallest_size, confidence_high=confidence_high, target_cue_threshold=target_cue_threshold, target_cue_threshold_decay=target_cue_threshold_decay, target_cue_threshold_minimum=target_cue_threshold_minimum) | |
| path_length = len(search_path) | |
| final_step = search_path[-1] | |
| if not search_successful: | |
| # if no target is found with confidence passing confidence_high, select the target with the highest confidence during search and compare its confidence with confidence_low | |
| max_logit = 0 | |
| final_step = None | |
| path_length = 0 | |
| for i, search_step in enumerate(search_path): | |
| if 'temp_detection_result' in search_step: | |
| if search_step['temp_detection_result'][0] > max_logit: | |
| max_logit = search_step['temp_detection_result'][0] | |
| final_step = search_step | |
| path_length = i+1 | |
| final_step['detection_result'] = final_step['temp_detection_result'][1] | |
| if max_logit >= confidence_low: | |
| search_successful = True | |
| if visualize: | |
| vis_path_length = path_length if search_successful else len(search_path) | |
| visualize_search_path(image, search_path, vis_path_length, target_bbox, target_object_name, save_path) | |
| del queue | |
| return final_step, path_length, search_successful, all_valid_boxes | |
| def main(args): | |
| args = parse_args(args) | |
| vsm = VSM(args) | |
| benchmark_folder = args.benchmark_folder | |
| acc_list = [] | |
| search_path_length_list = [] | |
| for test_type in ['direct_attributes', 'relative_position']: | |
| folder = os.path.join(benchmark_folder, test_type) | |
| output_folder = None | |
| if args.visualization: | |
| output_folder = os.path.join(args.output_path, test_type) | |
| os.makedirs(output_folder, exist_ok=True) | |
| image_files = filter(lambda file: '.json' not in file, os.listdir(folder)) | |
| for image_file in tqdm.tqdm(image_files): | |
| image_path = os.path.join(folder, image_file) | |
| annotation_path = image_path.split('.')[0] + '.json' | |
| annotation = json.load(open(annotation_path)) | |
| bboxs = annotation['bbox'] | |
| object_names = annotation['target_object'] | |
| for i, (gt_bbox, object_name) in enumerate(zip(bboxs, object_names)): | |
| image = Image.open(image_path).convert('RGB') | |
| smallest_size = max(int(np.ceil(min(image.width, image.height)/args.minimum_size_scale)), args.minimum_size) | |
| if args.visualization: | |
| vis_path = os.path.join(output_folder, "{}_{}".format(image_file.split('.')[0],i)) | |
| else: | |
| vis_path = None | |
| final_step, path_length, search_successful, all_valid_boxes = visual_search(vsm, image, object_name, target_bbox=gt_bbox, smallest_size=smallest_size, confidence_high=args.confidence_high, confidence_low=args.confidence_low, target_cue_threshold=args.target_cue_threshold, target_cue_threshold_decay=args.target_cue_threshold_decay, target_cue_threshold_minimum=args.target_cue_threshold_minimum, save_path=vis_path, visualize=args.visualization) | |
| if search_successful: | |
| search_bbox = final_step['detection_result'] | |
| search_final_patch = final_step['bbox'] | |
| search_bbox[0] += search_final_patch[0] | |
| search_bbox[1] += search_final_patch[1] | |
| iou_i = iou(search_bbox, gt_bbox).item() | |
| det_acc = 1.0 if iou_i > 0.5 else 0.0 | |
| acc_list.append(det_acc) | |
| search_path_length_list.append(path_length) | |
| else: | |
| acc_list.append(0) | |
| search_path_length_list.append(0) | |
| print('Avg search path length:', np.mean([search_path_length_list[i] for i in range(len(search_path_length_list)) if acc_list[i]])) | |
| print('Top 1 Acc:', np.mean(acc_list)) | |
| if __name__ == "__main__": | |
| main(sys.argv[1:]) |