Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import os | |
| from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP, | |
| FullStateDictConfig, StateDictType) | |
| from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, | |
| transformer_auto_wrap_policy) | |
| from wenet.LLM.decoder import DecoderOnly | |
| from wenet.branchformer.encoder_layer import BranchformerEncoderLayer | |
| from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer | |
| from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer | |
| from wenet.paraformer.layers import AliParaformerEncoderLayer, SanmDecoderLayer | |
| from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer | |
| from wenet.transformer.encoder_layer import (ConformerEncoderLayer, | |
| TransformerEncoderLayer) | |
| from wenet.transformer.decoder_layer import DecoderLayer | |
| from wenet.utils.checkpoint import save_state_dict_and_infos | |
| from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES | |
| WENET_ENCODER_LAYERS_CLASSES = { | |
| 'transformer_encoder_layer': TransformerEncoderLayer, | |
| 'conformer_encoder_layer': ConformerEncoderLayer, | |
| 'paraformer_encoder_layer': AliParaformerEncoderLayer, | |
| 'squeezeformer_encoder_layer': SqueezeformerEncoderLayer, | |
| 'ebranchformer_encoder_layer': EBranchformerEncoderLayer, | |
| 'efficient_conformer_encoder_layer': StrideConformerEncoderLayer, | |
| 'branchformer_encoder_layer': BranchformerEncoderLayer, | |
| } | |
| WENET_DECODER_LAYERS_CLASSES = { | |
| 'transformer_decoder_layer': DecoderLayer, | |
| 'paraformer_decoder_layer': SanmDecoderLayer, | |
| # TODO(Mddct): | |
| # 1 wrap transducer's predictor and joint | |
| # 2 wrap paraformer's cif and ignore lstm | |
| } | |
| def wenet_fsdp_wrap_policy(mode): | |
| # different wrap methods | |
| # please refer: https://openmmlab.medium.com/its-2023-is-pytorch-s-fsdp-the-best-choice-for-training-large-models-fe8d2848832f # noqa | |
| assert mode in ['no_shard', 'model', 'zero2', 'zero3'] | |
| if mode == 'no_shard': | |
| return None | |
| else: | |
| # TODO(Mddct): Support user customization | |
| # see more wrap methods: | |
| # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa | |
| if mode == 'model': | |
| enc_dec_wrap_policy = partial( | |
| lambda_auto_wrap_policy, | |
| lambda_fn=lambda module: isinstance( | |
| module, | |
| tuple(WENET_ENCODER_CLASSES.values()) + tuple( | |
| WENET_DECODER_CLASSES.values()))) | |
| return enc_dec_wrap_policy | |
| else: | |
| to_wrap_class = set() | |
| to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values())) | |
| to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values())) | |
| layers_wrap_policy = partial(transformer_auto_wrap_policy, | |
| transformer_layer_cls=to_wrap_class) | |
| return layers_wrap_policy | |
| fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, | |
| rank0_only=True) | |
| def fsdp_save_model(model, save_model_path, info_dict): | |
| # TODO(Mddct); When the model is large, saving a model will take a long time. | |
| # We only need to keep the sharding in an asynchronous manner, but it is | |
| # good now. This feature will be supported when llm is supported in the future. | |
| rank = int(os.environ.get('RANK', 0)) | |
| with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, | |
| fullstate_save_policy): | |
| state_dict = model.state_dict() | |
| if rank == 0: | |
| save_state_dict_and_infos(state_dict, save_model_path, info_dict) | |
| def check_gradient_checkpoint(model): | |
| ckpt_laye_types = [] | |
| if hasattr(model, 'encoder') and hasattr(model.encoder, | |
| 'gradient_checkpointing'): | |
| if model.encoder.gradient_checkpointing: | |
| model.encoder.gradient_checkpointing = False | |
| ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values()) | |
| if hasattr(model, 'decoder') and hasattr(model.decoder, | |
| 'gradient_checkpointing'): | |
| if model.decoder.gradient_checkpointing: | |
| model.decoder.gradient_checkpointing = False | |
| ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) | |
| if isinstance(model.decoder, DecoderOnly): | |
| ckpt_laye_types += [DecoderOnly] | |
| return tuple(ckpt_laye_types) | |
| def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple): | |
| # NOTE(Mddct): torch.utils.checkpoint is currently incompatible with | |
| # wenet's model mode. Using this writing method, Please refer to | |
| # https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21 # noqa | |
| if len(ckpt_layer_types) == 0: | |
| return | |
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | |
| checkpoint_wrapper, | |
| CheckpointImpl, | |
| apply_activation_checkpointing, | |
| ) | |
| non_reentrant_wrapper = partial( | |
| checkpoint_wrapper, | |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, | |
| ) | |
| apply_activation_checkpointing( | |
| model, | |
| checkpoint_wrapper_fn=non_reentrant_wrapper, | |
| check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types)) | |