mirror of https://github.com/InternLM/InternLM
move all2all to utils
parent
07c98c4a39
commit
fdd60691d3
|
@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
|||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||
We retain the following license from the original files:
|
||||
"""
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -16,6 +16,7 @@ from internlm.utils.logger import get_logger
|
|||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
from .base_moe import BaseMoELayer
|
||||
from .utils import _AllToAll
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
@ -59,30 +60,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
|||
return gumbel(shape)
|
||||
|
||||
|
||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""
|
||||
All to all communication
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
# TODO: replace with DS process group
|
||||
group: torch.distributed.ProcessGroup,
|
||||
inputs: Tensor,
|
||||
) -> Tensor: # type: ignore
|
||||
ctx.group = group
|
||||
inputs = inputs.contiguous()
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
||||
|
||||
|
||||
# einsum rewrites are on par or more performant
|
||||
# switch can be bubbled up in future
|
||||
USE_EINSUM = True
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""
|
||||
All to all communication
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
# TODO: replace with DS process group
|
||||
group: torch.distributed.ProcessGroup,
|
||||
inputs: Tensor,
|
||||
) -> Tensor: # type: ignore
|
||||
ctx.group = group
|
||||
inputs = inputs.contiguous()
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
Loading…
Reference in New Issue