[MOE] changed parallelmode to dist process group (#460)

pull/467/head
HELSON 3 years ago committed by GitHub
parent 8f9617c313
commit bccbc15861
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,7 @@ from .parallel_2d import *
from .parallel_2p5d import * from .parallel_2p5d import *
from .parallel_3d import * from .parallel_3d import *
from .parallel_sequence import * from .parallel_sequence import *
from .moe import *
from .utils import * from .utils import *
from .vanilla import * from .vanilla import *
from .wrapper import * from .wrapper import *

@ -1,8 +1,8 @@
from .experts import Experts, FFNExperts, TPExperts from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, Top1Router, Top2Router from .layers import MoeLayer, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, build_ffn_experts from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [ __all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'build_ffn_experts' 'UniformNoiseGenerator', 'build_ffn_experts'
] ]

@ -1,10 +1,8 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
from typing import Any, Tuple, Optional
from colossalai.context import ParallelMode from torch.distributed import ProcessGroup
from colossalai.core import global_context as gpc
from typing import Any, Tuple
U_CUDA_MODE = False U_CUDA_MODE = False
try: try:
@ -18,35 +16,35 @@ except ImportError:
class AllGather(torch.autograd.Function): class AllGather(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None: if ctx is not None:
ctx.parallel_mode = parallel_mode ctx.comm_grp = group
comm_size = gpc.get_world_size(parallel_mode) comm_size = dist.get_world_size(group)
if comm_size == 1: if comm_size == 1:
return inputs.unsqueeze(0) return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode)) dist.all_gather(buffer_list, inputs, group=group)
return outputs return outputs
@staticmethod @staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
class ReduceScatter(torch.autograd.Function): class ReduceScatter(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None: if ctx is not None:
ctx.parallel_mode = parallel_mode ctx.comm_grp = group
comm_size = gpc.get_world_size(parallel_mode) comm_size = dist.get_world_size(group)
if comm_size == 1: if comm_size == 1:
return inputs.squeeze(0) return inputs.squeeze(0)
@ -56,12 +54,12 @@ class ReduceScatter(torch.autograd.Function):
output_shape = inputs.shape[1:] output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode)) dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs return outputs
@staticmethod @staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
class AllToAll(torch.autograd.Function): class AllToAll(torch.autograd.Function):
@ -70,20 +68,20 @@ class AllToAll(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
if ctx is not None: if ctx is not None:
ctx.parallel_mode = parallel_mode ctx.comm_grp = group
if not inputs.is_contiguous(): if not inputs.is_contiguous():
inputs = inputs.contiguous() inputs = inputs.contiguous()
if gpc.get_world_size(parallel_mode) == 1: if dist.get_world_size(group) == 1:
return inputs return inputs
output = torch.empty_like(inputs) output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode)) dist.all_to_all_single(output, inputs, group=group)
return output return output
@staticmethod @staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
class MoeDispatch(torch.autograd.Function): class MoeDispatch(torch.autograd.Function):

Loading…
Cancel
Save