Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

336 lines
10 KiB

from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
MOE_KERNEL = None
def load_moe():
global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder
MOE_KERNEL = MOEBuilder().load()
class AllGather(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0), None
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
if not overlap:
dist.all_gather(buffer_list, inputs, group=group)
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0), None
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
if not overlap:
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
return outputs, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
# TODO: support async backward
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap
if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if dist.get_world_size(group) == 1:
return inputs, None
output = torch.empty_like(inputs)
if not overlap:
dist.all_to_all_single(output, inputs, group=group)
return output, None
else:
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
return output, handle
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
class HierarchicalAllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor:
"""
Returns:
outputs: Tensor
"""
# TODO: we can reduce comm volume by removing empty capacity
if ctx is not None:
ctx.comm_grps = groups
ctx.src_rank = src_rank
intra_node_group, inter_node_group = groups
local_world_size = dist.get_world_size(intra_node_group)
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
world_size = local_world_size * num_group
outputs = torch.empty_like(inputs)
if dist.get_rank() == src_rank:
# intra-node gather
intra_output = [torch.empty_like(inputs) for _ in range(local_world_size)]
dist.gather(inputs, intra_output, dst=src_rank, group=intra_node_group)
intra_output = [v.chunk(world_size, dim=0) for v in intra_output]
intra_output = torch.cat(sum(zip(*intra_output), ()))
# inter-node all-to-all
if inter_node_group is not None:
inter_output = torch.empty_like(intra_output)
dist.all_to_all_single(inter_output, intra_output, group=inter_node_group)
# layout transform
inter_output = inter_output.chunk(num_group, dim=0)
inter_output = [v.chunk(local_world_size, dim=0) for v in inter_output]
intra_output = torch.cat(sum(zip(*inter_output), ()))
# intra-node scatter
intra_output = list(intra_output.chunk(local_world_size, dim=0))
dist.scatter(outputs, intra_output, src=src_rank, group=intra_node_group)
else:
dist.gather(inputs, dst=src_rank, group=intra_node_group)
dist.scatter(outputs, src=src_rank, group=intra_node_group)
return outputs
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
None,
None,
)
class MoeDispatch(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
dtype = tokens.dtype
if MOE_KERNEL is None:
load_moe()
if tokens.dtype != torch.float32:
tokens = tokens.to(torch.float32)
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
if expert_input.dtype != dtype:
expert_input = expert_input.to(dtype)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
ctx.dtype = dtype
return expert_input
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
if output_grad.dtype != torch.float32:
output_grad = output_grad.to(torch.float32)
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
if d_tokens.dtype != ctx.dtype:
d_tokens = d_tokens.to(ctx.dtype)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
dtype = expert_tokens.dtype
if expert_tokens.dtype != torch.float32:
expert_tokens = expert_tokens.to(torch.float32)
if MOE_KERNEL is None:
load_moe()
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
if output.dtype != dtype:
output = output.to(dtype)
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.dtype = dtype
return output
@staticmethod
@custom_bwd
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32)
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
mask, dest_idx)
if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype)
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and use_kernel:
if MOE_KERNEL is None:
load_moe()
return MOE_KERNEL.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
class MoeInGradScaler(torch.autograd.Function):
"""
Scale the gradient back by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
if ctx is not None:
ctx.ep_size = ep_size
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad * ctx.ep_size
return grad, None
class MoeOutGradScaler(torch.autograd.Function):
"""
Scale the gradient by the number of experts
because the batch size increases in the moe stage
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
ctx.ep_size = ep_size
return inputs
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
assert len(grad_outputs) == 1
grad = grad_outputs[0]
if ctx.ep_size != 1:
grad = grad / ctx.ep_size
return grad, None