Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
| import torch | |
| from torch.distributed.fsdp import ( | |
| # FullyShardedDataParallel as FSDP, | |
| # CPUOffload, | |
| MixedPrecision, | |
| # BackwardPrefetch, | |
| # ShardingStrategy, | |
| ) | |
| # requires grad scaler in main loop | |
| fpSixteen = MixedPrecision( | |
| param_dtype=torch.float16, | |
| # Gradient communication precision. | |
| reduce_dtype=torch.float16, | |
| # Buffer precision. | |
| buffer_dtype=torch.float16, | |
| ) | |
| bfSixteen = MixedPrecision( | |
| param_dtype=torch.bfloat16, | |
| # Gradient communication precision. | |
| reduce_dtype=torch.bfloat16, | |
| # Buffer precision. | |
| buffer_dtype=torch.bfloat16, | |
| cast_forward_inputs=True, | |
| ) | |
| bfSixteen_mixed = MixedPrecision( | |
| param_dtype=torch.float32, | |
| reduce_dtype=torch.bfloat16, | |
| buffer_dtype=torch.bfloat16, | |
| ) | |
| fp32_policy = MixedPrecision( | |
| param_dtype=torch.float32, | |
| reduce_dtype=torch.float32, | |
| buffer_dtype=torch.float32, | |
| ) | |