2022-01-07 07:08:36 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch import Tensor
|
2022-03-19 05:46:29 +00:00
|
|
|
from typing import Any, Tuple, Optional
|
|
|
|
from torch.distributed import ProcessGroup
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
COL_MOE_KERNEL_FLAG = False
|
2022-02-18 12:42:31 +00:00
|
|
|
try:
|
|
|
|
import colossal_moe_cuda
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
COL_MOE_KERNEL_FLAG = True
|
2022-02-18 12:42:31 +00:00
|
|
|
except ImportError:
|
|
|
|
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
|
|
|
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
class AllGather(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
2022-03-19 05:46:29 +00:00
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
2022-02-27 14:28:39 +00:00
|
|
|
if ctx is not None:
|
2022-03-19 05:46:29 +00:00
|
|
|
ctx.comm_grp = group
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 05:46:29 +00:00
|
|
|
comm_size = dist.get_world_size(group)
|
2022-02-27 14:28:39 +00:00
|
|
|
if comm_size == 1:
|
|
|
|
return inputs.unsqueeze(0)
|
|
|
|
|
|
|
|
buffer_shape = (comm_size,) + inputs.shape
|
|
|
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
2022-03-19 05:46:29 +00:00
|
|
|
dist.all_gather(buffer_list, inputs, group=group)
|
2022-02-27 14:28:39 +00:00
|
|
|
return outputs
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
2022-03-19 05:46:29 +00:00
|
|
|
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
2022-02-27 14:28:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ReduceScatter(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
2022-03-19 05:46:29 +00:00
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
2022-02-27 14:28:39 +00:00
|
|
|
if ctx is not None:
|
2022-03-19 05:46:29 +00:00
|
|
|
ctx.comm_grp = group
|
2022-02-27 14:28:39 +00:00
|
|
|
|
2022-03-19 05:46:29 +00:00
|
|
|
comm_size = dist.get_world_size(group)
|
2022-02-27 14:28:39 +00:00
|
|
|
if comm_size == 1:
|
|
|
|
return inputs.squeeze(0)
|
|
|
|
|
|
|
|
if not inputs.is_contiguous():
|
|
|
|
inputs = inputs.contiguous()
|
|
|
|
|
|
|
|
output_shape = inputs.shape[1:]
|
|
|
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
2022-03-19 05:46:29 +00:00
|
|
|
dist.reduce_scatter(outputs, buffer_list, group=group)
|
2022-02-27 14:28:39 +00:00
|
|
|
return outputs
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
2022-03-19 05:46:29 +00:00
|
|
|
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
2022-02-27 14:28:39 +00:00
|
|
|
|
|
|
|
|
2022-01-07 07:08:36 +00:00
|
|
|
class AllToAll(torch.autograd.Function):
|
|
|
|
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
|
|
|
operation in torch.distributed.
|
|
|
|
"""
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-01-07 07:08:36 +00:00
|
|
|
@staticmethod
|
2022-03-19 05:46:29 +00:00
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
2022-02-18 12:42:31 +00:00
|
|
|
if ctx is not None:
|
2022-03-19 05:46:29 +00:00
|
|
|
ctx.comm_grp = group
|
2022-01-07 07:08:36 +00:00
|
|
|
if not inputs.is_contiguous():
|
|
|
|
inputs = inputs.contiguous()
|
2022-03-19 05:46:29 +00:00
|
|
|
if dist.get_world_size(group) == 1:
|
2022-02-27 06:01:25 +00:00
|
|
|
return inputs
|
2022-01-07 07:08:36 +00:00
|
|
|
output = torch.empty_like(inputs)
|
2022-03-19 05:46:29 +00:00
|
|
|
dist.all_to_all_single(output, inputs, group=group)
|
2022-01-07 07:08:36 +00:00
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
2022-03-19 05:46:29 +00:00
|
|
|
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
2022-02-18 12:42:31 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MoeDispatch(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, tokens, mask, dest_idx, ec):
|
|
|
|
s = tokens.size(0)
|
|
|
|
h = tokens.size(1)
|
|
|
|
|
|
|
|
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
|
|
|
|
|
|
|
ctx.save_for_backward(mask, dest_idx)
|
|
|
|
ctx.s = s
|
|
|
|
ctx.h = h
|
|
|
|
ctx.ec = ec
|
|
|
|
|
|
|
|
return expert_input
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, output_grad):
|
|
|
|
mask, dest_idx = ctx.saved_tensors
|
2022-02-27 06:01:25 +00:00
|
|
|
d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
2022-02-18 12:42:31 +00:00
|
|
|
return d_tokens, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
class MoeCombine(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
|
|
|
assert logits.dtype == torch.float32
|
|
|
|
|
|
|
|
s = logits.size(0)
|
|
|
|
e = logits.size(1)
|
|
|
|
c = ec // e
|
|
|
|
h = expert_tokens.size(-1)
|
|
|
|
|
|
|
|
fp16_flag = (expert_tokens.dtype == torch.float16)
|
|
|
|
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
2022-02-27 06:01:25 +00:00
|
|
|
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
2022-02-18 12:42:31 +00:00
|
|
|
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
|
|
|
|
|
|
|
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
|
|
|
ctx.s = s
|
|
|
|
ctx.e = e
|
|
|
|
ctx.c = c
|
|
|
|
ctx.h = h
|
|
|
|
ctx.fp16_flag = fp16_flag
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, tokens_grad):
|
|
|
|
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
|
|
|
|
|
|
|
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
|
|
|
else tokens_grad
|
|
|
|
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
2022-02-27 06:01:25 +00:00
|
|
|
d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
|
|
|
|
mask, dest_idx)
|
2022-02-18 12:42:31 +00:00
|
|
|
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
|
|
|
|
|
|
|
return d_expert, d_logits, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
def moe_cumsum(inputs: Tensor):
|
|
|
|
dim0 = inputs.size(0)
|
|
|
|
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
2022-03-19 07:36:25 +00:00
|
|
|
if flag and COL_MOE_KERNEL_FLAG:
|
2022-02-18 12:42:31 +00:00
|
|
|
return colossal_moe_cuda.cumsum_sub_one(inputs)
|
|
|
|
else:
|
|
|
|
return torch.cumsum(inputs, dim=0) - 1
|