mirror of https://github.com/hpcaitech/ColossalAI
[MOE] changed parallelmode to dist process group (#460)
parent
8f9617c313
commit
bccbc15861
|
@ -4,6 +4,7 @@ from .parallel_2d import *
|
||||||
from .parallel_2p5d import *
|
from .parallel_2p5d import *
|
||||||
from .parallel_3d import *
|
from .parallel_3d import *
|
||||||
from .parallel_sequence import *
|
from .parallel_sequence import *
|
||||||
|
from .moe import *
|
||||||
from .utils import *
|
from .utils import *
|
||||||
from .vanilla import *
|
from .vanilla import *
|
||||||
from .wrapper import *
|
from .wrapper import *
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from .experts import Experts, FFNExperts, TPExperts
|
from .experts import Experts, FFNExperts, TPExperts
|
||||||
from .layers import MoeLayer, Top1Router, Top2Router
|
from .layers import MoeLayer, Top1Router, Top2Router
|
||||||
from .utils import NormalNoiseGenerator, build_ffn_experts
|
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
||||||
'build_ffn_experts'
|
'UniformNoiseGenerator', 'build_ffn_experts'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from typing import Any, Tuple, Optional
|
||||||
from colossalai.context import ParallelMode
|
from torch.distributed import ProcessGroup
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from typing import Any, Tuple
|
|
||||||
|
|
||||||
U_CUDA_MODE = False
|
U_CUDA_MODE = False
|
||||||
try:
|
try:
|
||||||
|
@ -18,35 +16,35 @@ except ImportError:
|
||||||
class AllGather(torch.autograd.Function):
|
class AllGather(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
|
|
||||||
if ctx is not None:
|
if ctx is not None:
|
||||||
ctx.parallel_mode = parallel_mode
|
ctx.comm_grp = group
|
||||||
|
|
||||||
comm_size = gpc.get_world_size(parallel_mode)
|
comm_size = dist.get_world_size(group)
|
||||||
if comm_size == 1:
|
if comm_size == 1:
|
||||||
return inputs.unsqueeze(0)
|
return inputs.unsqueeze(0)
|
||||||
|
|
||||||
buffer_shape = (comm_size,) + inputs.shape
|
buffer_shape = (comm_size,) + inputs.shape
|
||||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||||
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode))
|
dist.all_gather(buffer_list, inputs, group=group)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||||
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None
|
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
|
||||||
|
|
||||||
|
|
||||||
class ReduceScatter(torch.autograd.Function):
|
class ReduceScatter(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
|
|
||||||
if ctx is not None:
|
if ctx is not None:
|
||||||
ctx.parallel_mode = parallel_mode
|
ctx.comm_grp = group
|
||||||
|
|
||||||
comm_size = gpc.get_world_size(parallel_mode)
|
comm_size = dist.get_world_size(group)
|
||||||
if comm_size == 1:
|
if comm_size == 1:
|
||||||
return inputs.squeeze(0)
|
return inputs.squeeze(0)
|
||||||
|
|
||||||
|
@ -56,12 +54,12 @@ class ReduceScatter(torch.autograd.Function):
|
||||||
output_shape = inputs.shape[1:]
|
output_shape = inputs.shape[1:]
|
||||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||||
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode))
|
dist.reduce_scatter(outputs, buffer_list, group=group)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||||
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None
|
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
|
||||||
|
|
||||||
|
|
||||||
class AllToAll(torch.autograd.Function):
|
class AllToAll(torch.autograd.Function):
|
||||||
|
@ -70,20 +68,20 @@ class AllToAll(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||||
if ctx is not None:
|
if ctx is not None:
|
||||||
ctx.parallel_mode = parallel_mode
|
ctx.comm_grp = group
|
||||||
if not inputs.is_contiguous():
|
if not inputs.is_contiguous():
|
||||||
inputs = inputs.contiguous()
|
inputs = inputs.contiguous()
|
||||||
if gpc.get_world_size(parallel_mode) == 1:
|
if dist.get_world_size(group) == 1:
|
||||||
return inputs
|
return inputs
|
||||||
output = torch.empty_like(inputs)
|
output = torch.empty_like(inputs)
|
||||||
dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode))
|
dist.all_to_all_single(output, inputs, group=group)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||||
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None
|
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
|
||||||
|
|
||||||
|
|
||||||
class MoeDispatch(torch.autograd.Function):
|
class MoeDispatch(torch.autograd.Function):
|
||||||
|
|
Loading…
Reference in New Issue