[MOE] changed parallelmode to dist process group (#460)

pull/467/head
HELSON 3 years ago committed by GitHub
parent 8f9617c313
commit bccbc15861
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@ from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .moe import *
from .utils import *
from .vanilla import *
from .wrapper import *

@ -1,8 +1,8 @@
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, build_ffn_experts
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'build_ffn_experts'
'UniformNoiseGenerator', 'build_ffn_experts'
]

@ -1,10 +1,8 @@
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from typing import Any, Tuple
from typing import Any, Tuple, Optional
from torch.distributed import ProcessGroup
U_CUDA_MODE = False
try:
@ -18,35 +16,35 @@ except ImportError:
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
ctx.comm_grp = group
comm_size = gpc.get_world_size(parallel_mode)
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode))
dist.all_gather(buffer_list, inputs, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
ctx.comm_grp = group
comm_size = gpc.get_world_size(parallel_mode)
comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0)
@ -56,12 +54,12 @@ class ReduceScatter(torch.autograd.Function):
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode))
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
class AllToAll(torch.autograd.Function):
@ -70,20 +68,20 @@ class AllToAll(torch.autograd.Function):
"""
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if gpc.get_world_size(parallel_mode) == 1:
if dist.get_world_size(group) == 1:
return inputs
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode))
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
class MoeDispatch(torch.autograd.Function):

Loading…
Cancel
Save