Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
| def fsdp_auto_wrap_policy(model, transformer_layer_name): | |
| import functools | |
| import os | |
| from accelerate import FullyShardedDataParallelPlugin | |
| from transformers.models.t5.modeling_t5 import T5Block | |
| from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy | |
| from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder | |
| def lambda_policy_fn(module): | |
| if ( | |
| len(list(module.named_children())) == 0 | |
| and getattr(module, "weight", None) is not None | |
| and module.weight.requires_grad | |
| ): | |
| return True | |
| return False | |
| lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) | |
| transformer_wrap_policy = functools.partial( | |
| transformer_auto_wrap_policy, | |
| transformer_layer_cls=( | |
| PrefixEncoder, | |
| PromptEncoder, | |
| PromptEmbedding, | |
| transformer_layer_name, | |
| # FullyShardedDataParallelPlugin.get_module_class_from_name( | |
| # model, transformer_layer_name | |
| # ), | |
| ), | |
| ) | |
| auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) | |
| return auto_wrap_policy |