ColossalAI/colossalai/moe/_operation.py

337 lines
10 KiB
Python
Raw Normal View History

from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
MOE_KERNEL = None
def load_moe():
global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder
MOE_KERNEL = MOEBuilder().load()
class AllGather(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0), None
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
if not overlap:
dist.all_gather(buffer_list, inputs, group=group)
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0), None
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
if not overlap:
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
# TODO: support async backward
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if dist.get_world_size(group) == 1:
return inputs, None
output = torch.empty_like(inputs)
if not overlap:
dist.all_to_all_single(output, inputs, group=group)
return output, None
else:
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
return output, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class HierarchicalAllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor:
"""
Returns:
outputs: Tensor
"""
# TODO: we can reduce comm volume by removing empty capacity
if ctx is not None:
ctx.comm_grps = groups
ctx.src_rank = src_rank
intra_node_group, inter_node_group = groups
local_world_size = dist.get_world_size(intra_node_group)
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
world_size = local_world_size * num_group
outputs = torch.empty_like(inputs)
if dist.get_rank() == src_rank:
# intra-node gather
intra_output = [torch.empty_like(inputs) for _ in range(local_world_size)]
dist.gather(inputs, intra_output, dst=src_rank, group=intra_node_group)
intra_output = [v.chunk(world_size, dim=0) for v in intra_output]
intra_output = torch.cat(sum(zip(*intra_output), ()))
# inter-node all-to-all
if inter_node_group is not None:
inter_output = torch.empty_like(intra_output)
dist.all_to_all_single(inter_output, intra_output, group=inter_node_group)
# layout transform
inter_output = inter_output.chunk(num_group, dim=0)
inter_output = [v.chunk(local_world_size, dim=0) for v in inter_output]
intra_output = torch.cat(sum(zip(*inter_output), ()))
# intra-node scatter
intra_output = list(intra_output.chunk(local_world_size, dim=0))
dist.scatter(outputs, intra_output, src=src_rank, group=intra_node_group)
else:
dist.gather(inputs, dst=src_rank, group=intra_node_group)
dist.scatter(outputs, src=src_rank, group=intra_node_group)
return outputs
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
None,
None,
)
class MoeDispatch(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
dtype = tokens.dtype
if MOE_KERNEL is None:
load_moe()
if tokens.dtype != torch.float32:
tokens = tokens.to(torch.float32)
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
if expert_input.dtype != dtype:
expert_input = expert_input.to(dtype)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
ctx.dtype = dtype
return expert_input
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
if output_grad.dtype != torch.float32:
output_grad = output_grad.to(torch.float32)
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
if d_tokens.dtype != ctx.dtype:
d_tokens = d_tokens.to(ctx.dtype)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
dtype = expert_tokens.dtype
if expert_tokens.dtype != torch.float32:
expert_tokens = expert_tokens.to(torch.float32)
if MOE_KERNEL is None:
load_moe()
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
if output.dtype != dtype:
output = output.to(dtype)
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.dtype = dtype
return output
@staticmethod
@custom_bwd
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32)
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
mask, dest_idx)
if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype)
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and use_kernel:
if MOE_KERNEL is None:
load_moe()
return MOE_KERNEL.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
class MoeInGradScaler(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None:
ctx.ep_size = ep_size
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad * ctx.ep_size
return grad, None
class MoeOutGradScaler(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
ctx.ep_size = ep_size
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None