diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index 86961dd93..b705632f8 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -4,6 +4,7 @@ from .parallel_2d import * from .parallel_2p5d import * from .parallel_3d import * from .parallel_sequence import * +from .moe import * from .utils import * from .vanilla import * from .wrapper import * diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index f102ddc01..36977ee05 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,8 +1,8 @@ from .experts import Experts, FFNExperts, TPExperts from .layers import MoeLayer, Top1Router, Top2Router -from .utils import NormalNoiseGenerator, build_ffn_experts +from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'build_ffn_experts' + 'UniformNoiseGenerator', 'build_ffn_experts' ] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index a28c1cda8..7928dbfcf 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -1,10 +1,8 @@ import torch import torch.distributed as dist from torch import Tensor - -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from typing import Any, Tuple +from typing import Any, Tuple, Optional +from torch.distributed import ProcessGroup U_CUDA_MODE = False try: @@ -18,35 +16,35 @@ except ImportError: class AllGather(torch.autograd.Function): @staticmethod - def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: if ctx is not None: - ctx.parallel_mode = parallel_mode + ctx.comm_grp = group - comm_size = gpc.get_world_size(parallel_mode) + comm_size = dist.get_world_size(group) if comm_size == 1: return inputs.unsqueeze(0) buffer_shape = (comm_size,) + inputs.shape outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) - dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode)) + dist.all_gather(buffer_list, inputs, group=group) return outputs @staticmethod def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None + return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None class ReduceScatter(torch.autograd.Function): @staticmethod - def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: if ctx is not None: - ctx.parallel_mode = parallel_mode + ctx.comm_grp = group - comm_size = gpc.get_world_size(parallel_mode) + comm_size = dist.get_world_size(group) if comm_size == 1: return inputs.squeeze(0) @@ -56,12 +54,12 @@ class ReduceScatter(torch.autograd.Function): output_shape = inputs.shape[1:] outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) - dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode)) + dist.reduce_scatter(outputs, buffer_list, group=group) return outputs @staticmethod def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None + return AllGather.forward(None, grad_outputs, ctx.comm_grp), None class AllToAll(torch.autograd.Function): @@ -70,20 +68,20 @@ class AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor: + def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: if ctx is not None: - ctx.parallel_mode = parallel_mode + ctx.comm_grp = group if not inputs.is_contiguous(): inputs = inputs.contiguous() - if gpc.get_world_size(parallel_mode) == 1: + if dist.get_world_size(group) == 1: return inputs output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode)) + dist.all_to_all_single(output, inputs, group=group) return output @staticmethod def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None + return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None class MoeDispatch(torch.autograd.Function):