diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index f9badf1..631a9da 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555 We retain the following license from the original files: """ -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import torch import torch.distributed as dist @@ -16,6 +16,7 @@ from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from .base_moe import BaseMoELayer +from .utils import _AllToAll # global llm logger logger = get_logger(__file__) @@ -59,30 +60,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: return gumbel(shape) -# Based on https://github.com/pytorch/pytorch/pull/40762 -class _AllToAll(torch.autograd.Function): - """ - All to all communication - """ - - @staticmethod - def forward( - ctx: Any, - # TODO: replace with DS process group - group: torch.distributed.ProcessGroup, - inputs: Tensor, - ) -> Tensor: # type: ignore - ctx.group = group - inputs = inputs.contiguous() - output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) - return output - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: - return (None, _AllToAll.apply(ctx.group, *grad_output)) - - # einsum rewrites are on par or more performant # switch can be bubbled up in future USE_EINSUM = True diff --git a/internlm/moe/utils.py b/internlm/moe/utils.py new file mode 100644 index 0000000..cdb8aed --- /dev/null +++ b/internlm/moe/utils.py @@ -0,0 +1,29 @@ +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor + + +# Based on https://github.com/pytorch/pytorch/pull/40762 +class _AllToAll(torch.autograd.Function): + """ + All to all communication + """ + + @staticmethod + def forward( + ctx: Any, + # TODO: replace with DS process group + group: torch.distributed.ProcessGroup, + inputs: Tensor, + ) -> Tensor: # type: ignore + ctx.group = group + inputs = inputs.contiguous() + output = torch.empty_like(inputs) + dist.all_to_all_single(output, inputs, group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: + return (None, _AllToAll.apply(ctx.group, *grad_output))