Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import json | |
| import gc | |
| import functools | |
| import contextlib | |
| from typing import Dict, Union, Optional, Type, Set | |
| import torch | |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| from torch.distributed.fsdp import ( | |
| StateDictType, | |
| FullOptimStateDictConfig, | |
| FullStateDictConfig, | |
| ) | |
| import torch.distributed.checkpoint as torch_dcp | |
| import torch.distributed.checkpoint.state_dict | |
| from torch.distributed.fsdp.api import ( | |
| ShardingStrategy, | |
| BackwardPrefetch, | |
| MixedPrecision, | |
| ) | |
| import accelerate | |
| import safetensors | |
| import diffusers | |
| import transformers | |
| from huggingface_hub.serialization import split_torch_state_dict_into_shards | |
| import os, re, json | |
| from typing import Union | |
| import torch | |
| import safetensors.torch | |
| import accelerate | |
| # from .ema_utils import EMAModel | |
| def upcast_trainable_param_to_fp32_(fsdp_model): | |
| for m in FSDP.fsdp_modules(fsdp_model): | |
| if m._has_params: | |
| param = m._flat_param | |
| if ( | |
| param.dtype != torch.float32 | |
| and param.device != torch.device("meta") | |
| and param.requires_grad | |
| ): | |
| param.data = param.data.to(torch.float32) | |
| m._handle._orig_param_dtype = torch.float32 | |
| def get_module_to_ignore_mixed_precision(): | |
| try: | |
| from apex.normalization import FusedLayerNorm | |
| return [ | |
| torch.nn.GroupNorm, | |
| torch.nn.modules.batchnorm._BatchNorm, | |
| torch.nn.LayerNorm, | |
| FusedLayerNorm, | |
| ] | |
| except: | |
| return [ | |
| torch.nn.GroupNorm, | |
| torch.nn.modules.batchnorm._BatchNorm, | |
| torch.nn.LayerNorm, | |
| ] | |
| def is_fsdp_model(model): | |
| return len(FSDP.fsdp_modules(model)) > 0 | |
| def size_based_auto_wrap_policy( | |
| module: torch.nn.Module, | |
| recurse: bool, | |
| nonwrapped_numel: int, | |
| # Additional custom arguments | |
| min_num_params: int = int(1e8), | |
| force_leaf_modules: Optional[Set[Type[torch.nn.Module]]] = None, | |
| exclude_wrap_modules: Optional[Set[Type[torch.nn.Module]]] = None, | |
| ) -> bool: | |
| """ | |
| A size-based auto wrap policy. | |
| Args: | |
| module (nn.Module): Current module being considered. | |
| recurse (bool): If ``False``, then this function must decide whether | |
| ``module`` should be wrapped as an FSDP instance or not. If | |
| ``True``, then the function is still recursing down the module | |
| tree as a part of the DFS. | |
| nonwrapped_numel (int): Parameter numel not yet wrapped. | |
| min_num_params (int): Customizable policy input that controls the size | |
| threshold over which a module is ready to be wrapped. This is in | |
| units of numel. | |
| force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep | |
| as leaves, i.e. their children will never be wrapped. | |
| exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be | |
| excluded in wrapping. | |
| Returns: | |
| Whether ``module`` should be wrapped. | |
| """ | |
| force_leaf_modules = ( | |
| size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] | |
| if force_leaf_modules is None | |
| else force_leaf_modules | |
| ) | |
| exclude_wrap_modules = ( | |
| size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined] | |
| if exclude_wrap_modules is None | |
| else exclude_wrap_modules | |
| ) | |
| # Keep the argument `min_num_params` for BC for now, but it represents the | |
| # minimum non-wrapped *numel* before triggering a wrapping | |
| min_nonwrapped_numel = min_num_params | |
| is_large = nonwrapped_numel >= min_nonwrapped_numel | |
| STOP_FLAG_NAME = "__FSDP_STOP_WARP_FLAG_CUSTOM_POLICY_size_based_auto_wrap_policy" | |
| if recurse: | |
| # use MixedPrecision cause ALWAYS recurse | |
| if isinstance(module, tuple(force_leaf_modules)): | |
| for m in module.children(): | |
| m.apply(lambda m: setattr(m, STOP_FLAG_NAME, True)) | |
| return True | |
| else: | |
| if getattr(module, size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME, False): | |
| return True | |
| elif getattr(module, STOP_FLAG_NAME, False): | |
| return False | |
| else: | |
| # If we are not recursing, determine if we should wrap. | |
| return is_large and not isinstance(module, tuple(exclude_wrap_modules)) | |
| # Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. | |
| size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {torch.nn.ModuleList, torch.nn.ModuleDict} # type: ignore[attr-defined] | |
| size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {torch.nn.MultiheadAttention} # type: ignore[attr-defined] | |
| size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME = ( | |
| "__FSDP_LEAF_ROOT_FLAG_CUSTOM_POLICY_size_based_auto_wrap_policy" | |
| ) | |
| def mark_leaf_root_(module): | |
| setattr( | |
| module, | |
| size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME, | |
| True, | |
| ) | |
| def make_model_fsdp( | |
| model, | |
| param_dtype, | |
| device, | |
| reduce_dtype=None, | |
| buffer_dtype=None, | |
| sync_module_states=True, | |
| process_group=None, | |
| sharding_strategy=ShardingStrategy.HYBRID_SHARD, | |
| module_classes_to_ignore_mixed_precision=None, | |
| ignored_states=None, | |
| ignored_modules=None, | |
| auto_wrap_policy=None, | |
| part_size=1e6, | |
| force_leaf_modules=None, | |
| exclude_wrap_modules=None, | |
| use_orig_params=False | |
| ): | |
| if module_classes_to_ignore_mixed_precision is None: | |
| module_classes_to_ignore_mixed_precision = ( | |
| get_module_to_ignore_mixed_precision() | |
| ) | |
| if auto_wrap_policy is not None: | |
| auto_wrap_policy = auto_wrap_policy | |
| elif sharding_strategy == ShardingStrategy.NO_SHARD: | |
| auto_wrap_policy = None | |
| else: | |
| auto_wrap_policy = functools.partial( | |
| size_based_auto_wrap_policy, | |
| min_num_params=part_size, | |
| force_leaf_modules=force_leaf_modules, | |
| exclude_wrap_modules=exclude_wrap_modules, | |
| ) | |
| model = FSDP( | |
| model, | |
| sharding_strategy=sharding_strategy, | |
| process_group=process_group, | |
| forward_prefetch=True, | |
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
| limit_all_gathers=True, | |
| use_orig_params=use_orig_params, | |
| sync_module_states=sync_module_states, | |
| mixed_precision=MixedPrecision( | |
| param_dtype=param_dtype, | |
| reduce_dtype=reduce_dtype or torch.float32, | |
| buffer_dtype=buffer_dtype or torch.float32, | |
| keep_low_precision_grads=False, | |
| cast_forward_inputs=False, | |
| cast_root_forward_inputs=True, | |
| _module_classes_to_ignore=module_classes_to_ignore_mixed_precision, | |
| ), | |
| auto_wrap_policy=auto_wrap_policy, | |
| ignored_states=ignored_states, | |
| ignored_modules=ignored_modules, | |
| device_id=device, | |
| ) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return model | |
| def save_fsdp_lora( | |
| model_to_save, # FSDP 包裹的模型 | |
| save_directory: Union[str, os.PathLike], | |
| is_main_process: bool = True, | |
| lora_regex: str = r"(?:lora)", # 根据自己命名习惯调 | |
| ): | |
| """ | |
| 仅保存 LoRA 层的权重。适用于 FSDP 并与 safetensors 兼容。 | |
| """ | |
| # 1. 解包 FSDP,拿到裸模型 | |
| unwrapped_model = accelerate.utils.extract_model_from_parallel(model_to_save) | |
| # 2. 创建保存目录 | |
| if is_main_process: | |
| os.makedirs(save_directory, exist_ok=True) | |
| # 3. 收集完整 state_dict(CPU 上) | |
| state_dict = torch_dcp.state_dict.get_model_state_dict( | |
| model_to_save, | |
| options=torch_dcp.state_dict.StateDictOptions( | |
| full_state_dict=True, | |
| cpu_offload=True, | |
| ignore_frozen_params=False, | |
| ), | |
| ) | |
| # 4. 过滤出 LoRA 参数 | |
| lora_pattern = re.compile(lora_regex) | |
| lora_state_dict = { | |
| k: v for k, v in state_dict.items() if lora_pattern.search(k) is not None | |
| } | |
| if not lora_state_dict: | |
| raise ValueError( | |
| "未找到匹配 LoRA 的参数。请检查 lora_regex 是否符合命名规则。" | |
| ) | |
| # 5. 保存为单文件 *.safetensors | |
| if is_main_process: | |
| weight_file = os.path.join(save_directory, "adapter_model.safetensors") | |
| safetensors.torch.save_file( | |
| lora_state_dict, weight_file, metadata={"format": "pt", "type": "lora"} | |
| ) | |
| def load_fsdp_model_(model_to_load: FSDP, save_directory: Union[str, os.PathLike]): | |
| with FSDP.state_dict_type( | |
| model_to_load, | |
| state_dict_type=StateDictType.FULL_STATE_DICT, | |
| state_dict_config=FullStateDictConfig( | |
| rank0_only=False, | |
| ), | |
| ): | |
| _model = model_to_load.from_pretrained(save_directory) | |
| model_to_load.load_state_dict(_model.state_dict()) | |
| def save_fsdp_optimizer( | |
| models: Dict, | |
| optimizer_to_save: torch.optim.Optimizer, | |
| save_directory: Union[str, os.PathLike], | |
| is_main_process: bool = True, | |
| ): | |
| _fsdp_state_dict_config = dict( | |
| state_dict_type=StateDictType.FULL_STATE_DICT, | |
| optim_state_dict_config=FullOptimStateDictConfig( | |
| offload_to_cpu=True, | |
| rank0_only=True, | |
| ), | |
| ) | |
| mgrs = list() | |
| for m in models.values(): | |
| if len(FSDP.fsdp_modules(m)) > 0: | |
| mgrs.append(FSDP.state_dict_type(m, **_fsdp_state_dict_config)) | |
| with contextlib.ExitStack() as stack: | |
| for mgr in mgrs: | |
| stack.enter_context(mgr) | |
| optim_state_dict = FSDP.optim_state_dict( | |
| torch.nn.ModuleDict(models), | |
| optimizer_to_save, | |
| ) | |
| if is_main_process: | |
| torch.save( | |
| optim_state_dict, os.path.join(save_directory, "optim_states.pth") | |
| ) | |
| def load_fsdp_optimizer_( | |
| models: Dict, | |
| optimizer_to_load: torch.optim.Optimizer, | |
| save_directory: Union[str, os.PathLike], | |
| ): | |
| _fsdp_state_dict_config = dict( | |
| state_dict_type=StateDictType.FULL_STATE_DICT, | |
| optim_state_dict_config=FullOptimStateDictConfig( | |
| rank0_only=False, | |
| ), | |
| ) | |
| mgrs = list() | |
| for m in models.values(): | |
| if len(FSDP.fsdp_modules(m)) > 0: | |
| mgrs.append(FSDP.state_dict_type(m, **_fsdp_state_dict_config)) | |
| with contextlib.ExitStack() as stack: | |
| for mgr in mgrs: | |
| stack.enter_context(mgr) | |
| optimizer_path = os.path.join(save_directory, "optim_states.pth") | |
| assert os.path.isfile(optimizer_path) | |
| optim_state_dict = torch.load(optimizer_path) | |
| optim_state_dict = FSDP.optim_state_dict_to_load( | |
| torch.nn.ModuleDict(models), | |
| optimizer_to_load, | |
| optim_state_dict, | |
| ) | |
| optimizer_to_load.load_state_dict(optim_state_dict) |