Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import os.path as osp | |
| import sys | |
| from typing import Optional, Tuple | |
| import cv2 | |
| import mmcv | |
| import numpy as np | |
| from mmengine.config import Config, DictAction | |
| from mmengine.dataset import Compose | |
| from mmengine.registry import init_default_scope | |
| from mmengine.utils import ProgressBar | |
| from mmengine.visualization import Visualizer | |
| from mmocr.registry import DATASETS, VISUALIZERS | |
| # TODO: Support for printing the change in key of results | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Browse a dataset') | |
| parser.add_argument('config', help='Path to model or dataset config.') | |
| parser.add_argument( | |
| '--phase', | |
| '-p', | |
| default='train', | |
| type=str, | |
| help='Phase of dataset to visualize. Use "train", "test" or "val" if ' | |
| "you just want to visualize the default split. It's also possible to " | |
| 'be a dataset variable name, which might be useful when a dataset ' | |
| 'split has multiple variants in the config.') | |
| parser.add_argument( | |
| '--mode', | |
| '-m', | |
| default='transformed', | |
| type=str, | |
| choices=['original', 'transformed', 'pipeline'], | |
| help='Display mode: display original pictures or ' | |
| 'transformed pictures or comparison pictures. "original" ' | |
| 'only visualizes the original dataset & annotations; ' | |
| '"transformed" shows the resulting images processed through all the ' | |
| 'transforms; "pipeline" shows all the intermediate images. ' | |
| 'Defaults to "transformed".') | |
| parser.add_argument( | |
| '--output-dir', | |
| '-o', | |
| default=None, | |
| type=str, | |
| help='If there is no display interface, you can save it.') | |
| parser.add_argument( | |
| '--task', | |
| '-t', | |
| default='auto', | |
| choices=['auto', 'textdet', 'textrecog'], | |
| type=str, | |
| help='Specify the task type of the dataset. If "auto", the task type ' | |
| 'will be inferred from the config. If the script is unable to infer ' | |
| 'the task type, you need to specify it manually. Defaults to "auto".') | |
| parser.add_argument('--not-show', default=False, action='store_true') | |
| parser.add_argument( | |
| '--show-number', | |
| '-n', | |
| type=int, | |
| default=sys.maxsize, | |
| help='number of images selected to visualize, ' | |
| 'must bigger than 0. if the number is bigger than length ' | |
| 'of dataset, show all the images in dataset; ' | |
| 'default "sys.maxsize", show all images in dataset') | |
| parser.add_argument( | |
| '--show-interval', | |
| '-i', | |
| type=float, | |
| default=3, | |
| help='the interval of show (s)') | |
| parser.add_argument( | |
| '--cfg-options', | |
| nargs='+', | |
| action=DictAction, | |
| help='override some settings in the used config, the key-value pair ' | |
| 'in xxx=yyy format will be merged into config file. If the value to ' | |
| 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
| 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
| 'Note that the quotation marks are necessary and that no white space ' | |
| 'is allowed.') | |
| args = parser.parse_args() | |
| return args | |
| def _get_adaptive_scale(img_shape: Tuple[int, int], | |
| min_scale: float = 0.3, | |
| max_scale: float = 3.0) -> float: | |
| """Get adaptive scale according to image shape. | |
| The target scale depends on the the short edge length of the image. If the | |
| short edge length equals 224, the output is 1.0. And output linear | |
| scales according the short edge length. You can also specify the minimum | |
| scale and the maximum scale to limit the linear scale. | |
| Args: | |
| img_shape (Tuple[int, int]): The shape of the canvas image. | |
| min_scale (int): The minimum scale. Defaults to 0.3. | |
| max_scale (int): The maximum scale. Defaults to 3.0. | |
| Returns: | |
| int: The adaptive scale. | |
| """ | |
| short_edge_length = min(img_shape) | |
| scale = short_edge_length / 224. | |
| return min(max(scale, min_scale), max_scale) | |
| def make_grid(imgs, infos): | |
| """Concat list of pictures into a single big picture, align height here.""" | |
| visualizer = Visualizer.get_current_instance() | |
| names = [info['name'] for info in infos] | |
| ori_shapes = [ | |
| info['dataset_sample'].metainfo['img_shape'] for info in infos | |
| ] | |
| max_height = int(max(img.shape[0] for img in imgs) * 1.1) | |
| min_width = min(img.shape[1] for img in imgs) | |
| horizontal_gap = min_width // 10 | |
| img_scale = _get_adaptive_scale((max_height, min_width)) | |
| texts = [] | |
| text_positions = [] | |
| start_x = 0 | |
| for i, img in enumerate(imgs): | |
| pad_height = (max_height - img.shape[0]) // 2 | |
| pad_width = horizontal_gap // 2 | |
| # make border | |
| imgs[i] = cv2.copyMakeBorder( | |
| img, | |
| pad_height, | |
| max_height - img.shape[0] - pad_height + int(img_scale * 30 * 2), | |
| pad_width, | |
| pad_width, | |
| cv2.BORDER_CONSTANT, | |
| value=(255, 255, 255)) | |
| texts.append(f'{"execution: "}{i}\n{names[i]}\n{ori_shapes[i]}') | |
| text_positions.append( | |
| [start_x + img.shape[1] // 2 + pad_width, max_height]) | |
| start_x += img.shape[1] + horizontal_gap | |
| display_img = np.concatenate(imgs, axis=1) | |
| visualizer.set_image(display_img) | |
| img_scale = _get_adaptive_scale(display_img.shape[:2]) | |
| visualizer.draw_texts( | |
| texts, | |
| positions=np.array(text_positions), | |
| font_sizes=img_scale * 7, | |
| colors='black', | |
| horizontal_alignments='center', | |
| font_families='monospace') | |
| return visualizer.get_image() | |
| class InspectCompose(Compose): | |
| """Compose multiple transforms sequentially. | |
| And record "img" field of all results in one list. | |
| """ | |
| def __init__(self, transforms, intermediate_imgs): | |
| super().__init__(transforms=transforms) | |
| self.intermediate_imgs = intermediate_imgs | |
| def __call__(self, data): | |
| self.ptransforms = [ | |
| self.transforms[i] for i in range(len(self.transforms) - 1) | |
| ] | |
| for t in self.ptransforms: | |
| data = t(data) | |
| # Keep the same meta_keys in the PackTextDetInputs | |
| # or PackTextRecogInputs | |
| self.transforms[-1].meta_keys = [key for key in data] | |
| data_sample = self.transforms[-1](data) | |
| if data is None: | |
| return None | |
| if 'img' in data: | |
| self.intermediate_imgs.append({ | |
| 'name': | |
| t.__class__.__name__, | |
| 'dataset_sample': | |
| data_sample['data_samples'] | |
| }) | |
| return data | |
| def infer_dataset_task(task: str, | |
| dataset_cfg: Config, | |
| var_name: Optional[str] = None) -> str: | |
| """Try to infer the dataset's task type from the config and the variable | |
| name.""" | |
| if task != 'auto': | |
| return task | |
| if dataset_cfg.pipeline is not None: | |
| if dataset_cfg.pipeline[-1].type == 'PackTextDetInputs': | |
| return 'textdet' | |
| elif dataset_cfg.pipeline[-1].type == 'PackTextRecogInputs': | |
| return 'textrecog' | |
| if var_name is not None: | |
| if 'det' in var_name: | |
| return 'textdet' | |
| elif 'rec' in var_name: | |
| return 'textrecog' | |
| raise ValueError( | |
| 'Unable to infer the task type from dataset pipeline ' | |
| 'or variable name. Please specify the task type with --task argument ' | |
| 'explicitly.') | |
| def obtain_dataset_cfg(cfg: Config, phase: str, mode: str, task: str) -> Tuple: | |
| """Obtain dataset and visualizer from config. Two modes are supported: | |
| 1. Model Config Mode: | |
| In this mode, the input config should be a complete model config, which | |
| includes a dataset within pipeline and a visualizer. | |
| 2. Dataset Config Mode: | |
| In this mode, the input config should be a complete dataset config, | |
| which only includes basic dataset information, and it may does not | |
| contain a visualizer and dataset pipeline. | |
| Examples: | |
| Typically, the model config files are stored in | |
| `configs/textdet/dbnet/xxx.py` and should look like: | |
| >>> train_dataloader = dict( | |
| >>> batch_size=16, | |
| >>> num_workers=8, | |
| >>> persistent_workers=True, | |
| >>> sampler=dict(type='DefaultSampler', shuffle=True), | |
| >>> dataset=icdar2015_textdet_train) | |
| while the dataset config files are stored in | |
| `configs/textdet/_base_/datasets/xxx.py` and should be like: | |
| >>> icdar2015_textdet_train = dict( | |
| >>> type='OCRDataset', | |
| >>> data_root=ic15_det_data_root, | |
| >>> ann_file='textdet_train.json', | |
| >>> filter_cfg=dict(filter_empty_gt=True, min_size=32), | |
| >>> pipeline=None) | |
| Args: | |
| cfg (Config): Config object. | |
| phase (str): The dataset phase to visualize. | |
| mode (str): Script mode. | |
| task (str): The current task type. | |
| Returns: | |
| Tuple: Tuple of (dataset, visualizer). | |
| """ | |
| default_cfgs = dict( | |
| textdet=dict( | |
| visualizer=dict( | |
| type='TextDetLocalVisualizer', | |
| name='visualizer', | |
| vis_backends=[dict(type='LocalVisBackend')]), | |
| pipeline=[ | |
| dict( | |
| type='LoadImageFromFile', | |
| color_type='color_ignore_orientation'), | |
| dict( | |
| type='LoadOCRAnnotations', | |
| with_polygon=True, | |
| with_bbox=True, | |
| with_label=True, | |
| ), | |
| dict( | |
| type='PackTextDetInputs', | |
| meta_keys=('img_path', 'ori_shape', 'img_shape')) | |
| ]), | |
| textrecog=dict( | |
| visualizer=dict( | |
| type='TextRecogLocalVisualizer', | |
| name='visualizer', | |
| vis_backends=[dict(type='LocalVisBackend')]), | |
| pipeline=[ | |
| dict(type='LoadImageFromFile', ignore_empty=True, min_size=2), | |
| dict(type='LoadOCRAnnotations', with_text=True), | |
| dict( | |
| type='PackTextRecogInputs', | |
| meta_keys=('img_path', 'ori_shape', 'img_shape', | |
| 'valid_ratio')) | |
| ]), | |
| ) | |
| # Model config mode | |
| dataloader_name = f'{phase}_dataloader' | |
| if dataloader_name in cfg: | |
| dataset = cfg.get(dataloader_name).dataset | |
| visualizer = cfg.visualizer | |
| if mode == 'original': | |
| default_cfg = default_cfgs[infer_dataset_task(task, dataset)] | |
| # Image can be stored in other methods, like LMDB, | |
| # which LoadImageFromFile can not handle | |
| if dataset.pipeline is not None: | |
| all_transform_types = [tfm['type'] for tfm in dataset.pipeline] | |
| if any([ | |
| tfm_type.startswith('LoadImageFrom') | |
| for tfm_type in all_transform_types | |
| ]): | |
| for tfm in dataset.pipeline: | |
| if tfm['type'].startswith('LoadImageFrom'): | |
| # update LoadImageFrom** transform | |
| default_cfg['pipeline'][0] = tfm | |
| dataset.pipeline = default_cfg['pipeline'] | |
| else: | |
| # In test_pipeline LoadOCRAnnotations is placed behind | |
| # other transforms. Transform will not be applied on | |
| # gt annotation. | |
| if phase == 'test': | |
| all_transform_types = [tfm['type'] for tfm in dataset.pipeline] | |
| load_ocr_ann_tfm_index = all_transform_types.index( | |
| 'LoadOCRAnnotations') | |
| load_ocr_ann_tfm = dataset.pipeline.pop(load_ocr_ann_tfm_index) | |
| dataset.pipeline.insert(1, load_ocr_ann_tfm) | |
| return dataset, visualizer | |
| # Dataset config mode | |
| for key in cfg.keys(): | |
| if key.endswith(phase) and cfg[key]['type'].endswith('Dataset'): | |
| dataset = cfg[key] | |
| default_cfg = default_cfgs[infer_dataset_task( | |
| task, dataset, key.lower())] | |
| visualizer = default_cfg['visualizer'] | |
| dataset['pipeline'] = default_cfg['pipeline'] if dataset[ | |
| 'pipeline'] is None else dataset['pipeline'] | |
| return dataset, visualizer | |
| raise ValueError( | |
| f'Unable to find "{phase}_dataloader" or any dataset variable ending ' | |
| f'with "{phase}". Please check your config file or --phase argument ' | |
| 'and try again. More details can be found in the docstring of ' | |
| 'obtain_dataset_cfg function. Or, you may visit the documentation via ' | |
| 'https://mmocr.readthedocs.io/en/dev-1.x/user_guides/useful_tools.html#dataset-visualization-tool' # noqa: E501 | |
| ) | |
| def main(): | |
| args = parse_args() | |
| cfg = Config.fromfile(args.config) | |
| if args.cfg_options is not None: | |
| cfg.merge_from_dict(args.cfg_options) | |
| init_default_scope(cfg.get('default_scope', 'mmocr')) | |
| dataset_cfg, visualizer_cfg = obtain_dataset_cfg(cfg, args.phase, | |
| args.mode, args.task) | |
| dataset = DATASETS.build(dataset_cfg) | |
| visualizer = VISUALIZERS.build(visualizer_cfg) | |
| visualizer.dataset_meta = dataset.metainfo | |
| intermediate_imgs = [] | |
| if dataset_cfg.type == 'ConcatDataset': | |
| for sub_dataset in dataset.datasets: | |
| sub_dataset.pipeline = InspectCompose( | |
| sub_dataset.pipeline.transforms, intermediate_imgs) | |
| else: | |
| dataset.pipeline = InspectCompose(dataset.pipeline.transforms, | |
| intermediate_imgs) | |
| # init visualization image number | |
| assert args.show_number > 0 | |
| display_number = min(args.show_number, len(dataset)) | |
| progress_bar = ProgressBar(display_number) | |
| # fetching items from dataset is a must for visualization | |
| for i, _ in zip(range(display_number), dataset): | |
| image_i = [] | |
| result_i = [result['dataset_sample'] for result in intermediate_imgs] | |
| for k, datasample in enumerate(result_i): | |
| image = datasample.img | |
| if len(image.shape) == 3: | |
| image = image[..., [2, 1, 0]] # bgr to rgb | |
| image_show = visualizer.add_datasample( | |
| 'result', | |
| image, | |
| datasample, | |
| draw_pred=False, | |
| draw_gt=True, | |
| show=False) | |
| image_i.append(image_show) | |
| if args.mode == 'pipeline': | |
| image = make_grid(image_i, intermediate_imgs) | |
| else: | |
| image = image_i[-1] | |
| if hasattr(datasample, 'img_path'): | |
| filename = osp.basename(datasample.img_path) | |
| else: | |
| # some dataset have not image path | |
| filename = f'{i}.jpg' | |
| out_file = osp.join(args.output_dir, | |
| filename) if args.output_dir is not None else None | |
| if out_file is not None: | |
| mmcv.imwrite(image[..., ::-1], out_file) | |
| if not args.not_show: | |
| visualizer.show( | |
| image, win_name=filename, wait_time=args.show_interval) | |
| intermediate_imgs.clear() | |
| progress_bar.update() | |
| if __name__ == '__main__': | |
| main() | |