mirror of https://github.com/hpcaitech/ColossalAI
276 lines
8.2 KiB
Python
276 lines
8.2 KiB
Python
![]() |
from typing import Any, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch import Tensor
|
||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||
|
from torch.distributed import ProcessGroup
|
||
|
|
||
|
from colossalai.moe.manager import MOE_MANAGER
|
||
|
|
||
|
MOE_KERNEL = None
|
||
|
|
||
|
|
||
|
def load_moe():
|
||
|
global MOE_KERNEL
|
||
|
from colossalai.kernel.op_builder import MOEBuilder
|
||
|
|
||
|
MOE_KERNEL = MOEBuilder().load()
|
||
|
|
||
|
|
||
|
class AllGather(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(
|
||
|
ctx: Any,
|
||
|
inputs: Tensor,
|
||
|
group: Optional[ProcessGroup] = None,
|
||
|
overlap: bool = False,
|
||
|
) -> Tuple[Tensor, Any]:
|
||
|
"""
|
||
|
Returns:
|
||
|
outputs: Tensor
|
||
|
handle: Optional[Work], if overlap is True
|
||
|
"""
|
||
|
assert ctx is not None or not overlap
|
||
|
|
||
|
if ctx is not None:
|
||
|
ctx.comm_grp = group
|
||
|
|
||
|
comm_size = dist.get_world_size(group)
|
||
|
if comm_size == 1:
|
||
|
return inputs.unsqueeze(0), None
|
||
|
|
||
|
buffer_shape = (comm_size,) + inputs.shape
|
||
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||
|
if not overlap:
|
||
|
dist.all_gather(buffer_list, inputs, group=group)
|
||
|
return outputs, None
|
||
|
else:
|
||
|
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
|
||
|
return outputs, handle
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||
|
return (
|
||
|
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
class ReduceScatter(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(
|
||
|
ctx: Any,
|
||
|
inputs: Tensor,
|
||
|
group: Optional[ProcessGroup] = None,
|
||
|
overlap: bool = False,
|
||
|
) -> Tuple[Tensor, Any]:
|
||
|
"""
|
||
|
Returns:
|
||
|
outputs: Tensor
|
||
|
handle: Optional[Work], if overlap is True
|
||
|
"""
|
||
|
assert ctx is not None or not overlap
|
||
|
|
||
|
if ctx is not None:
|
||
|
ctx.comm_grp = group
|
||
|
|
||
|
comm_size = dist.get_world_size(group)
|
||
|
if comm_size == 1:
|
||
|
return inputs.squeeze(0), None
|
||
|
|
||
|
if not inputs.is_contiguous():
|
||
|
inputs = inputs.contiguous()
|
||
|
|
||
|
output_shape = inputs.shape[1:]
|
||
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||
|
if not overlap:
|
||
|
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||
|
return outputs, None
|
||
|
else:
|
||
|
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
|
||
|
return outputs, handle
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||
|
# TODO: support async backward
|
||
|
return (
|
||
|
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
class AllToAll(torch.autograd.Function):
|
||
|
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||
|
operation in torch.distributed.
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(
|
||
|
ctx: Any,
|
||
|
inputs: Tensor,
|
||
|
group: Optional[ProcessGroup] = None,
|
||
|
overlap: bool = False,
|
||
|
) -> Tuple[Tensor, Any]:
|
||
|
"""
|
||
|
Returns:
|
||
|
outputs: Tensor
|
||
|
handle: Optional[Work], if overlap is True
|
||
|
"""
|
||
|
if ctx is not None:
|
||
|
ctx.comm_grp = group
|
||
|
if not inputs.is_contiguous():
|
||
|
inputs = inputs.contiguous()
|
||
|
if dist.get_world_size(group) == 1:
|
||
|
return inputs, None
|
||
|
output = torch.empty_like(inputs)
|
||
|
if not overlap:
|
||
|
dist.all_to_all_single(output, inputs, group=group)
|
||
|
return output, None
|
||
|
else:
|
||
|
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
|
||
|
return output, handle
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
||
|
return (
|
||
|
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0],
|
||
|
None,
|
||
|
None,
|
||
|
)
|
||
|
|
||
|
|
||
|
class MoeDispatch(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
@custom_fwd
|
||
|
def forward(ctx, tokens, mask, dest_idx, ec):
|
||
|
s = tokens.size(0)
|
||
|
h = tokens.size(1)
|
||
|
dtype = tokens.dtype
|
||
|
|
||
|
if MOE_KERNEL is None:
|
||
|
load_moe()
|
||
|
if tokens.dtype != torch.float32:
|
||
|
tokens = tokens.to(torch.float32)
|
||
|
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||
|
if expert_input.dtype != dtype:
|
||
|
expert_input = expert_input.to(dtype)
|
||
|
ctx.save_for_backward(mask, dest_idx)
|
||
|
ctx.s = s
|
||
|
ctx.h = h
|
||
|
ctx.ec = ec
|
||
|
ctx.dtype = dtype
|
||
|
|
||
|
return expert_input
|
||
|
|
||
|
@staticmethod
|
||
|
@custom_bwd
|
||
|
def backward(ctx, output_grad):
|
||
|
mask, dest_idx = ctx.saved_tensors
|
||
|
if output_grad.dtype != torch.float32:
|
||
|
output_grad = output_grad.to(torch.float32)
|
||
|
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||
|
if d_tokens.dtype != ctx.dtype:
|
||
|
d_tokens = d_tokens.to(ctx.dtype)
|
||
|
return d_tokens, None, None, None
|
||
|
|
||
|
|
||
|
class MoeCombine(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
@custom_fwd
|
||
|
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||
|
assert logits.dtype == torch.float32
|
||
|
|
||
|
s = logits.size(0)
|
||
|
e = logits.size(1)
|
||
|
c = ec // e
|
||
|
h = expert_tokens.size(-1)
|
||
|
dtype = expert_tokens.dtype
|
||
|
|
||
|
if expert_tokens.dtype != torch.float32:
|
||
|
expert_tokens = expert_tokens.to(torch.float32)
|
||
|
if MOE_KERNEL is None:
|
||
|
load_moe()
|
||
|
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
|
||
|
if output.dtype != dtype:
|
||
|
output = output.to(dtype)
|
||
|
|
||
|
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||
|
ctx.s = s
|
||
|
ctx.e = e
|
||
|
ctx.c = c
|
||
|
ctx.h = h
|
||
|
ctx.dtype = dtype
|
||
|
|
||
|
return output
|
||
|
|
||
|
@staticmethod
|
||
|
@custom_bwd
|
||
|
def backward(ctx, tokens_grad):
|
||
|
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||
|
if tokens_grad.dtype != torch.float32:
|
||
|
tokens_grad = tokens_grad.to(torch.float32)
|
||
|
|
||
|
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
|
||
|
mask, dest_idx)
|
||
|
if d_expert.dtype != ctx.dtype:
|
||
|
d_expert = d_expert.to(ctx.dtype)
|
||
|
|
||
|
return d_expert, d_logits, None, None, None
|
||
|
|
||
|
|
||
|
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
|
||
|
dim0 = inputs.size(0)
|
||
|
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||
|
if flag and use_kernel:
|
||
|
if MOE_KERNEL is None:
|
||
|
load_moe()
|
||
|
return MOE_KERNEL.cumsum_sub_one(inputs)
|
||
|
else:
|
||
|
return torch.cumsum(inputs, dim=0) - 1
|
||
|
|
||
|
|
||
|
class MoeInGradScaler(torch.autograd.Function):
|
||
|
"""
|
||
|
Scale the gradient back by the number of experts
|
||
|
because the batch size increases in the moe stage
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||
|
if ctx is not None:
|
||
|
ctx.ep_size = ep_size
|
||
|
return inputs
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||
|
assert len(grad_outputs) == 1
|
||
|
grad = grad_outputs[0]
|
||
|
if ctx.ep_size != 1:
|
||
|
grad = grad * ctx.ep_size
|
||
|
return grad, None
|
||
|
|
||
|
|
||
|
class MoeOutGradScaler(torch.autograd.Function):
|
||
|
"""
|
||
|
Scale the gradient by the number of experts
|
||
|
because the batch size increases in the moe stage
|
||
|
"""
|
||
|
|
||
|
@staticmethod
|
||
|
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
|
||
|
ctx.ep_size = ep_size
|
||
|
return inputs
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||
|
assert len(grad_outputs) == 1
|
||
|
grad = grad_outputs[0]
|
||
|
if ctx.ep_size != 1:
|
||
|
grad = grad / ctx.ep_size
|
||
|
return grad, None
|