2024-01-25 07:48:46 +00:00
|
|
|
from typing import Any, List, Optional, Tuple
|
2023-11-02 02:21:24 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
|
|
|
MOE_KERNEL = None
|
|
|
|
|
|
|
|
|
|
|
|
def load_moe():
|
|
|
|
global MOE_KERNEL
|
2024-01-25 09:01:48 +00:00
|
|
|
from colossalai.kernel.kernel_loader import MoeLoader
|
2023-11-02 02:21:24 +00:00
|
|
|
|
2024-01-25 09:01:48 +00:00
|
|
|
MOE_KERNEL = MoeLoader().load()
|
2023-11-02 02:21:24 +00:00
|
|
|
|
|
|
|
|
|
|
|
class AllGather(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def forward(
|
|
|
|
ctx: Any,
|
|
|
|
inputs: Tensor,
|
|
|
|
group: Optional[ProcessGroup] = None,
|
|
|
|
overlap: bool = False,
|
|
|
|
) -> Tuple[Tensor, Any]:
|
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
outputs: Tensor
|
|
|
|
handle: Optional[Work], if overlap is True
|
|
|
|
"""
|
|
|
|
assert ctx is not None or not overlap
|
|
|
|
|
|
|
|
if ctx is not None:
|
|
|
|
ctx.comm_grp = group
|
|
|
|
|
|
|
|
comm_size = dist.get_world_size(group)
|
|
|
|
if comm_size == 1:
|
|
|
|
return inputs.unsqueeze(0), None
|
|
|
|
|
|
|
|
buffer_shape = (comm_size,) + inputs.shape
|
|
|
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
|
|
|
if not overlap:
|
|
|
|
dist.all_gather(buffer_list, inputs, group=group)
|
|
|
|
return outputs, None
|
|
|
|
else:
|
|
|
|
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
|
|
|
|
return outputs, handle
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
|
|
|
return (
|
|
|
|
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class ReduceScatter(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def forward(
|
|
|
|
ctx: Any,
|
|
|
|
inputs: Tensor,
|
2023-11-09 06:31:00 +00:00
|
|
|
group: ProcessGroup,
|
2023-11-02 02:21:24 +00:00
|
|
|
overlap: bool = False,
|
|
|
|
) -> Tuple[Tensor, Any]:
|
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
outputs: Tensor
|
|
|
|
handle: Optional[Work], if overlap is True
|
|
|
|
"""
|
|
|
|
assert ctx is not None or not overlap
|
|
|
|
|
|
|
|
if ctx is not None:
|
|
|
|
ctx.comm_grp = group
|
|
|
|
|
|
|
|
comm_size = dist.get_world_size(group)
|
|
|
|
if comm_size == 1:
|
|
|
|
return inputs.squeeze(0), None
|
|
|
|
|
|
|
|
if not inputs.is_contiguous():
|
|
|
|
inputs = inputs.contiguous()
|
|
|
|
|
|
|
|
output_shape = inputs.shape[1:]
|
|
|
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
|
|
|
if not overlap:
|
|
|
|
dist.reduce_scatter(outputs, buffer_list, group=group)
|
|
|
|
return outputs, None
|
|
|
|
else:
|
|
|
|
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
|
|
|
|
return outputs, handle
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
|
|
|
# TODO: support async backward
|
|
|
|
return (
|
|
|
|
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class AllToAll(torch.autograd.Function):
|
|
|
|
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
|
|
|
operation in torch.distributed.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(
|
|
|
|
ctx: Any,
|
|
|
|
inputs: Tensor,
|
2023-11-09 06:31:00 +00:00
|
|
|
group: ProcessGroup,
|
2023-11-02 02:21:24 +00:00
|
|
|
overlap: bool = False,
|
|
|
|
) -> Tuple[Tensor, Any]:
|
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
outputs: Tensor
|
|
|
|
handle: Optional[Work], if overlap is True
|
|
|
|
"""
|
2023-11-09 06:31:00 +00:00
|
|
|
assert ctx is not None or not overlap
|
|
|
|
|
2023-11-02 02:21:24 +00:00
|
|
|
if ctx is not None:
|
|
|
|
ctx.comm_grp = group
|
|
|
|
if not inputs.is_contiguous():
|
|
|
|
inputs = inputs.contiguous()
|
|
|
|
if dist.get_world_size(group) == 1:
|
|
|
|
return inputs, None
|
|
|
|
output = torch.empty_like(inputs)
|
|
|
|
if not overlap:
|
|
|
|
dist.all_to_all_single(output, inputs, group=group)
|
|
|
|
return output, None
|
|
|
|
else:
|
|
|
|
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
|
|
|
|
return output, handle
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
|
|
|
return (
|
2023-11-09 06:31:00 +00:00
|
|
|
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
|
|
|
None,
|
2023-11-02 02:21:24 +00:00
|
|
|
None,
|
2023-11-09 06:31:00 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class HierarchicalAllToAll(torch.autograd.Function):
|
|
|
|
@staticmethod
|
2024-01-25 09:01:48 +00:00
|
|
|
def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor:
|
2023-11-09 06:31:00 +00:00
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
outputs: Tensor
|
|
|
|
"""
|
|
|
|
# TODO: we can reduce comm volume by removing empty capacity
|
|
|
|
if ctx is not None:
|
|
|
|
ctx.comm_grps = groups
|
2023-11-17 02:53:00 +00:00
|
|
|
ctx.src_rank = src_rank
|
2023-11-09 06:31:00 +00:00
|
|
|
intra_node_group, inter_node_group = groups
|
|
|
|
|
|
|
|
local_world_size = dist.get_world_size(intra_node_group)
|
|
|
|
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
|
|
|
|
world_size = local_world_size * num_group
|
|
|
|
outputs = torch.empty_like(inputs)
|
|
|
|
|
|
|
|
if dist.get_rank() == src_rank:
|
|
|
|
# intra-node gather
|
|
|
|
intra_output = [torch.empty_like(inputs) for _ in range(local_world_size)]
|
|
|
|
dist.gather(inputs, intra_output, dst=src_rank, group=intra_node_group)
|
|
|
|
|
|
|
|
intra_output = [v.chunk(world_size, dim=0) for v in intra_output]
|
|
|
|
intra_output = torch.cat(sum(zip(*intra_output), ()))
|
|
|
|
|
|
|
|
# inter-node all-to-all
|
|
|
|
if inter_node_group is not None:
|
|
|
|
inter_output = torch.empty_like(intra_output)
|
|
|
|
dist.all_to_all_single(inter_output, intra_output, group=inter_node_group)
|
|
|
|
|
|
|
|
# layout transform
|
|
|
|
inter_output = inter_output.chunk(num_group, dim=0)
|
|
|
|
inter_output = [v.chunk(local_world_size, dim=0) for v in inter_output]
|
|
|
|
intra_output = torch.cat(sum(zip(*inter_output), ()))
|
|
|
|
|
|
|
|
# intra-node scatter
|
|
|
|
intra_output = list(intra_output.chunk(local_world_size, dim=0))
|
|
|
|
dist.scatter(outputs, intra_output, src=src_rank, group=intra_node_group)
|
|
|
|
|
|
|
|
else:
|
|
|
|
dist.gather(inputs, dst=src_rank, group=intra_node_group)
|
|
|
|
dist.scatter(outputs, src=src_rank, group=intra_node_group)
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
@staticmethod
|
2023-11-17 02:53:00 +00:00
|
|
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
2023-11-09 06:31:00 +00:00
|
|
|
return (
|
2023-11-17 02:53:00 +00:00
|
|
|
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
|
|
|
|
None,
|
2023-11-02 02:21:24 +00:00
|
|
|
None,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class MoeDispatch(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
@custom_fwd
|
|
|
|
def forward(ctx, tokens, mask, dest_idx, ec):
|
|
|
|
s = tokens.size(0)
|
|
|
|
h = tokens.size(1)
|
|
|
|
dtype = tokens.dtype
|
|
|
|
|
|
|
|
if MOE_KERNEL is None:
|
|
|
|
load_moe()
|
|
|
|
if tokens.dtype != torch.float32:
|
|
|
|
tokens = tokens.to(torch.float32)
|
|
|
|
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
|
|
|
if expert_input.dtype != dtype:
|
|
|
|
expert_input = expert_input.to(dtype)
|
|
|
|
ctx.save_for_backward(mask, dest_idx)
|
|
|
|
ctx.s = s
|
|
|
|
ctx.h = h
|
|
|
|
ctx.ec = ec
|
|
|
|
ctx.dtype = dtype
|
|
|
|
|
|
|
|
return expert_input
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@custom_bwd
|
|
|
|
def backward(ctx, output_grad):
|
|
|
|
mask, dest_idx = ctx.saved_tensors
|
|
|
|
if output_grad.dtype != torch.float32:
|
|
|
|
output_grad = output_grad.to(torch.float32)
|
|
|
|
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
|
|
|
if d_tokens.dtype != ctx.dtype:
|
|
|
|
d_tokens = d_tokens.to(ctx.dtype)
|
|
|
|
return d_tokens, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
class MoeCombine(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
@custom_fwd
|
|
|
|
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
|
|
|
assert logits.dtype == torch.float32
|
|
|
|
|
|
|
|
s = logits.size(0)
|
|
|
|
e = logits.size(1)
|
|
|
|
c = ec // e
|
|
|
|
h = expert_tokens.size(-1)
|
|
|
|
dtype = expert_tokens.dtype
|
|
|
|
|
|
|
|
if expert_tokens.dtype != torch.float32:
|
|
|
|
expert_tokens = expert_tokens.to(torch.float32)
|
|
|
|
if MOE_KERNEL is None:
|
|
|
|
load_moe()
|
|
|
|
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
|
|
|
|
if output.dtype != dtype:
|
|
|
|
output = output.to(dtype)
|
|
|
|
|
|
|
|
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
|
|
|
ctx.s = s
|
|
|
|
ctx.e = e
|
|
|
|
ctx.c = c
|
|
|
|
ctx.h = h
|
|
|
|
ctx.dtype = dtype
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@custom_bwd
|
|
|
|
def backward(ctx, tokens_grad):
|
|
|
|
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
|
|
|
if tokens_grad.dtype != torch.float32:
|
|
|
|
tokens_grad = tokens_grad.to(torch.float32)
|
|
|
|
|
2024-01-25 09:01:48 +00:00
|
|
|
d_expert, d_logits = MOE_KERNEL.combine_backward(
|
|
|
|
ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx
|
|
|
|
)
|
2023-11-02 02:21:24 +00:00
|
|
|
if d_expert.dtype != ctx.dtype:
|
|
|
|
d_expert = d_expert.to(ctx.dtype)
|
|
|
|
|
|
|
|
return d_expert, d_logits, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
|
|
|
|
dim0 = inputs.size(0)
|
|
|
|
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
|
|
|
if flag and use_kernel:
|
|
|
|
if MOE_KERNEL is None:
|
|
|
|
load_moe()
|
|
|
|
return MOE_KERNEL.cumsum_sub_one(inputs)
|
|
|
|
else:
|
|
|
|
return torch.cumsum(inputs, dim=0) - 1
|
|
|
|
|
|
|
|
|
|
|
|
class MoeInGradScaler(torch.autograd.Function):
|
|
|
|
"""
|
|
|
|
Scale the gradient back by the number of experts
|
|
|
|
because the batch size increases in the moe stage
|
|
|
|
"""
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
|
|
|
if ctx is not None:
|
|
|
|
ctx.ep_size = ep_size
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
|
|
|
assert len(grad_outputs) == 1
|
|
|
|
grad = grad_outputs[0]
|
|
|
|
if ctx.ep_size != 1:
|
|
|
|
grad = grad * ctx.ep_size
|
|
|
|
return grad, None
|
|
|
|
|
|
|
|
|
|
|
|
class MoeOutGradScaler(torch.autograd.Function):
|
|
|
|
"""
|
|
|
|
Scale the gradient by the number of experts
|
|
|
|
because the batch size increases in the moe stage
|
|
|
|
"""
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
|
|
|
ctx.ep_size = ep_size
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
|
|
|
assert len(grad_outputs) == 1
|
|
|
|
grad = grad_outputs[0]
|
|
|
|
if ctx.ep_size != 1:
|
|
|
|
grad = grad / ctx.ep_size
|
|
|
|
return grad, None
|
2024-01-25 07:48:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all(
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
input_split_sizes: Optional[List[int]] = None,
|
|
|
|
output_split_sizes: Optional[List[int]] = None,
|
|
|
|
group=None,
|
|
|
|
async_op: bool = False,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
outputs: Tensor
|
|
|
|
handle: Optional[Work], if overlap is True
|
|
|
|
"""
|
|
|
|
outputs_shape = list(inputs.shape)
|
|
|
|
if output_split_sizes is not None:
|
|
|
|
outputs_shape[0] = sum(output_split_sizes)
|
|
|
|
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
inputs = inputs.contiguous()
|
|
|
|
outputs = outputs.contiguous()
|
|
|
|
handle = dist.all_to_all_single(
|
|
|
|
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
|
|
|
)
|
|
|
|
return outputs, handle
|
|
|
|
|
|
|
|
|
|
|
|
class AllToAllUneven(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def forward(
|
|
|
|
ctx,
|
|
|
|
inputs,
|
|
|
|
input_split_sizes=None,
|
|
|
|
output_split_sizes=None,
|
|
|
|
group=None,
|
|
|
|
overlap: bool = False,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Returns:
|
|
|
|
outputs: Tensor
|
|
|
|
handle: Optional[Work], if overlap is True
|
|
|
|
"""
|
|
|
|
ctx.input_split_sizes = input_split_sizes
|
|
|
|
ctx.output_split_sizes = output_split_sizes
|
|
|
|
ctx.group = group
|
|
|
|
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, *grad_outputs):
|
|
|
|
return (
|
|
|
|
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def all_to_all_uneven(
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
input_split_sizes: Optional[List[int]] = None,
|
|
|
|
output_split_sizes: Optional[List[int]] = None,
|
|
|
|
group=None,
|
|
|
|
overlap: bool = False,
|
|
|
|
):
|
|
|
|
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|