wcy1122's picture
initial commi
26a63c0
import os
import re
import json
import gc
import functools
import contextlib
from typing import Dict, Union, Optional, Type, Set
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
StateDictType,
FullOptimStateDictConfig,
FullStateDictConfig,
)
import torch.distributed.checkpoint as torch_dcp
import torch.distributed.checkpoint.state_dict
from torch.distributed.fsdp.api import (
ShardingStrategy,
BackwardPrefetch,
MixedPrecision,
)
import accelerate
import safetensors
import diffusers
import transformers
from huggingface_hub.serialization import split_torch_state_dict_into_shards
import os, re, json
from typing import Union
import torch
import safetensors.torch
import accelerate
# from .ema_utils import EMAModel
def upcast_trainable_param_to_fp32_(fsdp_model):
for m in FSDP.fsdp_modules(fsdp_model):
if m._has_params:
param = m._flat_param
if (
param.dtype != torch.float32
and param.device != torch.device("meta")
and param.requires_grad
):
param.data = param.data.to(torch.float32)
m._handle._orig_param_dtype = torch.float32
def get_module_to_ignore_mixed_precision():
try:
from apex.normalization import FusedLayerNorm
return [
torch.nn.GroupNorm,
torch.nn.modules.batchnorm._BatchNorm,
torch.nn.LayerNorm,
FusedLayerNorm,
]
except:
return [
torch.nn.GroupNorm,
torch.nn.modules.batchnorm._BatchNorm,
torch.nn.LayerNorm,
]
def is_fsdp_model(model):
return len(FSDP.fsdp_modules(model)) > 0
def size_based_auto_wrap_policy(
module: torch.nn.Module,
recurse: bool,
nonwrapped_numel: int,
# Additional custom arguments
min_num_params: int = int(1e8),
force_leaf_modules: Optional[Set[Type[torch.nn.Module]]] = None,
exclude_wrap_modules: Optional[Set[Type[torch.nn.Module]]] = None,
) -> bool:
"""
A size-based auto wrap policy.
Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.
min_num_params (int): Customizable policy input that controls the size
threshold over which a module is ready to be wrapped. This is in
units of numel.
force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
as leaves, i.e. their children will never be wrapped.
exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
excluded in wrapping.
Returns:
Whether ``module`` should be wrapped.
"""
force_leaf_modules = (
size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
if force_leaf_modules is None
else force_leaf_modules
)
exclude_wrap_modules = (
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
if exclude_wrap_modules is None
else exclude_wrap_modules
)
# Keep the argument `min_num_params` for BC for now, but it represents the
# minimum non-wrapped *numel* before triggering a wrapping
min_nonwrapped_numel = min_num_params
is_large = nonwrapped_numel >= min_nonwrapped_numel
STOP_FLAG_NAME = "__FSDP_STOP_WARP_FLAG_CUSTOM_POLICY_size_based_auto_wrap_policy"
if recurse:
# use MixedPrecision cause ALWAYS recurse
if isinstance(module, tuple(force_leaf_modules)):
for m in module.children():
m.apply(lambda m: setattr(m, STOP_FLAG_NAME, True))
return True
else:
if getattr(module, size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME, False):
return True
elif getattr(module, STOP_FLAG_NAME, False):
return False
else:
# If we are not recursing, determine if we should wrap.
return is_large and not isinstance(module, tuple(exclude_wrap_modules))
# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {torch.nn.ModuleList, torch.nn.ModuleDict} # type: ignore[attr-defined]
size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {torch.nn.MultiheadAttention} # type: ignore[attr-defined]
size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME = (
"__FSDP_LEAF_ROOT_FLAG_CUSTOM_POLICY_size_based_auto_wrap_policy"
)
def mark_leaf_root_(module):
setattr(
module,
size_based_auto_wrap_policy.LEAF_ROOT_FLAG_NAME,
True,
)
def make_model_fsdp(
model,
param_dtype,
device,
reduce_dtype=None,
buffer_dtype=None,
sync_module_states=True,
process_group=None,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
module_classes_to_ignore_mixed_precision=None,
ignored_states=None,
ignored_modules=None,
auto_wrap_policy=None,
part_size=1e6,
force_leaf_modules=None,
exclude_wrap_modules=None,
use_orig_params=False
):
if module_classes_to_ignore_mixed_precision is None:
module_classes_to_ignore_mixed_precision = (
get_module_to_ignore_mixed_precision()
)
if auto_wrap_policy is not None:
auto_wrap_policy = auto_wrap_policy
elif sharding_strategy == ShardingStrategy.NO_SHARD:
auto_wrap_policy = None
else:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=part_size,
force_leaf_modules=force_leaf_modules,
exclude_wrap_modules=exclude_wrap_modules,
)
model = FSDP(
model,
sharding_strategy=sharding_strategy,
process_group=process_group,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
use_orig_params=use_orig_params,
sync_module_states=sync_module_states,
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype or torch.float32,
buffer_dtype=buffer_dtype or torch.float32,
keep_low_precision_grads=False,
cast_forward_inputs=False,
cast_root_forward_inputs=True,
_module_classes_to_ignore=module_classes_to_ignore_mixed_precision,
),
auto_wrap_policy=auto_wrap_policy,
ignored_states=ignored_states,
ignored_modules=ignored_modules,
device_id=device,
)
torch.cuda.empty_cache()
gc.collect()
return model
def save_fsdp_lora(
model_to_save, # FSDP 包裹的模型
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
lora_regex: str = r"(?:lora)", # 根据自己命名习惯调
):
"""
仅保存 LoRA 层的权重。适用于 FSDP 并与 safetensors 兼容。
"""
# 1. 解包 FSDP,拿到裸模型
unwrapped_model = accelerate.utils.extract_model_from_parallel(model_to_save)
# 2. 创建保存目录
if is_main_process:
os.makedirs(save_directory, exist_ok=True)
# 3. 收集完整 state_dict(CPU 上)
state_dict = torch_dcp.state_dict.get_model_state_dict(
model_to_save,
options=torch_dcp.state_dict.StateDictOptions(
full_state_dict=True,
cpu_offload=True,
ignore_frozen_params=False,
),
)
# 4. 过滤出 LoRA 参数
lora_pattern = re.compile(lora_regex)
lora_state_dict = {
k: v for k, v in state_dict.items() if lora_pattern.search(k) is not None
}
if not lora_state_dict:
raise ValueError(
"未找到匹配 LoRA 的参数。请检查 lora_regex 是否符合命名规则。"
)
# 5. 保存为单文件 *.safetensors
if is_main_process:
weight_file = os.path.join(save_directory, "adapter_model.safetensors")
safetensors.torch.save_file(
lora_state_dict, weight_file, metadata={"format": "pt", "type": "lora"}
)
def load_fsdp_model_(model_to_load: FSDP, save_directory: Union[str, os.PathLike]):
with FSDP.state_dict_type(
model_to_load,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(
rank0_only=False,
),
):
_model = model_to_load.from_pretrained(save_directory)
model_to_load.load_state_dict(_model.state_dict())
def save_fsdp_optimizer(
models: Dict,
optimizer_to_save: torch.optim.Optimizer,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
):
_fsdp_state_dict_config = dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
optim_state_dict_config=FullOptimStateDictConfig(
offload_to_cpu=True,
rank0_only=True,
),
)
mgrs = list()
for m in models.values():
if len(FSDP.fsdp_modules(m)) > 0:
mgrs.append(FSDP.state_dict_type(m, **_fsdp_state_dict_config))
with contextlib.ExitStack() as stack:
for mgr in mgrs:
stack.enter_context(mgr)
optim_state_dict = FSDP.optim_state_dict(
torch.nn.ModuleDict(models),
optimizer_to_save,
)
if is_main_process:
torch.save(
optim_state_dict, os.path.join(save_directory, "optim_states.pth")
)
def load_fsdp_optimizer_(
models: Dict,
optimizer_to_load: torch.optim.Optimizer,
save_directory: Union[str, os.PathLike],
):
_fsdp_state_dict_config = dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
optim_state_dict_config=FullOptimStateDictConfig(
rank0_only=False,
),
)
mgrs = list()
for m in models.values():
if len(FSDP.fsdp_modules(m)) > 0:
mgrs.append(FSDP.state_dict_type(m, **_fsdp_state_dict_config))
with contextlib.ExitStack() as stack:
for mgr in mgrs:
stack.enter_context(mgr)
optimizer_path = os.path.join(save_directory, "optim_states.pth")
assert os.path.isfile(optimizer_path)
optim_state_dict = torch.load(optimizer_path)
optim_state_dict = FSDP.optim_state_dict_to_load(
torch.nn.ModuleDict(models),
optimizer_to_load,
optim_state_dict,
)
optimizer_to_load.load_state_dict(optim_state_dict)