|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from rotary_embedding_torch import RotaryEmbedding |
|
|
from torch import nn |
|
|
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened |
|
|
|
|
|
__all__ = ["meta_non_persistent_buffer_init_fn"] |
|
|
|
|
|
|
|
|
def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module: |
|
|
""" |
|
|
Used for materializing `non-persistent tensor buffers` while model resuming. |
|
|
|
|
|
Since non-persistent tensor buffers are not saved in state_dict, |
|
|
when initializing model with meta device, user should materialize those buffers manually. |
|
|
|
|
|
Currently, only `rope.dummy` is this special case. |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
for submodule in module.modules(): |
|
|
if not isinstance(submodule, RotaryEmbedding): |
|
|
continue |
|
|
for buffer_name, buffer in submodule.named_buffers(recurse=False): |
|
|
if buffer.is_meta and "dummy" in buffer_name: |
|
|
materialized_buffer = torch.zeros_like(buffer, device="cpu") |
|
|
setattr(submodule, buffer_name, materialized_buffer) |
|
|
assert not any(b.is_meta for n, b in module.named_buffers()) |
|
|
return module |
|
|
|