mirror of https://github.com/InternLM/InternLM
move all2all to utils
parent
07c98c4a39
commit
fdd60691d3
|
@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||||
We retain the following license from the original files:
|
We retain the following license from the original files:
|
||||||
"""
|
"""
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -16,6 +16,7 @@ from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
|
||||||
from .base_moe import BaseMoELayer
|
from .base_moe import BaseMoELayer
|
||||||
|
from .utils import _AllToAll
|
||||||
|
|
||||||
# global llm logger
|
# global llm logger
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
@ -59,30 +60,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
||||||
return gumbel(shape)
|
return gumbel(shape)
|
||||||
|
|
||||||
|
|
||||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
|
||||||
class _AllToAll(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
All to all communication
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx: Any,
|
|
||||||
# TODO: replace with DS process group
|
|
||||||
group: torch.distributed.ProcessGroup,
|
|
||||||
inputs: Tensor,
|
|
||||||
) -> Tensor: # type: ignore
|
|
||||||
ctx.group = group
|
|
||||||
inputs = inputs.contiguous()
|
|
||||||
output = torch.empty_like(inputs)
|
|
||||||
dist.all_to_all_single(output, inputs, group=group)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
|
||||||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
|
||||||
|
|
||||||
|
|
||||||
# einsum rewrites are on par or more performant
|
# einsum rewrites are on par or more performant
|
||||||
# switch can be bubbled up in future
|
# switch can be bubbled up in future
|
||||||
USE_EINSUM = True
|
USE_EINSUM = True
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||||
|
class _AllToAll(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
All to all communication
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx: Any,
|
||||||
|
# TODO: replace with DS process group
|
||||||
|
group: torch.distributed.ProcessGroup,
|
||||||
|
inputs: Tensor,
|
||||||
|
) -> Tensor: # type: ignore
|
||||||
|
ctx.group = group
|
||||||
|
inputs = inputs.contiguous()
|
||||||
|
output = torch.empty_like(inputs)
|
||||||
|
dist.all_to_all_single(output, inputs, group=group)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||||
|
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
Loading…
Reference in New Issue