Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) | |
| # 2023 Tsinghua Univ. (authors: Xingchen Song) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from contextlib import nullcontext | |
| import copy | |
| from typing import List, Optional | |
| import json | |
| import logging | |
| import os | |
| import torch | |
| import yaml | |
| import torch.optim as optim | |
| import torch.distributed as dist | |
| from tensorboardX import SummaryWriter | |
| from torch.utils.data import DataLoader | |
| from torch.nn.utils import clip_grad_norm_ | |
| from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, | |
| CPUOffload, MixedPrecision, | |
| sharded_grad_scaler, ShardingStrategy) | |
| try: | |
| import deepspeed | |
| from deepspeed.runtime.zero.stage_1_and_2 import ( | |
| estimate_zero2_model_states_mem_needs_all_live) | |
| from deepspeed.runtime.zero.stage3 import ( | |
| estimate_zero3_model_states_mem_needs_all_live) | |
| from deepspeed.utils.zero_to_fp32 import ( | |
| convert_zero_checkpoint_to_fp32_state_dict) | |
| except ImportError: | |
| pass | |
| from wenet.utils.checkpoint import save_checkpoint | |
| from wenet.utils.common import (StepTimer, get_nested_attribute, lrs_to_str, | |
| tensor_to_scalar) | |
| from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model, | |
| apply_fsdp_checkpointing, | |
| wenet_fsdp_wrap_policy) | |
| from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing | |
| from wenet.utils.ctc_utils import get_blank_id | |
| from wenet.utils.common import TORCH_NPU_AVAILABLE | |
| from wenet.utils.init_dataset import init_dataset | |
| def add_model_args(parser): | |
| parser.add_argument('--config', required=True, help='config file') | |
| parser.add_argument('--model_dir', required=True, help='save model dir') | |
| parser.add_argument('--checkpoint', help='checkpoint model') | |
| parser.add_argument('--tensorboard_dir', | |
| default='tensorboard', | |
| help='tensorboard log dir') | |
| parser.add_argument('--override_config', | |
| action='append', | |
| default=[], | |
| help="override yaml config") | |
| parser.add_argument("--enc_init", | |
| default=None, | |
| type=str, | |
| help="Pre-trained model to initialize encoder") | |
| parser.add_argument( | |
| '--enc_init_mods', | |
| default="encoder.", | |
| type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | |
| help="List of encoder modules \ | |
| to initialize ,separated by a comma") | |
| parser.add_argument( | |
| '--freeze_modules', | |
| default="", | |
| type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | |
| help='free module names', | |
| ) | |
| return parser | |
| def add_trace_args(parser): | |
| parser.add_argument('--jit', | |
| action='store_true', | |
| default=False, | |
| help='if use jit to trace model while training stage') | |
| parser.add_argument('--print_model', | |
| action='store_true', | |
| default=False, | |
| help='print model') | |
| return parser | |
| def add_dataset_args(parser): | |
| parser.add_argument('--data_type', | |
| default='raw', | |
| # choices=['raw', 'shard'], | |
| help='train and cv data type') | |
| parser.add_argument('--train_data', required=True, help='train data file') | |
| parser.add_argument('--cv_data', required=True, help='cv data file') | |
| parser.add_argument('--num_workers', | |
| default=0, | |
| type=int, | |
| help='num of subprocess workers for reading') | |
| parser.add_argument('--pin_memory', | |
| action='store_true', | |
| default=False, | |
| help='Use pinned memory buffers used for reading') | |
| parser.add_argument('--prefetch', | |
| default=100, | |
| type=int, | |
| help='prefetch number') | |
| return parser | |
| def add_lora_args(parser): | |
| '''Configure parameters for LoRA fine-tuning. Set use_lora and | |
| only_optimize_lora to true to enable LoRA functionality. | |
| LoRA will be injected to model through (lora_modules, lora_attn_attr, | |
| lora_list). | |
| LoRA weights will be merged after calling model.eval() | |
| (or model.train(mode=False)). | |
| LoRA weights need to be loaded after fine-tuning with DeepSpeed. | |
| ''' | |
| parser.add_argument("--use_lora", | |
| default=False, | |
| type=bool, | |
| help="whether use the lora finetune.") | |
| parser.add_argument("--only_optimize_lora", | |
| default=False, | |
| type=bool, | |
| help="freeze all other paramters and only optimize \ | |
| LoRA-related prameters.") | |
| parser.add_argument( | |
| '--lora_modules', | |
| default="encoder.encoders", | |
| type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | |
| help='modules names needs inject lora', | |
| ) | |
| parser.add_argument( | |
| "--lora_attn_attr", | |
| default="self_attn,src_attn", | |
| type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | |
| help="lora_attn_attr.") | |
| parser.add_argument( | |
| "--lora_list", | |
| default="linear_out,linear_q,linear_k,linear_v", | |
| type=lambda s: [str(mod) for mod in s.split(",") if s != ""], | |
| help="lora module list.") | |
| parser.add_argument("--lora_rank", | |
| default=8, | |
| type=int, | |
| help="lora rank num.") | |
| parser.add_argument("--lora_alpha", | |
| default=8, | |
| type=int, | |
| help="lora scale param, scale=lora_alpha/lora_rank.") | |
| parser.add_argument("--lora_dropout", | |
| default=0, | |
| type=float, | |
| help="lora dropout param.") | |
| parser.add_argument("--lora_ckpt_path", | |
| default=None, | |
| type=str, | |
| help="lora checkpoint path.") | |
| parser.add_argument("--lora_reinit", | |
| default=False, | |
| type=bool, | |
| help="whether use the lora init, default is zero init.") | |
| parser.add_argument('--lora_init_yaml', | |
| default="wenet/finetune/lora/config.yaml", | |
| type=str, | |
| help='Path to the configuration YAML file') | |
| return parser | |
| def add_ddp_args(parser): | |
| parser.add_argument('--ddp.dist_backend', | |
| dest='dist_backend', | |
| default='nccl', | |
| choices=['nccl', 'gloo', "hccl"], | |
| help='distributed backend') | |
| parser.add_argument('--use_amp', | |
| action='store_true', | |
| default=False, | |
| help='Use automatic mixed precision training') | |
| parser.add_argument('--fp16_grad_sync', | |
| action='store_true', | |
| default=False, | |
| help='Use fp16 gradient sync for ddp') | |
| return parser | |
| def add_deepspeed_args(parser): | |
| parser.add_argument('--timeout', | |
| default=30, | |
| type=int, | |
| help='timeout (in seconds) of wenet_join. ' + | |
| '30s for aishell & 300s for wenetspeech') | |
| parser.add_argument('--local_rank', | |
| type=int, | |
| default=-1, | |
| help='local rank passed from distributed launcher') | |
| parser.add_argument('--deepspeed.save_states', | |
| dest='save_states', | |
| default='model_only', | |
| choices=['model_only', 'model+optimizer'], | |
| help='save model/optimizer states') | |
| # DeepSpeed automaticly add '--deepspeed' and '--deepspeed_config' to parser | |
| try: | |
| parser = deepspeed.add_config_arguments(parser) | |
| except Exception as e: | |
| print(e) | |
| return parser | |
| def add_fsdp_args(parser): | |
| parser.add_argument( | |
| '--dtype', | |
| default='fp32', | |
| choices=['fp32', 'fp16', 'bf16'], | |
| help='when amp is used, dtype is automatically set to fp16.\ | |
| this arg has no effect when deepspeed is enabled.') | |
| parser.add_argument( | |
| '--fsdp_cpu_offload', | |
| default=False, | |
| type=bool, | |
| help='whether to offload parameters to CPU', | |
| ) | |
| parser.add_argument( | |
| '--fsdp_sync_module_states', | |
| type=bool, | |
| default=True, | |
| help='\ | |
| each FSDP module will broadcast module parameters and buffers from \ | |
| rank 0 to ensure that they are replicated across ranks', | |
| ) | |
| parser.add_argument( | |
| '--fsdp_sharding_strategy', | |
| default='zero2', | |
| # TODO(Mddct): pipeline and model parallel (3-D parallelism) | |
| choices=['no_shard', 'model', 'zero2', 'zero3'], | |
| help='Sharding strategy for FSDP. Choose from the following options:\n' | |
| ' - "no_shard": Equivalent to DistributedDataParallel (DDP).\n' | |
| ' - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n' | |
| ' - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n' | |
| ' - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n' | |
| 'For more information, refer to the FSDP API documentation.') | |
| return parser | |
| def init_distributed(args): | |
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) | |
| rank = int(os.environ.get('RANK', 0)) | |
| logging.info('training on multiple gpus, this gpu {}'.format(local_rank) + | |
| ', rank {}, world_size {}'.format(rank, world_size)) | |
| if args.train_engine in ["torch_ddp", "torch_fsdp"]: | |
| if "cuda" in args.device: | |
| torch.cuda.set_device(local_rank) | |
| elif "npu" in args.device and TORCH_NPU_AVAILABLE: | |
| torch.npu.set_device(local_rank) | |
| else: | |
| logging.error("not supported device: {}".format(args.device)) | |
| dist.init_process_group(args.dist_backend) | |
| elif args.train_engine == "deepspeed": | |
| deepspeed.init_distributed(dist_backend=args.dist_backend) | |
| else: | |
| logging.error("not supported engine: {}".format(args.train_engine)) | |
| return world_size, local_rank, rank | |
| def check_modify_and_save_config(args, configs, symbol_table): | |
| if args.train_engine in ["torch_ddp", "torch_fsdp"]: | |
| if args.use_amp: | |
| configs["dtype"] = "fp16" | |
| args.dtype = 'fp16' | |
| else: | |
| configs["dtype"] = args.dtype | |
| elif args.train_engine == "deepspeed": | |
| # NOTE(xcsong): DeepSpeed does not support uneven data. When using custom | |
| # dataset, we need to manually ensure that the data is evenly distributed | |
| # across all processe. we impl `train_utils.py::wenet_join` for this func | |
| # ref: https://github.com/microsoft/DeepSpeed/issues/2223 | |
| # | |
| # NOTE(xsong): We also need to keep: | |
| # 1. `train_micro_batch_size_per_gpu == 1` | |
| # 2. `accum_grad (in train_confomrer.yaml) | |
| # == gradient_accumulation_steps (in ds_config.json)` | |
| # 3. `grad_clip (in train_confomrer.yaml) | |
| # == gradient_clipping (in ds_config.json)` | |
| # The reason for such consistence checking lies in that deepspeed's native | |
| # dataloader uses PyTorch's torch.utils.data.DistributedSampler which does | |
| # not support IterableDataset, IterableDataset is extremly useful in large | |
| # scale training because it lets you stream the data without having to | |
| # download the complete dataset. | |
| # ref: https://github.com/microsoft/DeepSpeed/issues/1371 | |
| # https://github.com/microsoft/DeepSpeed/issues/285 | |
| # To make deepspeed training compatible with IterableDataset, we have to | |
| # use custom dataloader instead of deepspeed's native loader and thus we | |
| # should configure batchsize in train_confomrer.yaml instead of | |
| # ds_config.json. On the contrary, gradient accumulation / clipping should be | |
| # configured in ds_config.json since they will be handled by ds automatically. | |
| # ref: https://github.com/microsoft/DeepSpeed/issues/62 | |
| with open(args.deepspeed_config, 'r') as fin: | |
| ds_configs = json.load(fin) | |
| if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]: | |
| configs["dtype"] = "fp16" | |
| elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]: | |
| configs["dtype"] = "bf16" | |
| else: | |
| configs["dtype"] = "fp32" | |
| assert ds_configs["train_micro_batch_size_per_gpu"] == 1 | |
| assert ds_configs["gradient_accumulation_steps"] == configs[ | |
| 'accum_grad'] | |
| assert ds_configs["gradient_clipping"] == configs['grad_clip'] | |
| assert ds_configs["steps_per_print"] == configs['log_interval'] | |
| if args.use_lora: | |
| configs['lora_conf'] = {} | |
| configs['lora_conf']['lora_modules'] = args.lora_modules | |
| configs['lora_conf']['lora_attn_attr'] = args.lora_attn_attr | |
| configs['lora_conf']['lora_list'] = args.lora_list | |
| configs['lora_conf']['lora_rank'] = args.lora_rank | |
| configs['lora_conf']['lora_alpha'] = args.lora_alpha | |
| configs['lora_conf']['lora_dropout'] = args.lora_dropout | |
| if configs["model"] == 'asr_model': | |
| if 'input_dim' not in configs: | |
| if 'fbank_conf' in configs['dataset_conf']: | |
| input_dim = configs['dataset_conf']['fbank_conf'][ | |
| 'num_mel_bins'] | |
| elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: | |
| input_dim = configs['dataset_conf'][ | |
| 'log_mel_spectrogram_conf']['num_mel_bins'] | |
| else: | |
| input_dim = configs['dataset_conf']['mfcc_conf'][ | |
| 'num_mel_bins'] | |
| else: | |
| input_dim = configs['input_dim'] | |
| configs['input_dim'] = input_dim | |
| configs, _ = get_blank_id(configs, symbol_table) | |
| configs['output_dim'] = configs['vocab_size'] | |
| configs['train_engine'] = args.train_engine | |
| configs['use_amp'] = args.use_amp | |
| configs['model_dir'] = args.model_dir | |
| configs['save_states'] = args.save_states | |
| # Save configs to model_dir/train.yaml for inference and export | |
| if int(os.environ.get('RANK', 0)) == 0: | |
| saved_config_path = os.path.join(args.model_dir, 'train.yaml') | |
| with open(saved_config_path, 'w') as fout: | |
| data = yaml.dump(configs) | |
| fout.write(data) | |
| if configs["model_conf"].get("apply_non_blank_embedding", False): | |
| logging.warn('Had better load a well trained model' | |
| 'if apply_non_blank_embedding is true !!!') | |
| return configs | |
| def init_dataset_and_dataloader(args, configs, tokenizer, seed=777): | |
| generator = torch.Generator() | |
| generator.manual_seed(seed) | |
| # if save_interval in configs, steps mode else epoch mode | |
| if "save_interval" in configs: | |
| configs['dataset_conf']['cycle'] = configs.get('max_epoch', 100) | |
| conf = configs['dataset_conf'] | |
| dataset_type = configs.get('dataset', 'asr') | |
| configs['vocab_size'] = tokenizer.vocab_size() | |
| train_dataset = init_dataset(dataset_type, | |
| args.data_type, | |
| args.train_data, | |
| tokenizer, | |
| conf, | |
| True, | |
| split='train') | |
| tag = configs["init_infos"].get("tag", "init") | |
| train_dataset.set_epoch(configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) - 1) | |
| cv_conf = copy.deepcopy(conf) | |
| cv_conf['split_num'] = 1 | |
| cv_dataset = init_dataset(dataset_type, | |
| args.data_type, | |
| args.cv_data, | |
| tokenizer, | |
| cv_conf, | |
| partition=False, | |
| split='cv') | |
| # NOTE(xcsong): Why we prefer persistent_workers=True ? | |
| # https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110 | |
| train_data_loader = DataLoader(train_dataset, | |
| batch_size=None, | |
| pin_memory=args.pin_memory, | |
| num_workers=args.num_workers, | |
| persistent_workers=True, | |
| generator=generator, | |
| prefetch_factor=args.prefetch) | |
| cv_data_loader = DataLoader(cv_dataset, | |
| batch_size=None, | |
| pin_memory=args.pin_memory, | |
| num_workers=args.num_workers, | |
| persistent_workers=True, | |
| generator=generator, | |
| prefetch_factor=args.prefetch) | |
| return train_dataset, cv_dataset, train_data_loader, cv_data_loader | |
| def wrap_cuda_model(args, model, configs=None): | |
| local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) | |
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | |
| if hasattr(model, 'encoder'): | |
| grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False) | |
| else: | |
| grad_ckpt = False | |
| if args.train_engine == "torch_ddp": # native pytorch ddp | |
| device = torch.device(args.device) | |
| model.to(device) | |
| # model = torch.nn.parallel.DistributedDataParallel( | |
| # model, find_unused_parameters=not grad_ckpt) | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| model, find_unused_parameters=True) | |
| elif args.train_engine == "deepspeed": # deepspeed | |
| # NOTE(xcsong): look in detail how the memory estimator API works: | |
| # https://deepspeed.readthedocs.io/en/latest/memory.html#discussion | |
| if int(os.environ.get('RANK', 0)) == 0: | |
| logging.info("Estimating model states memory needs (zero2)...") | |
| estimate_zero2_model_states_mem_needs_all_live( | |
| model, | |
| num_gpus_per_node=local_world_size, | |
| num_nodes=world_size // local_world_size) | |
| logging.info("Estimating model states memory needs (zero3)...") | |
| estimate_zero3_model_states_mem_needs_all_live( | |
| model, | |
| num_gpus_per_node=local_world_size, | |
| num_nodes=world_size // local_world_size) | |
| device = torch.device(args.device) # Init device later | |
| pass # Init DeepSpeed later | |
| elif args.train_engine == 'torch_fsdp': | |
| assert configs is not None | |
| mixed_precision_dtype = { | |
| 'fp32': torch.float32, | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| }[configs['dtype']] | |
| sharding_strategy = { | |
| 'model': ShardingStrategy.SHARD_GRAD_OP, | |
| 'zero2': ShardingStrategy.SHARD_GRAD_OP, | |
| 'zero3': ShardingStrategy.FULL_SHARD, | |
| 'no_shard': ShardingStrategy.NO_SHARD, | |
| }[args.fsdp_sharding_strategy] | |
| wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy) | |
| layer_types = check_gradient_checkpoint(model) | |
| if "cuda" in args.device: | |
| device_id = torch.cuda.current_device() | |
| elif "npu" in args.device and TORCH_NPU_AVAILABLE: | |
| device_id = torch.npu.current_device() | |
| else: | |
| logging.error("not supported device: {}".format(args.device)) | |
| model = FSDP( | |
| model, | |
| auto_wrap_policy=wrap_policy, | |
| cpu_offload=CPUOffload(offload_params=True) | |
| if args.fsdp_cpu_offload is True else None, | |
| mixed_precision=MixedPrecision( | |
| param_dtype=mixed_precision_dtype, | |
| reduce_dtype=mixed_precision_dtype, | |
| buffer_dtype=mixed_precision_dtype, | |
| ), | |
| sharding_strategy=sharding_strategy, | |
| limit_all_gathers=True, | |
| use_orig_params=True, | |
| sync_module_states=args.fsdp_sync_module_states, | |
| # init_distributed is called (torch.cuda.set_device), | |
| # we should set device_id, see FSDP api | |
| device_id=device_id) | |
| apply_fsdp_checkpointing(model, layer_types) | |
| device = torch.device(args.device) | |
| else: | |
| logging.error("not supported engine: {}".format(args.train_engine)) | |
| if args.train_engine in ["torch_fsdp", "torch_ddp"]: | |
| if args.fp16_grad_sync: | |
| from torch.distributed.algorithms.ddp_comm_hooks import ( | |
| default as comm_hooks, ) | |
| model.register_comm_hook(state=None, | |
| hook=comm_hooks.fp16_compress_hook) | |
| return model, device | |
| def init_optimizer_and_scheduler(args, configs, model): | |
| groups = [] | |
| lr = configs['optim_conf'].get('lr') | |
| if isinstance(lr, List): | |
| assert configs['scheduler'] == 'warmuplr' | |
| modules_m = configs['optim_conf']['modules'] | |
| assert isinstance(modules_m, List) | |
| assert len(modules_m) + 1 == len(lr) | |
| special_param_ids = set() | |
| rest_params = [] | |
| for (i, m_str) in enumerate(modules_m): | |
| sub_module = get_nested_attribute(model, m_str) | |
| subs_params = [] | |
| for _, sub_params in sub_module.named_parameters(): | |
| subs_params.append(sub_params) | |
| special_param_ids.add(id(sub_params)) | |
| groups.append({'params': subs_params, 'lr': lr[i]}) | |
| # other model's parameters | |
| for _, param in model.named_parameters(): | |
| if id(param) not in special_param_ids: | |
| rest_params.append(param) | |
| groups.append({'params': rest_params, 'lr': lr[-1]}) | |
| params = groups if len(groups) > 0 else model.parameters() | |
| optim_conf = copy.deepcopy(configs['optim_conf']) | |
| if 'modules' in optim_conf: | |
| del optim_conf['modules'] | |
| if isinstance(lr, List): | |
| optim_conf['lr'] = lr[-1] | |
| if configs['optim'] == 'adam': | |
| optimizer = optim.Adam(params, **optim_conf) | |
| elif configs['optim'] == 'adamw': | |
| optimizer = optim.AdamW(params, **optim_conf) | |
| else: | |
| raise ValueError("unknown optimizer: " + configs['optim']) | |
| scheduler_type = None | |
| if configs['scheduler'] == 'warmuplr': | |
| scheduler_type = WarmupLR | |
| scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) | |
| elif configs['scheduler'] == 'NoamHoldAnnealing': | |
| scheduler_type = NoamHoldAnnealing | |
| scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf']) | |
| else: | |
| raise ValueError("unknown scheduler: " + configs['scheduler']) | |
| # NOTE(xcsong): Custom optimizer might yield poor performance when | |
| # zero-offload is enabled, if you do want to offload optimizer to CPU, | |
| # please set optimizer in ds_config.json, see: | |
| # (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters) | |
| if args.train_engine == "deepspeed": | |
| with open(args.deepspeed_config, 'r') as fin: | |
| ds_configs = json.load(fin) | |
| if "optimizer" in ds_configs: | |
| # NOTE(xcsong): Disable custom optimizer if it is set in ds_config, | |
| # extremely useful when enable cpu_offload, DeepspeedCpuAdam | |
| # could be 4~5x faster than torch native adam | |
| optimizer = None | |
| if "scheduler" in ds_configs: | |
| scheduler = None | |
| else: | |
| def scheduler(opt): | |
| return scheduler_type(opt, **configs['scheduler_conf']) | |
| model, optimizer, _, scheduler = deepspeed.initialize( | |
| args=args, | |
| model=model, | |
| optimizer=optimizer, | |
| lr_scheduler=scheduler, | |
| model_parameters=model.parameters()) | |
| step = configs.get("init_infos", {}).get("step", -1) | |
| scheduler.set_step(step) | |
| return model, optimizer, scheduler | |
| def trace_and_print_model(args, model): | |
| # !!!IMPORTANT!!! | |
| # Try to export the model by script, if fails, we should refine | |
| # the code to satisfy the script export requirements | |
| if int(os.environ.get('RANK', 0)) == 0: | |
| if args.jit: | |
| script_model = torch.jit.script(model) | |
| script_model.save(os.path.join(args.model_dir, 'init.zip')) | |
| if args.print_model: | |
| print(model) | |
| num_params = sum(p.numel() for p in model.parameters()) | |
| print('the number of model params: {:,d}'.format(num_params)) | |
| def init_summarywriter(args): | |
| writer = None | |
| if int(os.environ.get('RANK', 0)) == 0: | |
| os.makedirs(args.model_dir, exist_ok=True) | |
| exp_id = os.path.basename(args.model_dir) | |
| writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) | |
| return writer | |
| def init_scaler(args): | |
| scaler = None | |
| if args.use_amp: | |
| if "cuda" in args.device: | |
| scaler = torch.cuda.amp.GradScaler() | |
| elif "npu" in args.device and TORCH_NPU_AVAILABLE: | |
| scaler = torch.npu.amp.GradScaler() | |
| else: | |
| logging.error("not supported device: {}".format(args.device)) | |
| elif args.train_engine == 'torch_fsdp': | |
| # why bf16 don't need scaler: | |
| # https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596 | |
| if args.dtype in ['fp16']: | |
| scaler = sharded_grad_scaler.ShardedGradScaler(enabled=True) | |
| return scaler | |
| def save_model(model, info_dict): | |
| rank = int(os.environ.get('RANK', 0)) | |
| tag = info_dict["tag"] | |
| model_dir = info_dict["model_dir"] | |
| save_model_path = os.path.join(model_dir, '{}.pt'.format(tag)) | |
| # save ckpt | |
| if info_dict["train_engine"] == "deepspeed": | |
| # NOTE(xcsong): All ranks should call this API, but only rank 0 | |
| # save the general model params. see: | |
| # https://github.com/microsoft/DeepSpeed/issues/2993 | |
| with torch.no_grad(): | |
| model.save_checkpoint(save_dir=model_dir, | |
| tag=tag, | |
| client_state=info_dict) | |
| if info_dict["save_states"] == "model_only" and rank == 0: | |
| convert_zero_checkpoint_to_fp32_state_dict(model_dir, | |
| save_model_path, | |
| tag=tag) | |
| os.system("rm -rf {}/{}".format(model_dir, tag)) | |
| elif info_dict['train_engine'] == "torch_fsdp": | |
| fsdp_save_model(model, save_model_path, info_dict) | |
| elif rank == 0: | |
| # NOTE(xcsong): For torch_ddp, only rank-0 should call this. | |
| save_checkpoint(model, save_model_path, info_dict) | |
| # save yaml | |
| if rank == 0: | |
| with open("{}/{}.yaml".format(model_dir, tag), 'w') as fout: | |
| data = yaml.dump(info_dict) | |
| fout.write(data) | |
| def wenet_join(group_join, info_dict): | |
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) | |
| rank = int(os.environ.get('RANK', 0)) | |
| train_engine = info_dict.get('train_engine', "torch_ddp") | |
| if info_dict["batch_idx"] == 0 or train_engine == "torch_ddp": | |
| # NOTE(xcsong): skip first batch because its processing time includes | |
| # dataloader initialization time, which may exceed 30 seconds | |
| return False | |
| try: | |
| # NOTE(xcsong): Why we need a new group? | |
| # Because Deepspeed has its own group where all the relevant communication | |
| # operations are executed. If we add a communication operation that is not | |
| # managed by Deepspeed in this group, it's highly likely to cause | |
| # communication chaos, resulting in hard-to-troubleshoot hangs. | |
| dist.monitored_barrier(group=group_join, | |
| timeout=group_join.options._timeout) | |
| except RuntimeError as e: | |
| logging.info("Detected uneven workload distribution: {}\n".format(e) + | |
| "Break current worker to manually join all workers, " + | |
| "world_size {}, current rank {}, current local_rank {}\n". | |
| format(world_size, rank, local_rank)) | |
| return True | |
| return False | |
| def batch_forward(model, batch, scaler, info_dict, device): | |
| train_engine = info_dict.get('train_engine', "torch_ddp") | |
| accum_grad = info_dict.get('accum_grad', 1) | |
| dtype = info_dict.get("dtype", "fp32") | |
| if dtype == "fp16": | |
| dtype = torch.float16 | |
| elif dtype == "bf16": | |
| dtype = torch.bfloat16 | |
| else: # fp32 | |
| dtype = None | |
| # autocast context | |
| # The more details about amp can be found in | |
| # https://pytorch.org/docs/stable/notes/amp_examples.html | |
| amp_autocast = torch.cuda.amp.autocast | |
| if "npu" in device.__str__() and TORCH_NPU_AVAILABLE: | |
| amp_autocast = torch.npu.amp.autocast | |
| autocast = { | |
| "deepspeed": | |
| amp_autocast(enabled=dtype is not None, | |
| dtype=dtype, | |
| cache_enabled=False), | |
| "torch_ddp": | |
| amp_autocast(enabled=scaler is not None), | |
| "torch_fsdp": | |
| amp_autocast(enabled=True, dtype=dtype) | |
| if dtype is not None else nullcontext() | |
| }[train_engine] | |
| with autocast: | |
| loss_dict = model(batch, device) | |
| info_dict['loss_dict'] = loss_dict | |
| return info_dict | |
| def batch_backward(model, scaler, info_dict): | |
| train_engine = info_dict.get("train_engine", "torch_ddp") | |
| accum_grad = info_dict.get('accum_grad', 1) | |
| use_amp = info_dict.get('use_amp', False) | |
| if use_amp: | |
| assert scaler is not None | |
| loss = info_dict['loss_dict']['loss'] | |
| if train_engine == "deepspeed": | |
| # NOTE(xcsong): `model.backward(loss)` is equivalent to | |
| # `scale_loss_wrt_accum_grad + loss.backward()` | |
| # ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api | |
| scaled_loss = model.backward(loss) | |
| else: | |
| assert train_engine in ["torch_ddp", "torch_fsdp"] | |
| scaled_loss = loss / accum_grad | |
| if scaler is not None: | |
| # fp16 (amp and fsdp) | |
| scaler.scale(scaled_loss).backward() | |
| else: | |
| # float32 (ddp and fsdp) | |
| # bf16 (fsdp) | |
| scaled_loss.backward() | |
| info_dict['loss_dict']['loss'] = scaled_loss | |
| for loss_name, loss_value in info_dict['loss_dict'].items(): | |
| if loss_value is not None: | |
| info_dict['loss_dict'][loss_name] = tensor_to_scalar(loss_value) | |
| return info_dict | |
| def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): | |
| rank = int(os.environ.get('RANK', 0)) | |
| train_engine = info_dict.get("train_engine", "torch_ddp") | |
| accum_grad = info_dict.get('accum_grad', 1) | |
| use_amp = info_dict.get('use_amp', False) | |
| clip = info_dict.get('grad_clip', 50.0) | |
| batch_idx = info_dict["batch_idx"] | |
| if use_amp: | |
| assert scaler is not None | |
| grad_norm = 0.0 | |
| if train_engine == "deepspeed": | |
| # NOTE(xcsong): The step() function in DeepSpeed engine updates the | |
| # model parameters as well as the learning rate. | |
| # Zeroing the gradients is handled automatically by | |
| # DeepSpeed after the weights have been updated using a mini-batch. | |
| # DeepSpeed also performs gradient averaging automatically at the | |
| # gradient accumulation boundaries and addresses clip_grad_norm internally. | |
| # `ds_model.step() = clip_grad_norm_() + optimizer.step() | |
| # + optimizer.zero_grad() + scheduler.step()` | |
| # ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api | |
| info_dict["is_gradient_accumulation_boundary"] = \ | |
| model.is_gradient_accumulation_boundary() | |
| model.step() | |
| grad_norm = model.get_global_grad_norm() | |
| if grad_norm is None: | |
| grad_norm = 0.0 | |
| elif (batch_idx + 1) % accum_grad == 0: | |
| # Use mixed precision training | |
| # fp16 (ddp fsdp) | |
| if scaler is not None: | |
| scaler.unscale_(optimizer) | |
| if train_engine == "torch_ddp": | |
| grad_norm = clip_grad_norm_(model.parameters(), clip) | |
| else: | |
| # fsdp | |
| grad_norm = model.clip_grad_norm_(clip) | |
| # Must invoke scaler.update() if unscale_() is used in | |
| # the iteration to avoid the following error: | |
| # RuntimeError: unscale_() has already been called | |
| # on this optimizer since the last update(). | |
| # We don't check grad here since that if the gradient | |
| # has inf/nan values, scaler.step will skip | |
| # optimizer.step(). | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| if train_engine == "torch_ddp": | |
| grad_norm = clip_grad_norm_(model.parameters(), clip) | |
| else: | |
| grad_norm = model.clip_grad_norm_(clip) | |
| if torch.isfinite(grad_norm): | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| scheduler.step() | |
| info_dict["lrs"] = [group['lr'] for group in optimizer.param_groups] | |
| info_dict["grad_norm"] = tensor_to_scalar(grad_norm) | |
| return info_dict | |
| def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): | |
| tag = info_dict["tag"] | |
| step = info_dict["step"] | |
| batch_idx = info_dict["batch_idx"] | |
| loss_dict = info_dict['loss_dict'] | |
| epoch = info_dict.get('epoch', 0) | |
| train_engine = info_dict.get("train_engine", "torch_ddp") | |
| accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1 | |
| log_interval = info_dict.get('log_interval', 10) | |
| lrs = info_dict.get("lrs", [0.0]) | |
| is_gradient_accumulation_boundary = info_dict.get( | |
| "is_gradient_accumulation_boundary", False) | |
| rank = int(os.environ.get('RANK', 0)) | |
| # TRAIN Tensorboard | |
| if tag == "TRAIN" and rank == 0 and writer is not None: | |
| if (train_engine == "deepspeed" and is_gradient_accumulation_boundary | |
| ) or (train_engine in ["torch_ddp", "torch_fsdp"] and | |
| (batch_idx + 1) % accum_grad == 0): | |
| writer.add_scalar('train/train_loss', | |
| tensor_to_scalar(loss_dict['loss']) * accum_grad, | |
| step) | |
| writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step) | |
| for name, value in loss_dict.items(): | |
| if name != 'loss' and value is not None: | |
| writer.add_scalar('train/{}'.format(name), | |
| tensor_to_scalar(value), step) | |
| # lr | |
| for i, lr in enumerate(lrs): | |
| writer.add_scalar('train/lr_{}'.format(i), lr, step) | |
| # CV Tensorboard | |
| elif "step_" in tag and rank == 0 and writer is not None: | |
| for name, value in loss_dict.items(): | |
| writer.add_scalar('cv/{}'.format(name), tensor_to_scalar(value), | |
| step) | |
| logging.info( | |
| 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( | |
| epoch, step + 1, lrs_to_str(lrs), | |
| tensor_to_scalar(loss_dict["loss"]), rank, | |
| tensor_to_scalar(loss_dict["acc"]))) | |
| return | |
| # TRAIN & CV, Shell log (stdout) | |
| if (batch_idx + 1) % log_interval == 0: | |
| log_str = '{} | '.format(tag) | |
| if timer is not None: | |
| timer_step = step | |
| if info_dict.get("cv_step", None) is not None: | |
| timer_step = info_dict['cv_step'] | |
| steps_per_second = timer.steps_per_second(timer_step) | |
| log_str += 'steps/sec {:.3f}| '.format(steps_per_second) | |
| log_str += 'Batch {}/{} loss {:.6f} '.format( | |
| epoch, batch_idx + 1 if 'save_interval' not in info_dict else | |
| (step + 1) * accum_grad, | |
| tensor_to_scalar(loss_dict['loss']) * accum_grad) | |
| for name, value in loss_dict.items(): | |
| if name != 'loss' and value is not None: | |
| log_str += '{} {:.6f} '.format(name, tensor_to_scalar(value)) | |
| if tag == "TRAIN": | |
| log_str += 'lr {} grad_norm {:.6f} rank {}'.format( | |
| lrs_to_str(lrs), info_dict['grad_norm'], rank) | |
| logging.debug(log_str) | |
| def log_per_epoch(writer, info_dict): | |
| epoch = info_dict["epoch"] | |
| loss_dict = info_dict["loss_dict"] | |
| lrs = info_dict['lrs'] | |
| rank = int(os.environ.get('RANK', 0)) | |
| step = info_dict["step"] | |
| logging.info( | |
| 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( | |
| epoch, step, lrs_to_str(lrs), tensor_to_scalar(loss_dict["loss"]), | |
| rank, tensor_to_scalar(loss_dict["acc"]))) | |
| if int(os.environ.get('RANK', 0)) == 0: | |
| for i, lr in enumerate(info_dict["lrs"]): | |
| writer.add_scalar('epoch/lr_{}'.format(i), lr, epoch) | |
| for name, value in loss_dict.items(): | |
| writer.add_scalar('epoch/{}'.format(name), tensor_to_scalar(value), | |
| epoch) | |
| def freeze_modules(model, args): | |
| for name, param in model.named_parameters(): | |
| for module_name in args.freeze_modules: | |
| if module_name in name: | |
| param.requires_grad = False | |
| logging.debug("{} module is freezed".format(name)) | |
| def reinit_lora(model, args, configs, tokenizer, seed=777): | |
| from tqdm import tqdm | |
| from wenet.finetune.lora.utils import estimate_gradient, reinit_lora_modules | |
| from wenet.finetune.lora.layers import LoRALayer | |
| from types import SimpleNamespace | |
| logging.info("reinit lora modules.") | |
| with open(args.lora_init_yaml, 'r') as file: | |
| lora_config = yaml.safe_load(file) | |
| generator = torch.Generator() | |
| generator.manual_seed(seed) | |
| dataset_conf = copy.deepcopy(configs['dataset_conf']) | |
| dataset_conf['batch_conf']['batch_size'] = lora_config['init_batch_size'] | |
| dataset_type = configs.get('dataset', 'asr') | |
| dataset = init_dataset(dataset_type, args.data_type, args.train_data, | |
| tokenizer, dataset_conf, True) | |
| dataloader = DataLoader(dataset, | |
| batch_size=None, | |
| pin_memory=args.pin_memory, | |
| num_workers=args.num_workers, | |
| persistent_workers=True, | |
| generator=generator, | |
| prefetch_factor=args.prefetch) | |
| additional_kwargs = {} | |
| if lora_config["init_config"]["mode"] == "gradient": | |
| named_grads = estimate_gradient(model, dataloader, | |
| lora_config['init_iters']) | |
| additional_kwargs["named_grads"] = named_grads | |
| lora_config = SimpleNamespace(**lora_config["init_config"]) | |
| for name, module in tqdm( | |
| model.named_modules(), | |
| desc="Reinitializing Lora", | |
| total=len(list(model.named_modules())), | |
| ): | |
| if isinstance(module, LoRALayer): | |
| reinit_lora_modules(name, module, lora_config, **additional_kwargs) | |
| # lora_init_model needs to be saved, w0 = w0 - A0 * B0 | |
| save_checkpoint(model, os.path.join(args.model_dir, "lora_init.pt"), | |
| infos={"tag": "lora_init", **configs}) | |