mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
396 lines
12 KiB
396 lines
12 KiB
from typing import Any, List, Optional, Tuple |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch import Tensor |
|
from torch.cuda.amp import custom_bwd, custom_fwd |
|
from torch.distributed import ProcessGroup |
|
|
|
MOE_KERNEL = None |
|
|
|
|
|
def load_moe(): |
|
global MOE_KERNEL |
|
from colossalai.kernel.kernel_loader import MoeLoader |
|
|
|
MOE_KERNEL = MoeLoader().load() |
|
|
|
|
|
class AllGather(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx: Any, |
|
inputs: Tensor, |
|
group: Optional[ProcessGroup] = None, |
|
overlap: bool = False, |
|
) -> Tuple[Tensor, Any]: |
|
""" |
|
Returns: |
|
outputs: Tensor |
|
handle: Optional[Work], if overlap is True |
|
""" |
|
assert ctx is not None or not overlap |
|
|
|
if ctx is not None: |
|
ctx.comm_grp = group |
|
|
|
comm_size = dist.get_world_size(group) |
|
if comm_size == 1: |
|
return inputs.unsqueeze(0), None |
|
|
|
buffer_shape = (comm_size,) + inputs.shape |
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) |
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) |
|
if not overlap: |
|
dist.all_gather(buffer_list, inputs, group=group) |
|
return outputs, None |
|
else: |
|
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) |
|
return outputs, handle |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: |
|
return ( |
|
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], |
|
None, |
|
None, |
|
) |
|
|
|
|
|
class ReduceScatter(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx: Any, |
|
inputs: Tensor, |
|
group: ProcessGroup, |
|
overlap: bool = False, |
|
) -> Tuple[Tensor, Any]: |
|
""" |
|
Returns: |
|
outputs: Tensor |
|
handle: Optional[Work], if overlap is True |
|
""" |
|
assert ctx is not None or not overlap |
|
|
|
if ctx is not None: |
|
ctx.comm_grp = group |
|
|
|
comm_size = dist.get_world_size(group) |
|
if comm_size == 1: |
|
return inputs.squeeze(0), None |
|
|
|
if not inputs.is_contiguous(): |
|
inputs = inputs.contiguous() |
|
|
|
output_shape = inputs.shape[1:] |
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) |
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) |
|
if not overlap: |
|
dist.reduce_scatter(outputs, buffer_list, group=group) |
|
return outputs, None |
|
else: |
|
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) |
|
return outputs, handle |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: |
|
# TODO: support async backward |
|
return ( |
|
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], |
|
None, |
|
None, |
|
) |
|
|
|
|
|
class AllToAll(torch.autograd.Function): |
|
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single |
|
operation in torch.distributed. |
|
""" |
|
|
|
@staticmethod |
|
def forward( |
|
ctx: Any, |
|
inputs: Tensor, |
|
group: ProcessGroup, |
|
overlap: bool = False, |
|
) -> Tuple[Tensor, Any]: |
|
""" |
|
Returns: |
|
outputs: Tensor |
|
handle: Optional[Work], if overlap is True |
|
""" |
|
assert ctx is not None or not overlap |
|
|
|
if ctx is not None: |
|
ctx.comm_grp = group |
|
if not inputs.is_contiguous(): |
|
inputs = inputs.contiguous() |
|
if dist.get_world_size(group) == 1: |
|
return inputs, None |
|
output = torch.empty_like(inputs) |
|
if not overlap: |
|
dist.all_to_all_single(output, inputs, group=group) |
|
return output, None |
|
else: |
|
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True) |
|
return output, handle |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: |
|
return ( |
|
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], |
|
None, |
|
None, |
|
) |
|
|
|
|
|
class HierarchicalAllToAll(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor: |
|
""" |
|
Returns: |
|
outputs: Tensor |
|
""" |
|
# TODO: we can reduce comm volume by removing empty capacity |
|
if ctx is not None: |
|
ctx.comm_grps = groups |
|
ctx.src_rank = src_rank |
|
intra_node_group, inter_node_group = groups |
|
|
|
local_world_size = dist.get_world_size(intra_node_group) |
|
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1 |
|
world_size = local_world_size * num_group |
|
outputs = torch.empty_like(inputs) |
|
|
|
if dist.get_rank() == src_rank: |
|
# intra-node gather |
|
intra_output = [torch.empty_like(inputs) for _ in range(local_world_size)] |
|
dist.gather(inputs, intra_output, dst=src_rank, group=intra_node_group) |
|
|
|
intra_output = [v.chunk(world_size, dim=0) for v in intra_output] |
|
intra_output = torch.cat(sum(zip(*intra_output), ())) |
|
|
|
# inter-node all-to-all |
|
if inter_node_group is not None: |
|
inter_output = torch.empty_like(intra_output) |
|
dist.all_to_all_single(inter_output, intra_output, group=inter_node_group) |
|
|
|
# layout transform |
|
inter_output = inter_output.chunk(num_group, dim=0) |
|
inter_output = [v.chunk(local_world_size, dim=0) for v in inter_output] |
|
intra_output = torch.cat(sum(zip(*inter_output), ())) |
|
|
|
# intra-node scatter |
|
intra_output = list(intra_output.chunk(local_world_size, dim=0)) |
|
dist.scatter(outputs, intra_output, src=src_rank, group=intra_node_group) |
|
|
|
else: |
|
dist.gather(inputs, dst=src_rank, group=intra_node_group) |
|
dist.scatter(outputs, src=src_rank, group=intra_node_group) |
|
|
|
return outputs |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: |
|
return ( |
|
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank), |
|
None, |
|
None, |
|
) |
|
|
|
|
|
class MoeDispatch(torch.autograd.Function): |
|
@staticmethod |
|
@custom_fwd |
|
def forward(ctx, tokens, mask, dest_idx, ec): |
|
s = tokens.size(0) |
|
h = tokens.size(1) |
|
dtype = tokens.dtype |
|
|
|
if MOE_KERNEL is None: |
|
load_moe() |
|
if tokens.dtype != torch.float32: |
|
tokens = tokens.to(torch.float32) |
|
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx) |
|
if expert_input.dtype != dtype: |
|
expert_input = expert_input.to(dtype) |
|
ctx.save_for_backward(mask, dest_idx) |
|
ctx.s = s |
|
ctx.h = h |
|
ctx.ec = ec |
|
ctx.dtype = dtype |
|
|
|
return expert_input |
|
|
|
@staticmethod |
|
@custom_bwd |
|
def backward(ctx, output_grad): |
|
mask, dest_idx = ctx.saved_tensors |
|
if output_grad.dtype != torch.float32: |
|
output_grad = output_grad.to(torch.float32) |
|
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) |
|
if d_tokens.dtype != ctx.dtype: |
|
d_tokens = d_tokens.to(ctx.dtype) |
|
return d_tokens, None, None, None |
|
|
|
|
|
class MoeCombine(torch.autograd.Function): |
|
@staticmethod |
|
@custom_fwd |
|
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): |
|
assert logits.dtype == torch.float32 |
|
|
|
s = logits.size(0) |
|
e = logits.size(1) |
|
c = ec // e |
|
h = expert_tokens.size(-1) |
|
dtype = expert_tokens.dtype |
|
|
|
if expert_tokens.dtype != torch.float32: |
|
expert_tokens = expert_tokens.to(torch.float32) |
|
if MOE_KERNEL is None: |
|
load_moe() |
|
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx) |
|
if output.dtype != dtype: |
|
output = output.to(dtype) |
|
|
|
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) |
|
ctx.s = s |
|
ctx.e = e |
|
ctx.c = c |
|
ctx.h = h |
|
ctx.dtype = dtype |
|
|
|
return output |
|
|
|
@staticmethod |
|
@custom_bwd |
|
def backward(ctx, tokens_grad): |
|
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors |
|
if tokens_grad.dtype != torch.float32: |
|
tokens_grad = tokens_grad.to(torch.float32) |
|
|
|
d_expert, d_logits = MOE_KERNEL.combine_backward( |
|
ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx |
|
) |
|
if d_expert.dtype != ctx.dtype: |
|
d_expert = d_expert.to(ctx.dtype) |
|
|
|
return d_expert, d_logits, None, None, None |
|
|
|
|
|
def moe_cumsum(inputs: Tensor, use_kernel: bool = False): |
|
dim0 = inputs.size(0) |
|
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) |
|
if flag and use_kernel: |
|
if MOE_KERNEL is None: |
|
load_moe() |
|
return MOE_KERNEL.cumsum_sub_one(inputs) |
|
else: |
|
return torch.cumsum(inputs, dim=0) - 1 |
|
|
|
|
|
class MoeInGradScaler(torch.autograd.Function): |
|
""" |
|
Scale the gradient back by the number of experts |
|
because the batch size increases in the moe stage |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: |
|
if ctx is not None: |
|
ctx.ep_size = ep_size |
|
return inputs |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: |
|
assert len(grad_outputs) == 1 |
|
grad = grad_outputs[0] |
|
if ctx.ep_size != 1: |
|
grad = grad * ctx.ep_size |
|
return grad, None |
|
|
|
|
|
class MoeOutGradScaler(torch.autograd.Function): |
|
""" |
|
Scale the gradient by the number of experts |
|
because the batch size increases in the moe stage |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: |
|
ctx.ep_size = ep_size |
|
return inputs |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: |
|
assert len(grad_outputs) == 1 |
|
grad = grad_outputs[0] |
|
if ctx.ep_size != 1: |
|
grad = grad / ctx.ep_size |
|
return grad, None |
|
|
|
|
|
def _all_to_all( |
|
inputs: torch.Tensor, |
|
input_split_sizes: Optional[List[int]] = None, |
|
output_split_sizes: Optional[List[int]] = None, |
|
group=None, |
|
async_op: bool = False, |
|
): |
|
""" |
|
Returns: |
|
outputs: Tensor |
|
handle: Optional[Work], if overlap is True |
|
""" |
|
outputs_shape = list(inputs.shape) |
|
if output_split_sizes is not None: |
|
outputs_shape[0] = sum(output_split_sizes) |
|
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) |
|
inputs = inputs.contiguous() |
|
outputs = outputs.contiguous() |
|
handle = dist.all_to_all_single( |
|
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op |
|
) |
|
return outputs, handle |
|
|
|
|
|
class AllToAllUneven(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx, |
|
inputs, |
|
input_split_sizes=None, |
|
output_split_sizes=None, |
|
group=None, |
|
overlap: bool = False, |
|
): |
|
""" |
|
Returns: |
|
outputs: Tensor |
|
handle: Optional[Work], if overlap is True |
|
""" |
|
ctx.input_split_sizes = input_split_sizes |
|
ctx.output_split_sizes = output_split_sizes |
|
ctx.group = group |
|
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_outputs): |
|
return ( |
|
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0], |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
def all_to_all_uneven( |
|
inputs: torch.Tensor, |
|
input_split_sizes: Optional[List[int]] = None, |
|
output_split_sizes: Optional[List[int]] = None, |
|
group=None, |
|
overlap: bool = False, |
|
): |
|
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
|
|