|
|
@ -1,10 +1,8 @@
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch import Tensor
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from typing import Any, Tuple, Optional
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
|
|
|
from typing import Any, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
U_CUDA_MODE = False
|
|
|
|
U_CUDA_MODE = False
|
|
|
|
try:
|
|
|
|
try:
|
|
|
@ -18,35 +16,35 @@ except ImportError:
|
|
|
|
class AllGather(torch.autograd.Function):
|
|
|
|
class AllGather(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
if ctx is not None:
|
|
|
|
if ctx is not None:
|
|
|
|
ctx.parallel_mode = parallel_mode
|
|
|
|
ctx.comm_grp = group
|
|
|
|
|
|
|
|
|
|
|
|
comm_size = gpc.get_world_size(parallel_mode)
|
|
|
|
comm_size = dist.get_world_size(group)
|
|
|
|
if comm_size == 1:
|
|
|
|
if comm_size == 1:
|
|
|
|
return inputs.unsqueeze(0)
|
|
|
|
return inputs.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
buffer_shape = (comm_size,) + inputs.shape
|
|
|
|
buffer_shape = (comm_size,) + inputs.shape
|
|
|
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
|
|
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
|
|
|
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode))
|
|
|
|
dist.all_gather(buffer_list, inputs, group=group)
|
|
|
|
return outputs
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
|
|
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
|
|
|
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None
|
|
|
|
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReduceScatter(torch.autograd.Function):
|
|
|
|
class ReduceScatter(torch.autograd.Function):
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
if ctx is not None:
|
|
|
|
if ctx is not None:
|
|
|
|
ctx.parallel_mode = parallel_mode
|
|
|
|
ctx.comm_grp = group
|
|
|
|
|
|
|
|
|
|
|
|
comm_size = gpc.get_world_size(parallel_mode)
|
|
|
|
comm_size = dist.get_world_size(group)
|
|
|
|
if comm_size == 1:
|
|
|
|
if comm_size == 1:
|
|
|
|
return inputs.squeeze(0)
|
|
|
|
return inputs.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
@ -56,12 +54,12 @@ class ReduceScatter(torch.autograd.Function):
|
|
|
|
output_shape = inputs.shape[1:]
|
|
|
|
output_shape = inputs.shape[1:]
|
|
|
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
|
|
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
|
|
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
|
|
|
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode))
|
|
|
|
dist.reduce_scatter(outputs, buffer_list, group=group)
|
|
|
|
return outputs
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
|
|
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
|
|
|
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None
|
|
|
|
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllToAll(torch.autograd.Function):
|
|
|
|
class AllToAll(torch.autograd.Function):
|
|
|
@ -70,20 +68,20 @@ class AllToAll(torch.autograd.Function):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
|
|
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
|
|
|
if ctx is not None:
|
|
|
|
if ctx is not None:
|
|
|
|
ctx.parallel_mode = parallel_mode
|
|
|
|
ctx.comm_grp = group
|
|
|
|
if not inputs.is_contiguous():
|
|
|
|
if not inputs.is_contiguous():
|
|
|
|
inputs = inputs.contiguous()
|
|
|
|
inputs = inputs.contiguous()
|
|
|
|
if gpc.get_world_size(parallel_mode) == 1:
|
|
|
|
if dist.get_world_size(group) == 1:
|
|
|
|
return inputs
|
|
|
|
return inputs
|
|
|
|
output = torch.empty_like(inputs)
|
|
|
|
output = torch.empty_like(inputs)
|
|
|
|
dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode))
|
|
|
|
dist.all_to_all_single(output, inputs, group=group)
|
|
|
|
return output
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
|
|
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
|
|
|
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None
|
|
|
|
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MoeDispatch(torch.autograd.Function):
|
|
|
|
class MoeDispatch(torch.autograd.Function):
|
|
|
|