File size: 4,073 Bytes
3ed0796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from collections import OrderedDict
from diffusers.utils import logging

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def get_fsdp_plugin(fsdp_cfg, mixed_precision):
    import functools
    from torch.distributed.fsdp.fully_sharded_data_parallel import (
        BackwardPrefetch,
        CPUOffload,
        ShardingStrategy,
        MixedPrecision,
        StateDictType,
        FullStateDictConfig,
        FullOptimStateDictConfig,
    )
    from accelerate.utils import FullyShardedDataParallelPlugin
    from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

    if mixed_precision == "fp16":
        dtype = torch.float16
    elif mixed_precision == "bf16":
        dtype = torch.bfloat16
    else:
        dtype = torch.float32
    fsdp_plugin = FullyShardedDataParallelPlugin(
        sharding_strategy={
            "FULL_SHARD": ShardingStrategy.FULL_SHARD,
            "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
            "NO_SHARD": ShardingStrategy.NO_SHARD,
            "HYBRID_SHARD": ShardingStrategy.HYBRID_SHARD,
            "HYBRID_SHARD_ZERO2": ShardingStrategy._HYBRID_SHARD_ZERO2,
        }[fsdp_cfg.sharding_strategy],
        backward_prefetch={
            "BACKWARD_PRE": BackwardPrefetch.BACKWARD_PRE,
            "BACKWARD_POST": BackwardPrefetch.BACKWARD_POST,
        }[fsdp_cfg.backward_prefetch],
        mixed_precision_policy=MixedPrecision(
            param_dtype=dtype,
            reduce_dtype=dtype,
        ),
        auto_wrap_policy=functools.partial(
            size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params
        ),
        cpu_offload=CPUOffload(offload_params=fsdp_cfg.cpu_offload),
        state_dict_type={
            "FULL_STATE_DICT": StateDictType.FULL_STATE_DICT,
            "LOCAL_STATE_DICT": StateDictType.LOCAL_STATE_DICT,
            "SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT,
        }[fsdp_cfg.state_dict_type],
        state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
        optim_state_dict_config=FullOptimStateDictConfig(
            offload_to_cpu=True, rank0_only=True
        ),
        limit_all_gathers=fsdp_cfg.limit_all_gathers,
        use_orig_params=fsdp_cfg.use_orig_params,
        sync_module_states=fsdp_cfg.sync_module_states,
        forward_prefetch=fsdp_cfg.forward_prefetch,
        activation_checkpointing=fsdp_cfg.activation_checkpointing,
    )
    return fsdp_plugin


def freeze_model(model, trainable_modules={}, verbose=False):
    logger.info("Start freeze")
    for name, param in model.named_parameters():
        # param.requires_grad = False
        if verbose:
            logger.info("freeze moduel: " + str(name))
        for trainable_module_name in trainable_modules:
            if trainable_module_name in name:
                # param.requires_grad = True
                if verbose:
                    logger.info("unfreeze moduel: " + str(name))
                break
    logger.info("End freeze")
    # params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]
    # params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]
    # logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M")
    # logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M")
    return


@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    if hasattr(model, "module"):
        model = model.module
    if hasattr(ema_model, "module"):
        ema_model = ema_model.module
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def log_validation(model):
    pass