|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Distributed ops for supporting sequence parallel. |
|
|
""" |
|
|
|
|
|
from collections import defaultdict |
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch import Tensor |
|
|
|
|
|
from common.cache import Cache |
|
|
from common.distributed.advanced import ( |
|
|
get_sequence_parallel_group, |
|
|
get_sequence_parallel_rank, |
|
|
get_sequence_parallel_world_size, |
|
|
) |
|
|
|
|
|
from .basic import get_device |
|
|
|
|
|
_SEQ_DATA_BUF = defaultdict(lambda: [None, None, None]) |
|
|
_SEQ_DATA_META_SHAPES = defaultdict() |
|
|
_SEQ_DATA_META_DTYPES = defaultdict() |
|
|
_SEQ_DATA_ASYNC_COMMS = defaultdict(list) |
|
|
_SYNC_BUFFER = defaultdict(dict) |
|
|
|
|
|
|
|
|
def single_all_to_all( |
|
|
local_input: Tensor, |
|
|
scatter_dim: int, |
|
|
gather_dim: int, |
|
|
group: dist.ProcessGroup, |
|
|
async_op: bool = False, |
|
|
): |
|
|
""" |
|
|
A function to do all-to-all on a tensor |
|
|
""" |
|
|
seq_world_size = dist.get_world_size(group) |
|
|
prev_scatter_dim = scatter_dim |
|
|
if scatter_dim != 0: |
|
|
local_input = local_input.transpose(0, scatter_dim) |
|
|
if gather_dim == 0: |
|
|
gather_dim = scatter_dim |
|
|
scatter_dim = 0 |
|
|
|
|
|
inp_shape = list(local_input.shape) |
|
|
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size |
|
|
input_t = local_input.reshape( |
|
|
[seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :] |
|
|
).contiguous() |
|
|
output = torch.empty_like(input_t) |
|
|
comm = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) |
|
|
if async_op: |
|
|
|
|
|
return output, comm, prev_scatter_dim |
|
|
|
|
|
|
|
|
output = torch.cat(output.split(1), dim=gather_dim + 1).squeeze(0) |
|
|
if prev_scatter_dim: |
|
|
output = output.transpose(0, prev_scatter_dim).contiguous() |
|
|
return output |
|
|
|
|
|
|
|
|
def _all_to_all( |
|
|
local_input: Tensor, |
|
|
scatter_dim: int, |
|
|
gather_dim: int, |
|
|
group: dist.ProcessGroup, |
|
|
): |
|
|
seq_world_size = dist.get_world_size(group) |
|
|
input_list = [ |
|
|
t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) |
|
|
] |
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] |
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
|
|
|
class SeqAllToAll(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
group: dist.ProcessGroup, |
|
|
local_input: Tensor, |
|
|
scatter_dim: int, |
|
|
gather_dim: int, |
|
|
async_op: bool, |
|
|
) -> Tensor: |
|
|
ctx.group = group |
|
|
ctx.scatter_dim = scatter_dim |
|
|
ctx.gather_dim = gather_dim |
|
|
ctx.async_op = async_op |
|
|
if async_op: |
|
|
output, comm, prev_scatter_dim = single_all_to_all( |
|
|
local_input, scatter_dim, gather_dim, group, async_op=async_op |
|
|
) |
|
|
ctx.prev_scatter_dim = prev_scatter_dim |
|
|
return output, comm |
|
|
|
|
|
return _all_to_all(local_input, scatter_dim, gather_dim, group) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: |
|
|
if ctx.async_op: |
|
|
input_t = torch.cat(grad_output[0].split(1), dim=ctx.gather_dim + 1).squeeze(0) |
|
|
if ctx.prev_scatter_dim: |
|
|
input_t = input_t.transpose(0, ctx.prev_scatter_dim) |
|
|
else: |
|
|
input_t = grad_output[0] |
|
|
return ( |
|
|
None, |
|
|
_all_to_all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group), |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
class Slice(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor: |
|
|
ctx.group = group |
|
|
ctx.rank = dist.get_rank(group) |
|
|
seq_world_size = dist.get_world_size(group) |
|
|
ctx.seq_world_size = seq_world_size |
|
|
ctx.dim = dim |
|
|
dim_size = local_input.shape[dim] |
|
|
return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous() |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]: |
|
|
dim_size = list(grad_output.size()) |
|
|
split_size = dim_size[0] |
|
|
dim_size[0] = dim_size[0] * ctx.seq_world_size |
|
|
output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device()) |
|
|
dist._all_gather_base(output, grad_output, group=ctx.group) |
|
|
return (None, torch.cat(output.split(split_size), dim=ctx.dim), None) |
|
|
|
|
|
|
|
|
class Gather(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
group: dist.ProcessGroup, |
|
|
local_input: Tensor, |
|
|
dim: int, |
|
|
grad_scale: Optional[bool] = False, |
|
|
) -> Tensor: |
|
|
ctx.group = group |
|
|
ctx.rank = dist.get_rank(group) |
|
|
ctx.dim = dim |
|
|
ctx.grad_scale = grad_scale |
|
|
seq_world_size = dist.get_world_size(group) |
|
|
ctx.seq_world_size = seq_world_size |
|
|
dim_size = list(local_input.size()) |
|
|
split_size = dim_size[0] |
|
|
ctx.part_size = dim_size[dim] |
|
|
dim_size[0] = dim_size[0] * seq_world_size |
|
|
output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device()) |
|
|
dist._all_gather_base(output, local_input.contiguous(), group=ctx.group) |
|
|
return torch.cat(output.split(split_size), dim=dim) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]: |
|
|
if ctx.grad_scale: |
|
|
grad_output = grad_output * ctx.seq_world_size |
|
|
return ( |
|
|
None, |
|
|
grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(), |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
def gather_seq_scatter_heads_qkv( |
|
|
qkv_tensor: Tensor, |
|
|
*, |
|
|
seq_dim: int, |
|
|
qkv_shape: Optional[Tensor] = None, |
|
|
cache: Cache = Cache(disable=True), |
|
|
restore_shape: bool = True, |
|
|
): |
|
|
""" |
|
|
A func to sync splited qkv tensor |
|
|
qkv_tensor: the tensor we want to do alltoall with. The last dim must |
|
|
be the projection_idx, which we will split into 3 part. After |
|
|
spliting, the gather idx will be projecttion_idx + 1 |
|
|
seq_dim: gather_dim for all2all comm |
|
|
restore_shape: if True, output will has the same shape length as input |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
if not group: |
|
|
return qkv_tensor |
|
|
world = get_sequence_parallel_world_size() |
|
|
orig_shape = qkv_tensor.shape |
|
|
scatter_dim = qkv_tensor.dim() |
|
|
bef_all2all_shape = list(orig_shape) |
|
|
qkv_proj_dim = bef_all2all_shape[-1] |
|
|
bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] |
|
|
qkv_tensor = qkv_tensor.view(bef_all2all_shape) |
|
|
qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, False) |
|
|
if restore_shape: |
|
|
out_shape = list(orig_shape) |
|
|
out_shape[seq_dim] *= world |
|
|
out_shape[-1] = qkv_proj_dim // world |
|
|
qkv_tensor = qkv_tensor.view(out_shape) |
|
|
|
|
|
|
|
|
if qkv_shape is not None: |
|
|
unpad_dim_size = cache( |
|
|
"unpad_dim_size", lambda: torch.sum(torch.prod(qkv_shape, dim=-1)).item() |
|
|
) |
|
|
if unpad_dim_size % world != 0: |
|
|
padding_size = qkv_tensor.size(seq_dim) - unpad_dim_size |
|
|
qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size) |
|
|
return qkv_tensor |
|
|
|
|
|
|
|
|
def slice_inputs(x: Tensor, dim: int, padding: bool = True): |
|
|
""" |
|
|
A func to slice the input sequence in sequence parallel |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
if group is None: |
|
|
return x |
|
|
sp_rank = get_sequence_parallel_rank() |
|
|
sp_world = get_sequence_parallel_world_size() |
|
|
dim_size = x.shape[dim] |
|
|
unit = (dim_size + sp_world - 1) // sp_world |
|
|
if padding and dim_size % sp_world: |
|
|
padding_size = sp_world - (dim_size % sp_world) |
|
|
x = _pad_tensor(x, dim, padding_size) |
|
|
slc = [slice(None)] * len(x.shape) |
|
|
slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1)) |
|
|
return x[slc] |
|
|
|
|
|
|
|
|
def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int): |
|
|
""" |
|
|
A func to remove the padding part of the tensor based on its original shape |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
if group is None: |
|
|
return x |
|
|
sp_world = get_sequence_parallel_world_size() |
|
|
if unpad_dim_size % sp_world == 0: |
|
|
return x |
|
|
padding_size = sp_world - (unpad_dim_size % sp_world) |
|
|
assert (padding_size + unpad_dim_size) % sp_world == 0 |
|
|
return _unpad_tensor(x, dim=dim, padding_size=padding_size) |
|
|
|
|
|
|
|
|
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor: |
|
|
""" |
|
|
A func to sync attention result with alltoall in sequence parallel |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
if not group: |
|
|
return x |
|
|
dim_size = x.size(seq_dim) |
|
|
sp_world = get_sequence_parallel_world_size() |
|
|
if dim_size % sp_world != 0: |
|
|
padding_size = sp_world - (dim_size % sp_world) |
|
|
x = _pad_tensor(x, seq_dim, padding_size) |
|
|
return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) |
|
|
|
|
|
|
|
|
def gather_seq_scatter_heads(x: Tensor, seq_dim: int, head_dim: int) -> Tensor: |
|
|
""" |
|
|
A func to sync embedding input with alltoall in sequence parallel |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
if not group: |
|
|
return x |
|
|
return SeqAllToAll.apply(group, x, head_dim, seq_dim, False) |
|
|
|
|
|
|
|
|
def scatter_heads(x: Tensor, dim: int) -> Tensor: |
|
|
""" |
|
|
A func to split heads before attention in sequence parallel |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
if not group: |
|
|
return x |
|
|
return Slice.apply(group, x, dim) |
|
|
|
|
|
|
|
|
def gather_heads(x: Tensor, dim: int, grad_scale: Optional[bool] = False) -> Tensor: |
|
|
""" |
|
|
A func to gather heads for the attention result in sequence parallel |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
if not group: |
|
|
return x |
|
|
return Gather.apply(group, x, dim, grad_scale) |
|
|
|
|
|
|
|
|
def gather_outputs( |
|
|
x: Tensor, |
|
|
*, |
|
|
gather_dim: int, |
|
|
padding_dim: Optional[int] = None, |
|
|
unpad_shape: Optional[Tensor] = None, |
|
|
cache: Cache = Cache(disable=True), |
|
|
scale_grad=True, |
|
|
): |
|
|
""" |
|
|
A func to gather the outputs for the model result in sequence parallel |
|
|
""" |
|
|
group = get_sequence_parallel_group() |
|
|
if not group: |
|
|
return x |
|
|
x = Gather.apply(group, x, gather_dim, scale_grad) |
|
|
if padding_dim is not None: |
|
|
unpad_dim_size = cache( |
|
|
"unpad_dim_size", lambda: torch.sum(torch.prod(unpad_shape, dim=1)).item() |
|
|
) |
|
|
x = remove_seqeunce_parallel_padding(x, padding_dim, unpad_dim_size) |
|
|
return x |
|
|
|
|
|
|
|
|
def _pad_tensor(x: Tensor, dim: int, padding_size: int): |
|
|
shape = list(x.shape) |
|
|
shape[dim] = padding_size |
|
|
pad = torch.zeros(shape, dtype=x.dtype, device=x.device) |
|
|
return torch.cat([x, pad], dim=dim) |
|
|
|
|
|
|
|
|
def _unpad_tensor(x: Tensor, dim: int, padding_size): |
|
|
slc = [slice(None)] * len(x.shape) |
|
|
slc[dim] = slice(0, -padding_size) |
|
|
return x[slc] |
|
|
|
|
|
|
|
|
def _broadcast_data(data, shape, dtype, src, group, async_op): |
|
|
comms = [] |
|
|
if isinstance(data, (list, tuple)): |
|
|
for i, sub_shape in enumerate(shape): |
|
|
comms += _broadcast_data(data[i], sub_shape, dtype[i], src, group, async_op) |
|
|
elif isinstance(data, dict): |
|
|
for key, sub_data in data.items(): |
|
|
comms += _broadcast_data(sub_data, shape[key], dtype[key], src, group, async_op) |
|
|
elif isinstance(data, Tensor): |
|
|
comms.append(dist.broadcast(data, src=src, group=group, async_op=async_op)) |
|
|
return comms |
|
|
|
|
|
|
|
|
def _traverse(data: Any, op: Callable) -> Union[None, List, Dict, Any]: |
|
|
if isinstance(data, (list, tuple)): |
|
|
return [_traverse(sub_data, op) for sub_data in data] |
|
|
elif isinstance(data, dict): |
|
|
return {key: _traverse(sub_data, op) for key, sub_data in data.items()} |
|
|
elif isinstance(data, Tensor): |
|
|
return op(data) |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
def _get_shapes(data): |
|
|
return _traverse(data, op=lambda x: x.shape) |
|
|
|
|
|
|
|
|
def _get_dtypes(data): |
|
|
return _traverse(data, op=lambda x: x.dtype) |
|
|
|
|
|
|
|
|
def _construct_broadcast_buffer(shapes, dtypes, device): |
|
|
if isinstance(shapes, torch.Size): |
|
|
return torch.empty(shapes, dtype=dtypes, device=device) |
|
|
|
|
|
if isinstance(shapes, (list, tuple)): |
|
|
buffer = [] |
|
|
for i, sub_shape in enumerate(shapes): |
|
|
buffer.append(_construct_broadcast_buffer(sub_shape, dtypes[i], device)) |
|
|
elif isinstance(shapes, dict): |
|
|
buffer = {} |
|
|
for key, sub_shape in shapes.items(): |
|
|
buffer[key] = _construct_broadcast_buffer(sub_shape, dtypes[key], device) |
|
|
else: |
|
|
return None |
|
|
return buffer |
|
|
|
|
|
|
|
|
class SPDistForward: |
|
|
"""A forward tool to sync different result across sp group |
|
|
|
|
|
Args: |
|
|
module: a function or module to process users input |
|
|
sp_step: current training step to judge which rank to broadcast its result to all |
|
|
name: a distinct str to save meta and async comm |
|
|
comm_shape: if different ranks have different shape, mark this arg to True |
|
|
device: the device for current rank, can be empty |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
comm_shape: bool, |
|
|
device: torch.device = None, |
|
|
): |
|
|
self.name = name |
|
|
self.comm_shape = comm_shape |
|
|
if device: |
|
|
self.device = device |
|
|
else: |
|
|
self.device = get_device() |
|
|
|
|
|
def __call__(self, inputs) -> Any: |
|
|
group = get_sequence_parallel_group() |
|
|
if not group: |
|
|
yield inputs |
|
|
else: |
|
|
device = self.device |
|
|
sp_world = get_sequence_parallel_world_size() |
|
|
sp_rank = get_sequence_parallel_rank() |
|
|
for local_step in range(sp_world): |
|
|
src_rank = dist.get_global_rank(group, local_step) |
|
|
is_src = sp_rank == local_step |
|
|
local_shapes = [] |
|
|
local_dtypes = [] |
|
|
if local_step == 0: |
|
|
local_result = inputs |
|
|
_SEQ_DATA_BUF[self.name][-1] = local_result |
|
|
local_shapes = _get_shapes(local_result) |
|
|
local_dtypes = _get_dtypes(local_result) |
|
|
if self.comm_shape: |
|
|
group_shapes_lists = [None] * sp_world |
|
|
dist.all_gather_object(group_shapes_lists, local_shapes, group=group) |
|
|
_SEQ_DATA_META_SHAPES[self.name] = group_shapes_lists |
|
|
else: |
|
|
_SEQ_DATA_META_SHAPES[self.name] = [local_shapes] * sp_world |
|
|
_SEQ_DATA_META_DTYPES[self.name] = local_dtypes |
|
|
shapes = _SEQ_DATA_META_SHAPES[self.name][local_step] |
|
|
dtypes = _SEQ_DATA_META_DTYPES[self.name] |
|
|
buf_id = local_step % 2 |
|
|
if local_step == 0: |
|
|
sync_data = ( |
|
|
local_result |
|
|
if is_src |
|
|
else _construct_broadcast_buffer(shapes, dtypes, device) |
|
|
) |
|
|
_broadcast_data(sync_data, shapes, dtypes, src_rank, group, False) |
|
|
_SEQ_DATA_BUF[self.name][buf_id] = sync_data |
|
|
|
|
|
|
|
|
if _SEQ_DATA_ASYNC_COMMS[self.name]: |
|
|
for comm in _SEQ_DATA_ASYNC_COMMS[self.name]: |
|
|
comm.wait() |
|
|
|
|
|
if local_step < sp_world - 1: |
|
|
next_buf_id = 1 - buf_id |
|
|
shapes = _SEQ_DATA_META_SHAPES[self.name][local_step + 1] |
|
|
src_rank = dist.get_global_rank(group, local_step + 1) |
|
|
is_src = sp_rank == local_step + 1 |
|
|
next_sync_data = ( |
|
|
_SEQ_DATA_BUF[self.name][-1] |
|
|
if is_src |
|
|
else _construct_broadcast_buffer(shapes, dtypes, device) |
|
|
) |
|
|
_SEQ_DATA_ASYNC_COMMS[self.name] = _broadcast_data( |
|
|
next_sync_data, shapes, dtypes, src_rank, group, True |
|
|
) |
|
|
_SEQ_DATA_BUF[self.name][next_buf_id] = next_sync_data |
|
|
yield _SEQ_DATA_BUF[self.name][buf_id] |
|
|
|
|
|
|
|
|
sync_inputs = SPDistForward(name="bef_fwd", comm_shape=True) |
|
|
|
|
|
|
|
|
def sync_data(data, sp_idx, name="tmp"): |
|
|
group = get_sequence_parallel_group() |
|
|
if group is None: |
|
|
return data |
|
|
|
|
|
|
|
|
sp_rank = get_sequence_parallel_rank() |
|
|
src_rank = dist.get_global_rank(group, sp_idx) |
|
|
objects = [data] if sp_rank == sp_idx else [None] |
|
|
dist.broadcast_object_list(objects, src=src_rank, group=group) |
|
|
|
|
|
return objects[0] |
|
|
|