Spaces:
Running
Running
| from typing import List, Tuple | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from mmscan_utils.box_metric import (get_average_precision, | |
| get_general_topk_scores, | |
| subset_get_average_precision) | |
| from mmscan_utils.box_utils import index_box, to_9dof_box | |
| class VisualGroundingEvaluator: | |
| """Evaluator for MMScan Visual Grounding benchmark. The evaluation metric | |
| includes "AP","AP_C","AR","gTop-k". | |
| Attributes: | |
| save_buffer(list[dict]): Save the buffer of Inputs. | |
| records(list[dict]): Metric results for each sample | |
| category_records(dict): Metric results for each category | |
| (average of all samples with the same category) | |
| Args: | |
| show_results(bool): Whether to print the evaluation results. | |
| Defaults to True. | |
| """ | |
| def __init__(self, show_results: bool = True) -> None: | |
| self.show_results = show_results | |
| self.eval_metric_type = ['AP', 'AR'] | |
| self.top_k_visible = [1, 3, 5] | |
| self.call_for_category_mode = True | |
| for top_k in [1, 3, 5, 10]: | |
| self.eval_metric_type.append(f'gTop-{top_k}') | |
| self.iou_thresholds = [0.25, 0.50] | |
| self.eval_metric = [] | |
| for iou_thr in self.iou_thresholds: | |
| for eval_type in self.eval_metric_type: | |
| self.eval_metric.append(eval_type + '@' + str(iou_thr)) | |
| self.reset() | |
| def reset(self) -> None: | |
| """Reset the evaluator, clear the buffer and records.""" | |
| self.save_buffer = [] | |
| self.records = [] | |
| self.category_records = {} | |
| def update(self, raw_batch_input: List[dict]) -> None: | |
| """Update a batch of results to the buffer. | |
| Args: | |
| raw_batch_input (list[dict]): | |
| Batch of the raw original input. | |
| """ | |
| self.__check_format__(raw_batch_input) | |
| self.save_buffer.extend(raw_batch_input) | |
| def start_evaluation(self) -> dict: | |
| """This function is used to start the evaluation process. | |
| It will iterate over the saved buffer and evaluate each item. | |
| Returns: | |
| category_records(dict): Metric results per category. | |
| """ | |
| category_collect = {} | |
| for data_item in tqdm(self.save_buffer): | |
| metric_for_single = {} | |
| # (1) len(gt)==0 : skip | |
| if self.__is_zero__(data_item['gt_bboxes']): | |
| continue | |
| # (2) len(pred)==0 : model's wrong | |
| if self.__is_zero__(data_item['pred_bboxes']): | |
| for iou_thr in self.iou_thresholds: | |
| metric_for_single[f'AP@{iou_thr}'] = 0 | |
| metric_for_single[f'AR@{iou_thr}'] = 0 | |
| for topk in [1, 3, 5, 10]: | |
| metric_for_single[f'gTop-{topk}@{iou_thr}'] = 0 | |
| data_item['num_gts'] = len(data_item['gt_bboxes']) | |
| data_item.update(metric_for_single) | |
| self.records.append(data_item) | |
| continue | |
| iou_array, pred_score = self.__calculate_iou_array_(data_item) | |
| if self.call_for_category_mode: | |
| category = self.__category_mapping__(data_item['subclass']) | |
| if category not in category_collect.keys(): | |
| category_collect[category] = { | |
| 'ious': [], | |
| 'scores': [], | |
| 'sample_indices': [], | |
| 'cnt': 0, | |
| } | |
| category_collect[category]['ious'].extend(iou_array) | |
| category_collect[category]['scores'].extend(pred_score) | |
| category_collect[category]['sample_indices'].extend( | |
| [data_item['index']] * len(iou_array)) | |
| category_collect[category]['cnt'] += 1 | |
| for iou_thr in self.iou_thresholds: | |
| # AP/AR metric | |
| AP, AR = get_average_precision(iou_array, iou_thr) | |
| metric_for_single[f'AP@{iou_thr}'] = AP | |
| metric_for_single[f'AR@{iou_thr}'] = AR | |
| # topk metric | |
| metric_for_single.update( | |
| get_general_topk_scores(iou_array, iou_thr)) | |
| data_item['num_gts'] = iou_array.shape[1] | |
| data_item.update(metric_for_single) | |
| self.records.append(data_item) | |
| self.collect_result() | |
| if self.call_for_category_mode: | |
| for iou_thr in self.iou_thresholds: | |
| self.category_records['overall'][f'AP_C@{iou_thr}'] = 0 | |
| self.category_records['overall'][f'AR_C@{iou_thr}'] = 0 | |
| for category in category_collect: | |
| AP_C, AR_C = subset_get_average_precision( | |
| category_collect[category], iou_thr) | |
| self.category_records[category][f'AP_C@{iou_thr}'] = AP_C | |
| self.category_records[category][f'AR_C@{iou_thr}'] = AR_C | |
| self.category_records['overall'][f'AP_C@{iou_thr}'] += ( | |
| AP_C * category_collect[category]['cnt'] / | |
| len(self.records)) | |
| self.category_records['overall'][f'AR_C@{iou_thr}'] += ( | |
| AR_C * category_collect[category]['cnt'] / | |
| len(self.records)) | |
| return self.category_records | |
| def collect_result(self) -> dict: | |
| """Collect the result from the evaluation process. | |
| Stores them based on their subclass. | |
| Returns: | |
| category_results(dict): Average results per category. | |
| """ | |
| category_results = {} | |
| category_results['overall'] = {} | |
| for metric_name in self.eval_metric: | |
| category_results['overall'][metric_name] = [] | |
| category_results['overall']['num_gts'] = 0 | |
| for data_item in self.records: | |
| category = self.__category_mapping__(data_item['subclass']) | |
| if category not in category_results: | |
| category_results[category] = {} | |
| for metric_name in self.eval_metric: | |
| category_results[category][metric_name] = [] | |
| category_results[category]['num_gts'] = 0 | |
| for metric_name in self.eval_metric: | |
| for metric_name in self.eval_metric: | |
| category_results[category][metric_name].append( | |
| data_item[metric_name]) | |
| category_results['overall'][metric_name].append( | |
| data_item[metric_name]) | |
| category_results['overall']['num_gts'] += data_item['num_gts'] | |
| category_results[category]['num_gts'] += data_item['num_gts'] | |
| for category in category_results: | |
| for metric_name in self.eval_metric: | |
| category_results[category][metric_name] = np.mean( | |
| category_results[category][metric_name]) | |
| self.category_records = category_results | |
| return category_results | |
| def print_result(self) -> str: | |
| """Showing the result table. | |
| Returns: | |
| table(str): The metric result table. | |
| """ | |
| assert len(self.category_records) > 0, 'No result yet.' | |
| self.category_records = { | |
| key: self.category_records[key] | |
| for key in sorted(self.category_records.keys(), reverse=True) | |
| } | |
| header = ['Type'] | |
| header.extend(self.category_records.keys()) | |
| table_columns = [[] for _ in range(len(header))] | |
| # some metrics | |
| for iou_thr in self.iou_thresholds: | |
| show_in_table = (['AP', 'AR'] + | |
| [f'gTop-{k}' for k in self.top_k_visible] | |
| if not self.call_for_category_mode else | |
| ['AP', 'AR', 'AP_C', 'AR_C'] + | |
| [f'gTop-{k}' for k in self.top_k_visible]) | |
| for metric_type in show_in_table: | |
| table_columns[0].append(metric_type + ' ' + str(iou_thr)) | |
| for i, category in enumerate(self.category_records.keys()): | |
| ap = self.category_records[category][f'AP@{iou_thr}'] | |
| ar = self.category_records[category][f'AR@{iou_thr}'] | |
| table_columns[i + 1].append(f'{float(ap):.4f}') | |
| table_columns[i + 1].append(f'{float(ar):.4f}') | |
| ap = self.category_records[category][f'AP_C@{iou_thr}'] | |
| ar = self.category_records[category][f'AR_C@{iou_thr}'] | |
| table_columns[i + 1].append(f'{float(ap):.4f}') | |
| table_columns[i + 1].append(f'{float(ar):.4f}') | |
| for k in self.top_k_visible: | |
| top_k = self.category_records[category][ | |
| f'gTop-{k}@{iou_thr}'] | |
| table_columns[i + 1].append(f'{float(top_k):.4f}') | |
| # Number of gts | |
| table_columns[0].append('Num GT') | |
| for i, category in enumerate(self.category_records.keys()): | |
| table_columns[i + 1].append( | |
| f'{int(self.category_records[category]["num_gts"])}') | |
| table_data = [header] | |
| table_rows = list(zip(*table_columns)) | |
| table_data += table_rows | |
| table_data = [list(row) for row in zip(*table_data)] | |
| return table_data | |
| def __category_mapping__(self, sub_class: str) -> str: | |
| """Mapping the subclass name to the category name. | |
| Args: | |
| sub_class (str): The subclass name in the original samples. | |
| Returns: | |
| category (str): The category name. | |
| """ | |
| sub_class = sub_class.lower() | |
| sub_class = sub_class.replace('single', 'sngl') | |
| sub_class = sub_class.replace('inter', 'int') | |
| sub_class = sub_class.replace('unique', 'uniq') | |
| sub_class = sub_class.replace('common', 'cmn') | |
| sub_class = sub_class.replace('attribute', 'attr') | |
| if 'sngl' in sub_class and ('attr' in sub_class or 'eq' in sub_class): | |
| sub_class = 'vg_sngl_attr' | |
| return sub_class | |
| def __calculate_iou_array_( | |
| self, data_item: dict) -> Tuple[np.ndarray, np.ndarray]: | |
| """Calculate some information needed for eavl. | |
| Args: | |
| data_item (dict): The subclass name in the original samples. | |
| Returns: | |
| np.ndarray, np.ndarray : | |
| The iou array sorted by the confidence and the | |
| confidence scores. | |
| """ | |
| pred_bboxes = data_item['pred_bboxes'] | |
| gt_bboxes = data_item['gt_bboxes'] | |
| # Sort the bounding boxes based on their scores | |
| pred_scores = data_item['pred_scores'] | |
| top_idxs = torch.argsort(pred_scores, descending=True) | |
| pred_scores = pred_scores[top_idxs] | |
| pred_bboxes = to_9dof_box(index_box(pred_bboxes, top_idxs)) | |
| gt_bboxes = to_9dof_box(gt_bboxes) | |
| iou_matrix = pred_bboxes.overlaps(pred_bboxes, | |
| gt_bboxes) # (num_query, num_gt) | |
| # (3) calculate the TP and NP, | |
| # preparing for the forward AP/topk calculation | |
| pred_scores = pred_scores.cpu().numpy() | |
| iou_array = iou_matrix.cpu().numpy() | |
| return iou_array, pred_scores | |
| def __is_zero__(self, box): | |
| if isinstance(box, (list, tuple)): | |
| return (len(box[0]) == 0) | |
| return (len(box) == 0) | |
| def __check_format__(self, raw_input: List[dict]) -> None: | |
| """Check if the input conform with mmscan evaluation format. Transform | |
| the input box format. | |
| Args: | |
| raw_input (list[dict]): The input of VG evaluator. | |
| """ | |
| assert isinstance( | |
| raw_input, | |
| list), 'The input of VG evaluator should be a list of dict. ' | |
| raw_input = raw_input | |
| for _index in tqdm(range(len(raw_input))): | |
| if 'index' not in raw_input[_index]: | |
| raw_input[_index]['index'] = len(self.save_buffer) + _index | |
| if 'subclass' not in raw_input[_index]: | |
| raw_input[_index]['subclass'] = 'non-class' | |
| assert 'gt_bboxes' in raw_input[_index] | |
| assert 'pred_bboxes' in raw_input[_index] | |
| assert 'pred_scores' in raw_input[_index] | |
| for mode in ['pred_bboxes', 'gt_bboxes']: | |
| if (isinstance(raw_input[_index][mode], dict) | |
| and 'center' in raw_input[_index][mode]): | |
| raw_input[_index][mode] = [ | |
| torch.tensor(raw_input[_index][mode]['center']), | |
| torch.tensor(raw_input[_index][mode]['size']).to( | |
| torch.float32), | |
| torch.tensor(raw_input[_index][mode]['rot']).to( | |
| torch.float32) | |
| ] | |
| def trun_box(box_list): | |
| trun_box_list = [] | |
| for box in box_list: | |
| trun_box_list.append([round(x,2) for x in box]) | |
| return trun_box_list | |
| def evaluation_for_challenge(gt_data,pred_data): | |
| inputs = [] | |
| for sample_ID in gt_data: | |
| batch_result = {} | |
| if sample_ID not in pred_data: | |
| batch_result["pred_scores"] = torch.zeros(0,9) | |
| batch_result["pred_bboxes"] = torch.zeros(0,) | |
| else: | |
| batch_result["pred_scores"] = torch.tensor(pred_data[sample_ID]["score"][:100]) | |
| batch_result["pred_bboxes"] = torch.tensor(trun_box(pred_data[sample_ID]["pred_bboxes"][:100])) | |
| batch_result["gt_bboxes"] = torch.tensor(gt_data[sample_ID]) | |
| batch_result["subclass"] = sample_ID.split('__')[0] | |
| inputs.append(batch_result) | |
| vg_evaluator = VisualGroundingEvaluator() | |
| vg_evaluator.update(inputs) | |
| results = vg_evaluator.start_evaluation() | |
| #vg_evaluator.print_result() | |
| return results['overall'] |