alexnasa's picture
Upload 54 files
295978e verified
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Codes adapted from [SeedVR]
# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
"""
Distributed basic functions.
"""
import os
import torch
from torch import nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
def get_global_rank() -> int:
"""
Get the global rank, the global index of the GPU.
"""
return int(os.environ.get("RANK", "0"))
def get_local_rank() -> int:
"""
Get the local rank, the local index of the GPU.
"""
return int(os.environ.get("LOCAL_RANK", "0"))
def get_world_size() -> int:
"""
Get the world size, the total amount of GPUs.
"""
return int(os.environ.get("WORLD_SIZE", "1"))
def get_device() -> torch.device:
"""
Get current rank device.
"""
return torch.device("cuda", get_local_rank())
def barrier_if_distributed(*args, **kwargs):
"""
Synchronizes all processes if under distributed context.
"""
if dist.is_initialized():
return dist.barrier(*args, **kwargs)
def init_torch(cudnn_benchmark=True):
"""
Common PyTorch initialization configuration.
"""
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.cuda.set_device(get_local_rank())
dist.init_process_group(
backend="nccl",
rank=get_global_rank(),
world_size=get_world_size(),
)
def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
return DistributedDataParallel(
module=module,
device_ids=[get_local_rank()],
output_device=get_local_rank(),
**kwargs,
)
def meta_param_init_fn(module: nn.Module) -> None:
"""
Used for model inited onto meta device.
Init meta param/buffer with empty tensor.
We don't care numerical correctness in this func.
FSDP will sync param/buffer state from rank0 to the other ranks.
"""
with torch.no_grad():
for submodule in module.modules():
for param_name, param in submodule.named_parameters(recurse=False):
if not _is_fsdp_flattened(param) and param.is_meta:
materialized_param = nn.Parameter(torch.empty_like(param, device="cpu"))
setattr(submodule, param_name, materialized_param)
for buffer_name, buffer in submodule.named_buffers(recurse=False):
if not _is_fsdp_flattened(buffer) and buffer.is_meta:
materialized_param = torch.empty_like(buffer, device="cpu")
setattr(submodule, buffer_name, materialized_param)
torch.cuda.empty_cache()
def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
"""
Materialize meta device buffers that are not persistent in state_dict.
Handles special cases like RotaryEmbedding.freqs.
"""
with torch.no_grad():
for submodule in module.modules():
if hasattr(submodule, "freqs"):
freqs = getattr(submodule, "freqs")
if isinstance(freqs, torch.Tensor) and freqs.is_meta:
dim = submodule.dim
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
dim = 5120 # 1536
num_heads = 40 # 12
# dim = 1536
# num_heads = 12
d = dim // num_heads
freqs_tensor = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
], dim=1).to(dtype=torch.cfloat, device="cpu")
setattr(submodule, "freqs", freqs_tensor)
print(f"Successfully materialized freqs for {submodule.__class__.__name__}")
assert not any(b.is_meta for n, b in module.named_buffers())
return module