mirror of https://github.com/InternLM/InternLM
30 lines
795 B
Python
30 lines
795 B
Python
from typing import Any, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import Tensor
|
|
|
|
|
|
# Based on https://github.com/pytorch/pytorch/pull/40762
|
|
class _AllToAll(torch.autograd.Function):
|
|
"""
|
|
All to all communication
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx: Any,
|
|
# TODO: replace with DS process group
|
|
group: torch.distributed.ProcessGroup,
|
|
inputs: Tensor,
|
|
) -> Tensor: # type: ignore
|
|
ctx.group = group
|
|
inputs = inputs.contiguous()
|
|
output = torch.empty_like(inputs)
|
|
dist.all_to_all_single(output, inputs, group=group)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
|
return (None, _AllToAll.apply(ctx.group, *grad_output))
|