|
|
|
@ -1,10 +1,8 @@
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
from colossalai.context import ParallelMode |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from typing import Any, Tuple |
|
|
|
|
from typing import Any, Tuple, Optional |
|
|
|
|
from torch.distributed import ProcessGroup |
|
|
|
|
|
|
|
|
|
U_CUDA_MODE = False |
|
|
|
|
try: |
|
|
|
@ -18,35 +16,35 @@ except ImportError:
|
|
|
|
|
class AllGather(torch.autograd.Function): |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: |
|
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: |
|
|
|
|
|
|
|
|
|
if ctx is not None: |
|
|
|
|
ctx.parallel_mode = parallel_mode |
|
|
|
|
ctx.comm_grp = group |
|
|
|
|
|
|
|
|
|
comm_size = gpc.get_world_size(parallel_mode) |
|
|
|
|
comm_size = dist.get_world_size(group) |
|
|
|
|
if comm_size == 1: |
|
|
|
|
return inputs.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
buffer_shape = (comm_size,) + inputs.shape |
|
|
|
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) |
|
|
|
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) |
|
|
|
|
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode)) |
|
|
|
|
dist.all_gather(buffer_list, inputs, group=group) |
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: |
|
|
|
|
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None |
|
|
|
|
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReduceScatter(torch.autograd.Function): |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: |
|
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: |
|
|
|
|
|
|
|
|
|
if ctx is not None: |
|
|
|
|
ctx.parallel_mode = parallel_mode |
|
|
|
|
ctx.comm_grp = group |
|
|
|
|
|
|
|
|
|
comm_size = gpc.get_world_size(parallel_mode) |
|
|
|
|
comm_size = dist.get_world_size(group) |
|
|
|
|
if comm_size == 1: |
|
|
|
|
return inputs.squeeze(0) |
|
|
|
|
|
|
|
|
@ -56,12 +54,12 @@ class ReduceScatter(torch.autograd.Function):
|
|
|
|
|
output_shape = inputs.shape[1:] |
|
|
|
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) |
|
|
|
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) |
|
|
|
|
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode)) |
|
|
|
|
dist.reduce_scatter(outputs, buffer_list, group=group) |
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: |
|
|
|
|
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None |
|
|
|
|
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllToAll(torch.autograd.Function): |
|
|
|
@ -70,20 +68,20 @@ class AllToAll(torch.autograd.Function):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: |
|
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: |
|
|
|
|
if ctx is not None: |
|
|
|
|
ctx.parallel_mode = parallel_mode |
|
|
|
|
ctx.comm_grp = group |
|
|
|
|
if not inputs.is_contiguous(): |
|
|
|
|
inputs = inputs.contiguous() |
|
|
|
|
if gpc.get_world_size(parallel_mode) == 1: |
|
|
|
|
if dist.get_world_size(group) == 1: |
|
|
|
|
return inputs |
|
|
|
|
output = torch.empty_like(inputs) |
|
|
|
|
dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode)) |
|
|
|
|
dist.all_to_all_single(output, inputs, group=group) |
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: |
|
|
|
|
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None |
|
|
|
|
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MoeDispatch(torch.autograd.Function): |
|
|
|
|