move all2all to utils

pull/567/head
Wenwen Qu 2024-01-08 13:16:17 +08:00
parent 07c98c4a39
commit fdd60691d3
2 changed files with 31 additions and 25 deletions

View File

@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple
import torch
import torch.distributed as dist
@ -16,6 +16,7 @@ from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from .base_moe import BaseMoELayer
from .utils import _AllToAll
# global llm logger
logger = get_logger(__file__)
@ -59,30 +60,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
return gumbel(shape)
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
"""
All to all communication
"""
@staticmethod
def forward(
ctx: Any,
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
inputs: Tensor,
) -> Tensor: # type: ignore
ctx.group = group
inputs = inputs.contiguous()
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True

29
internlm/moe/utils.py Normal file
View File

@ -0,0 +1,29 @@
from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
"""
All to all communication
"""
@staticmethod
def forward(
ctx: Any,
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
inputs: Tensor,
) -> Tensor: # type: ignore
ctx.group = group
inputs = inputs.contiguous()
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))