Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import json | |
| import os.path as osp | |
| import re | |
| from pathlib import Path | |
| from unittest.mock import MagicMock | |
| import matplotlib.pyplot as plt | |
| import rich | |
| import torch.nn as nn | |
| from mmengine.config import Config, DictAction | |
| from mmengine.hooks import Hook | |
| from mmengine.model import BaseModel | |
| from mmengine.registry import init_default_scope | |
| from mmengine.runner import Runner | |
| from mmengine.visualization import Visualizer | |
| from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn | |
| from mmocr.registry import DATASETS | |
| class SimpleModel(BaseModel): | |
| """simple model that do nothing in train_step.""" | |
| def __init__(self): | |
| super(SimpleModel, self).__init__() | |
| self.data_preprocessor = nn.Identity() | |
| self.conv = nn.Conv2d(1, 1, 1) | |
| def forward(self, inputs, data_samples, mode='tensor'): | |
| pass | |
| def train_step(self, data, optim_wrapper): | |
| pass | |
| class ParamRecordHook(Hook): | |
| def __init__(self, by_epoch): | |
| super().__init__() | |
| self.by_epoch = by_epoch | |
| self.lr_list = [] | |
| self.momentum_list = [] | |
| self.wd_list = [] | |
| self.task_id = 0 | |
| self.progress = Progress(BarColumn(), MofNCompleteColumn(), | |
| TextColumn('{task.description}')) | |
| def before_train(self, runner): | |
| if self.by_epoch: | |
| total = runner.train_loop.max_epochs | |
| self.task_id = self.progress.add_task( | |
| 'epochs', start=True, total=total) | |
| else: | |
| total = runner.train_loop.max_iters | |
| self.task_id = self.progress.add_task( | |
| 'iters', start=True, total=total) | |
| self.progress.start() | |
| def after_train_epoch(self, runner): | |
| if self.by_epoch: | |
| self.progress.update(self.task_id, advance=1) | |
| def after_train_iter(self, runner, batch_idx, data_batch, outputs): | |
| if not self.by_epoch: | |
| self.progress.update(self.task_id, advance=1) | |
| self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0]) | |
| self.momentum_list.append( | |
| runner.optim_wrapper.get_momentum()['momentum'][0]) | |
| self.wd_list.append( | |
| runner.optim_wrapper.param_groups[0]['weight_decay']) | |
| def after_train(self, runner): | |
| self.progress.stop() | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description='Visualize a Dataset Pipeline') | |
| parser.add_argument('config', help='config file path') | |
| parser.add_argument( | |
| '-p', | |
| '--parameter', | |
| type=str, | |
| default='lr', | |
| choices=['lr', 'momentum', 'wd'], | |
| help='The parameter to visualize its change curve, choose from' | |
| '"lr", "wd" and "momentum". Defaults to "lr".') | |
| parser.add_argument( | |
| '-d', | |
| '--dataset-size', | |
| type=int, | |
| help='The size of the dataset. If specify, `build_dataset` will ' | |
| 'be skipped and use this size as the dataset size.') | |
| parser.add_argument( | |
| '-n', | |
| '--ngpus', | |
| type=int, | |
| default=1, | |
| help='The number of GPUs used in training.') | |
| parser.add_argument( | |
| '-s', | |
| '--save-path', | |
| type=Path, | |
| help='The learning rate curve plot save path') | |
| parser.add_argument( | |
| '--log-level', | |
| default='WARNING', | |
| help='The log level of the handler and logger. Defaults to ' | |
| 'WARNING.') | |
| parser.add_argument('--title', type=str, help='title of figure') | |
| parser.add_argument( | |
| '--style', type=str, default='whitegrid', help='style of plt') | |
| parser.add_argument('--not-show', default=False, action='store_true') | |
| parser.add_argument( | |
| '--window-size', | |
| default='12*7', | |
| help='Size of the window to display images, in format of "$W*$H".') | |
| parser.add_argument( | |
| '--cfg-options', | |
| nargs='+', | |
| action=DictAction, | |
| help='override some settings in the used config, the key-value pair ' | |
| 'in xxx=yyy format will be merged into config file. If the value to ' | |
| 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
| 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
| 'Note that the quotation marks are necessary and that no white space ' | |
| 'is allowed.') | |
| args = parser.parse_args() | |
| if args.window_size != '': | |
| assert re.match(r'\d+\*\d+', args.window_size), \ | |
| "'window-size' must be in format 'W*H'." | |
| return args | |
| def plot_curve(lr_list, args, param_name, iters_per_epoch, by_epoch=True): | |
| """Plot learning rate vs iter graph.""" | |
| try: | |
| import seaborn as sns | |
| sns.set_style(args.style) | |
| except ImportError: | |
| pass | |
| wind_w, wind_h = args.window_size.split('*') | |
| wind_w, wind_h = int(wind_w), int(wind_h) | |
| plt.figure(figsize=(wind_w, wind_h)) | |
| ax: plt.Axes = plt.subplot() | |
| ax.plot(lr_list, linewidth=1) | |
| if by_epoch: | |
| ax.xaxis.tick_top() | |
| ax.set_xlabel('Iters') | |
| ax.xaxis.set_label_position('top') | |
| sec_ax = ax.secondary_xaxis( | |
| 'bottom', | |
| functions=(lambda x: x / iters_per_epoch, | |
| lambda y: y * iters_per_epoch)) | |
| sec_ax.set_xlabel('Epochs') | |
| else: | |
| plt.xlabel('Iters') | |
| plt.ylabel(param_name) | |
| if args.title is None: | |
| plt.title(f'{osp.basename(args.config)} {param_name} curve') | |
| else: | |
| plt.title(args.title) | |
| def simulate_train(data_loader, cfg, by_epoch): | |
| model = SimpleModel() | |
| param_record_hook = ParamRecordHook(by_epoch=by_epoch) | |
| default_hooks = dict( | |
| param_scheduler=cfg.default_hooks['param_scheduler'], | |
| runtime_info=None, | |
| timer=None, | |
| logger=None, | |
| checkpoint=None, | |
| sampler_seed=None, | |
| param_record=param_record_hook) | |
| runner = Runner( | |
| model=model, | |
| work_dir=cfg.work_dir, | |
| train_dataloader=data_loader, | |
| train_cfg=cfg.train_cfg, | |
| log_level=cfg.log_level, | |
| optim_wrapper=cfg.optim_wrapper, | |
| param_scheduler=cfg.param_scheduler, | |
| default_scope=cfg.default_scope, | |
| default_hooks=default_hooks, | |
| visualizer=MagicMock(spec=Visualizer), | |
| custom_hooks=cfg.get('custom_hooks', None)) | |
| runner.train() | |
| param_dict = dict( | |
| lr=param_record_hook.lr_list, | |
| momentum=param_record_hook.momentum_list, | |
| wd=param_record_hook.wd_list) | |
| return param_dict | |
| def build_dataset(cfg): | |
| return DATASETS.build(cfg) | |
| def main(): | |
| args = parse_args() | |
| cfg = Config.fromfile(args.config) | |
| init_default_scope(cfg.get('default_scope', 'mmocr')) | |
| if args.cfg_options is not None: | |
| cfg.merge_from_dict(args.cfg_options) | |
| if cfg.get('work_dir', None) is None: | |
| # use config filename as default work_dir if cfg.work_dir is None | |
| cfg.work_dir = osp.join('./work_dirs', | |
| osp.splitext(osp.basename(args.config))[0]) | |
| cfg.log_level = args.log_level | |
| # make sure save_root exists | |
| if args.save_path and not args.save_path.parent.exists(): | |
| raise FileNotFoundError( | |
| f'The save path is {args.save_path}, and directory ' | |
| f"'{args.save_path.parent}' do not exist.") | |
| # init logger | |
| print('Param_scheduler :') | |
| rich.print_json(json.dumps(cfg.param_scheduler)) | |
| # prepare data loader | |
| batch_size = cfg.train_dataloader.batch_size * args.ngpus | |
| if 'by_epoch' in cfg.train_cfg: | |
| by_epoch = cfg.train_cfg.get('by_epoch') | |
| elif 'type' in cfg.train_cfg: | |
| by_epoch = cfg.train_cfg.get('type') == 'EpochBasedTrainLoop' | |
| else: | |
| raise ValueError('please set `train_cfg`.') | |
| if args.dataset_size is None and by_epoch: | |
| dataset_size = len(build_dataset(cfg.train_dataloader.dataset)) | |
| else: | |
| dataset_size = args.dataset_size or batch_size | |
| class FakeDataloader(list): | |
| dataset = MagicMock(metainfo=None) | |
| data_loader = FakeDataloader(range(dataset_size // batch_size)) | |
| dataset_info = ( | |
| f'\nDataset infos:' | |
| f'\n - Dataset size: {dataset_size}' | |
| f'\n - Batch size per GPU: {cfg.train_dataloader.batch_size}' | |
| f'\n - Number of GPUs: {args.ngpus}' | |
| f'\n - Total batch size: {batch_size}') | |
| if by_epoch: | |
| dataset_info += f'\n - Iterations per epoch: {len(data_loader)}' | |
| rich.print(dataset_info + '\n') | |
| # simulation training process | |
| param_dict = simulate_train(data_loader, cfg, by_epoch) | |
| param_list = param_dict[args.parameter] | |
| if args.parameter == 'lr': | |
| param_name = 'Learning Rate' | |
| elif args.parameter == 'momentum': | |
| param_name = 'Momentum' | |
| else: | |
| param_name = 'Weight Decay' | |
| plot_curve(param_list, args, param_name, len(data_loader), by_epoch) | |
| if args.save_path: | |
| plt.savefig(args.save_path) | |
| print(f'\nThe {param_name} graph is saved at {args.save_path}') | |
| if not args.not_show: | |
| plt.show() | |
| if __name__ == '__main__': | |
| main() | |