mirror of https://github.com/hpcaitech/ColossalAI
[MOE] changed parallelmode to dist process group (#460)
parent
8f9617c313
commit
bccbc15861
|
@ -4,6 +4,7 @@ from .parallel_2d import *
|
|||
from .parallel_2p5d import *
|
||||
from .parallel_3d import *
|
||||
from .parallel_sequence import *
|
||||
from .moe import *
|
||||
from .utils import *
|
||||
from .vanilla import *
|
||||
from .wrapper import *
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from .experts import Experts, FFNExperts, TPExperts
|
||||
from .layers import MoeLayer, Top1Router, Top2Router
|
||||
from .utils import NormalNoiseGenerator, build_ffn_experts
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
|
||||
|
||||
__all__ = [
|
||||
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
||||
'build_ffn_experts'
|
||||
'UniformNoiseGenerator', 'build_ffn_experts'
|
||||
]
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Tuple, Optional
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
U_CUDA_MODE = False
|
||||
try:
|
||||
|
@ -18,35 +16,35 @@ except ImportError:
|
|||
class AllGather(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
|
||||
if ctx is not None:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = gpc.get_world_size(parallel_mode)
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0)
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode))
|
||||
dist.all_gather(buffer_list, inputs, group=group)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None
|
||||
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
|
||||
if ctx is not None:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.comm_grp = group
|
||||
|
||||
comm_size = gpc.get_world_size(parallel_mode)
|
||||
comm_size = dist.get_world_size(group)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0)
|
||||
|
||||
|
@ -56,12 +54,12 @@ class ReduceScatter(torch.autograd.Function):
|
|||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode))
|
||||
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None
|
||||
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
|
@ -70,20 +68,20 @@ class AllToAll(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.comm_grp = group
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
if gpc.get_world_size(parallel_mode) == 1:
|
||||
if dist.get_world_size(group) == 1:
|
||||
return inputs
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode))
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None
|
||||
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
||||
|
||||
|
||||
class MoeDispatch(torch.autograd.Function):
|
||||
|
|
Loading…
Reference in New Issue