diff --git a/colossalai/core.py b/colossalai/core.py index a82586c7d..a2d3f57a7 100644 --- a/colossalai/core.py +++ b/colossalai/core.py @@ -4,4 +4,4 @@ from colossalai.context import ParallelContext, MoeContext global_context = ParallelContext.get_instance() -moe_context = MoeContext.get_instance() +MOE_CONTEXT = MoeContext.get_instance() diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py index fa2340196..3c260fdca 100644 --- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_moe_gradient_handler.py @@ -1,4 +1,4 @@ -from colossalai.core import global_context as gpc, moe_context as moe_env +from colossalai.core import global_context as gpc, MOE_CONTEXT from colossalai.registry import GRADIENT_HANDLER from colossalai.utils.moe import get_moe_epsize_param_dict from ._base_gradient_handler import BaseGradientHandler @@ -30,5 +30,5 @@ class MoeGradientHandler(BaseGradientHandler): bucket_allreduce(param_list=param_dict[1], group=gpc.get_group(ParallelMode.DATA)) for ep_size in param_dict: - if ep_size != 1 and ep_size != moe_env.world_size: - bucket_allreduce(param_list=param_dict[ep_size], group=moe_env.information[ep_size].dp_group) + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + bucket_allreduce(param_list=param_dict[ep_size], group=MOE_CONTEXT.information[ep_size].dp_group) diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index 7928dbfcf..dbf264297 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -4,11 +4,11 @@ from torch import Tensor from typing import Any, Tuple, Optional from torch.distributed import ProcessGroup -U_CUDA_MODE = False +COL_MOE_KERNEL_FLAG = False try: import colossal_moe_cuda - U_CUDA_MODE = True + COL_MOE_KERNEL_FLAG = True except ImportError: print("If you want to activate cuda mode for MoE, please install with cuda_ext!") @@ -17,7 +17,6 @@ class AllGather(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: ctx.comm_grp = group @@ -40,7 +39,6 @@ class ReduceScatter(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: ctx.comm_grp = group @@ -149,7 +147,7 @@ class MoeCombine(torch.autograd.Function): def moe_cumsum(inputs: Tensor): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and U_CUDA_MODE: + if flag and COL_MOE_KERNEL_FLAG: return colossal_moe_cuda.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 797eb9c24..5ee9a45a9 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -2,18 +2,24 @@ import math import torch import torch.nn as nn -from colossalai.global_variables import moe_env from colossalai.context import ParallelMode, seed from colossalai.utils import get_current_device +from colossalai.core import MOE_CONTEXT class MoeExperts(nn.Module): + """Basic class for experts in MoE. It stores what kind of communication expersts use + to exchange tokens, how many experts in a single GPU and parallel information such as + expert parallel size, data parallel size and their distributed communication groups. + """ - def __init__(self, comm: str): + def __init__(self, comm_name: str, num_experts: int): super().__init__() - assert comm in {"all_to_all", "all_gather"}, \ + assert comm_name in {"all_to_all", "all_gather"}, \ "This kind of communication has not been implemented yet.\n Please use Experts build function." - self.comm = comm + self.comm_name = comm_name + # Get the configuration of experts' deployment and parallel information from moe contex + self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) class Experts(MoeExperts): @@ -29,53 +35,48 @@ class Experts(MoeExperts): """ def __init__(self, expert, num_experts, **expert_args): - super().__init__("all_to_all") - - assert num_experts % moe_env.model_parallel_size == 0, \ - "The number of experts should be divied by moe model size" - - num_local_experts = num_experts // moe_env.model_parallel_size + super().__init__("all_to_all", num_experts) - with seed(ParallelMode.MOE_MODEL): - self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)]) + # Use seed to make every expert different from others + with seed(ParallelMode.TENSOR): + self.experts = nn.ModuleList([expert(**expert_args) for _ in range(self.num_local_experts)]) + # Attach parallel information for all parameters in Experts for exp in self.experts: for param in exp.parameters(): - param.__setattr__('moe_param', True) - - self.num_local_experts = num_local_experts + param.__setattr__('moe_info', self.dist_info) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor): + # Split inputs for each expert expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) expert_output = [] + # Get outputs from each expert for i in range(self.num_local_experts): expert_output.append(self.experts[i](expert_input[i])) + # Concatenate all outputs together output = torch.cat(expert_output, dim=1).contiguous() return output class FFNExperts(MoeExperts): + """Use torch.bmm to speed up for multiple experts. + """ def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_to_all") + super().__init__("all_to_all", num_experts) - assert num_experts % moe_env.model_parallel_size == 0, \ - "The number of experts should be divied by moe model size" + self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) - num_local_experts = num_experts // moe_env.model_parallel_size - - self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device())) + self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) s1 = math.sqrt(0.1 / d_model) s2 = math.sqrt(0.1 / d_ff) - with seed(ParallelMode.MOE_MODEL): + with seed(ParallelMode.TENSOR): nn.init.trunc_normal_(self.w1, std=s1) nn.init.trunc_normal_(self.b1, std=s1) nn.init.trunc_normal_(self.w2, std=s2) @@ -85,7 +86,7 @@ class FFNExperts(MoeExperts): self.drop = nn.Dropout(p=drop_rate) for param in self.parameters(): - param.__setattr__('moe_param', True) + param.__setattr__('moe_info', self.dist_info) def forward(self, inputs): # inputs [g, el, c, h] @@ -99,9 +100,9 @@ class FFNExperts(MoeExperts): out_ff = torch.baddbmm(self.b1, inputs, self.w1) out_act = self.act(out_ff) with seed(ParallelMode.TENSOR): - inter = self.drop(out_act) + out_inter = self.drop(out_act) - out_model = torch.baddbmm(self.b2, inter, self.w2) + out_model = torch.baddbmm(self.b2, out_inter, self.w2) with seed(ParallelMode.TENSOR): outputs = self.drop(out_model) # outputs [el, gc, h] @@ -111,14 +112,18 @@ class FFNExperts(MoeExperts): class TPExperts(MoeExperts): + """Use tensor parallelism to split each expert evenly, which can deploy experts in + case that the number of experts can't be divied by maximum expert parallel size or + maximum expert parallel size can't be divied by the number of experts. + """ def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_gather") + super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - assert d_ff % moe_env.model_parallel_size == 0, \ - "d_ff should be divied by moe model size" + assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ + "d_ff should be divied by maximum expert parallel size" - p_ff = d_ff // moe_env.model_parallel_size + p_ff = d_ff // MOE_CONTEXT.max_ep_size self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) @@ -129,7 +134,7 @@ class TPExperts(MoeExperts): s1 = math.sqrt(0.1 / d_model) s2 = math.sqrt(0.1 / d_ff) - with seed(ParallelMode.MOE_MODEL): + with seed(ParallelMode.TENSOR): nn.init.trunc_normal_(self.w1, std=s1) nn.init.trunc_normal_(self.b1, std=s1) nn.init.trunc_normal_(self.w2, std=s2) @@ -139,9 +144,9 @@ class TPExperts(MoeExperts): self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate) - self.w1.__setattr__('moe_param', True) - self.w2.__setattr__('moe_param', True) - self.b1.__setattr__('moe_param', True) + self.w1.__setattr__('moe_info', self.dist_info) + self.w2.__setattr__('moe_info', self.dist_info) + self.b1.__setattr__('moe_info', self.dist_info) def forward(self, inputs): # inputs [g, e, c, h] @@ -155,9 +160,9 @@ class TPExperts(MoeExperts): out_ff = torch.baddbmm(self.b1, inputs, self.w1) out_act = self.act(out_ff) with seed(ParallelMode.TENSOR): - inter = self.drop(out_act) + out_inter = self.drop(out_act) - out_model = torch.baddbmm(self.b2, inter, self.w2) + out_model = torch.baddbmm(self.b2, out_inter, self.w2) outputs = self.drop(out_model) # outputs [e, gc, h] outputs = outputs.reshape(inshape) diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index f98e0764e..39b23abed 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -4,14 +4,13 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist -from colossalai.core import global_context as gpc -from colossalai.global_variables import moe_env -from colossalai.context import ParallelMode +from colossalai.core import MOE_CONTEXT from colossalai.utils import get_current_device -from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum +from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from .experts import MoeExperts from .utils import autocast_softmax -from typing import Callable +from typing import Callable, Optional +from torch.distributed import ProcessGroup class Top1Router(nn.Module): @@ -19,8 +18,8 @@ class Top1Router(nn.Module): for routing usage. More deailted function can be found in the paper about Switch Transformer of Google. - :param capacity_factor_train: Capacity factor in routing of training - :param capacity_factor_eval: Capacity factor in routing of evaluation + :param capacity_factor_train: Capacity factor in routing during training + :param capacity_factor_eval: Capacity factor in routing during evaluation :param min_capacity: The minimum number of the capacity of each expert :param select_policy: The policy about tokens selection :param noisy_func: Noisy function used in logits @@ -66,7 +65,7 @@ class Top1Router(nn.Module): assert capacity > 0 return capacity - def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) @@ -82,10 +81,10 @@ class Top1Router(nn.Module): me = torch.mean(logits, dim=0) ce = torch.mean(mask.float(), dim=0) l_aux = num_experts * torch.sum(me * ce) - moe_env.add_loss(l_aux) + MOE_CONTEXT.add_loss(l_aux) elif not self.drop_tks: max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() else: pass @@ -103,7 +102,7 @@ class Top1Router(nn.Module): ranks = torch.sum(mask * ranks, dim=-1) - if cuda_mode: + if use_kernel: mask = torch.sum(mask, dim=-1) mask = torch.stack([mask], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) @@ -120,8 +119,8 @@ class Top2Router(nn.Module): """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] for routing usage. More deailted function can be found in the paper about ViT-MoE. - :param capacity_factor_train: Capacity factor in routing of training - :param capacity_factor_eval: Capacity factor in routing of evaluation + :param capacity_factor_train: Capacity factor in routing during training + :param capacity_factor_eval: Capacity factor in routing during evaluation :param min_capacity: The minimum number of the capacity of each expert :param noisy_func: Noisy function used in logits :param drop_tks: Whether drops tokens in evaluation @@ -157,7 +156,7 @@ class Top2Router(nn.Module): assert capacity > 0 return capacity - def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): # inputs: [s, h] if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) @@ -177,10 +176,10 @@ class Top2Router(nn.Module): me = torch.mean(logits, dim=0) ce = torch.mean(cmask.float(), dim=0) l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - moe_env.add_loss(l_aux) + MOE_CONTEXT.add_loss(l_aux) elif not self.drop_tks: max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() else: pass @@ -195,7 +194,7 @@ class Top2Router(nn.Module): rank1 = torch.sum(mask1 * rank1, dim=-1) rank2 = torch.sum(mask2 * rank2, dim=-1) - if cuda_mode: + if use_kernel: mask1 = torch.sum(mask1, dim=-1) mask2 = torch.sum(mask2, dim=-1) @@ -241,34 +240,36 @@ class MoeLayer(nn.Module): self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device()) self.router = router self.experts = experts - self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + self.ep_group = experts.dist_info.ep_group + self.ep_size = experts.dist_info.ep_size + self.num_local_experts = experts.num_local_experts def a2a_process(self, dispatch_data: torch.Tensor): - expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL) + expert_input = AllToAll.apply(dispatch_data, self.ep_group) input_shape = expert_input.shape - expert_input = expert_input.reshape(moe_env.model_parallel_size, - self.num_experts // moe_env.model_parallel_size, -1, self.d_model) + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) expert_output = self.experts(expert_input) expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL) + expert_output = AllToAll.apply(expert_output, self.ep_group) return expert_output def tp_process(self, dispatch_data: torch.Tensor): - expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL) + expert_in = AllGather.apply(dispatch_data, self.ep_group) expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL) + expert_out = ReduceScatter.apply(expert_out, self.ep_group) return expert_out def forward(self, inputs: torch.Tensor) -> torch.Tensor: tokens = inputs.reshape(-1, self.d_model) gate_output = self.gate(tokens) - router_res = self.router(gate_output, self.cuda_mode) + router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) - if self.cuda_mode: + if self.use_kernel: dispatch_data = MoeDispatch.apply(tokens, *router_res[1:]) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) else: @@ -276,16 +277,16 @@ class MoeLayer(nn.Module): dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) # dispatch_data [e, c, h] - if self.experts.comm == "all_to_all": + if self.experts.comm_name == "all_to_all": expert_output = self.a2a_process(dispatch_data) - elif self.experts.comm == "all_gather": + elif self.experts.comm_name == "all_gather": expert_output = self.tp_process(dispatch_data) else: raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " "build function.") # expert_output [e, c, h] - if self.cuda_mode: + if self.use_kernel: expert_output = expert_output.reshape(-1, self.d_model) ans = MoeCombine.apply(expert_output, *router_res) else: diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 37c57f396..98f54cde7 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F from colossalai.utils import get_current_device -from colossalai.global_variables import moe_env +from colossalai.core import MOE_CONTEXT from .experts import FFNExperts, TPExperts @@ -36,7 +36,7 @@ class UniformNoiseGenerator: :type eps: float """ - def __init__(self, eps: float): + def __init__(self, eps: float = 1e-2): self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), high=torch.tensor(1.0 + eps, device=get_current_device())).rsample @@ -55,10 +55,10 @@ def autocast_softmax(inputs: torch.Tensor, dim: int): def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - moe_mp_size = moe_env.model_parallel_size - if num_experts % moe_mp_size == 0: + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % moe_mp_size == 0: + elif d_ff % mep_size == 0: return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) else: - raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.") + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py index 50f42fcd3..4c9c0fac8 100644 --- a/colossalai/nn/loss/loss_moe.py +++ b/colossalai/nn/loss/loss_moe.py @@ -1,7 +1,7 @@ import torch.nn as nn from colossalai.registry import LOSSES from torch.nn.modules.loss import _Loss -from colossalai.global_variables import moe_env +from colossalai.core import MOE_CONTEXT @LOSSES.register_module @@ -14,6 +14,7 @@ class MoeCrossEntropyLoss(_Loss): :type aux_weight: float, optional """ + def __init__(self, aux_weight: float = 0.01, *args, **kwargs): super().__init__() self.loss = nn.CrossEntropyLoss(*args, **kwargs) @@ -21,7 +22,7 @@ class MoeCrossEntropyLoss(_Loss): def forward(self, *args): main_loss = self.loss(*args) - aux_loss = moe_env.get_loss() + aux_loss = MOE_CONTEXT.get_loss() return main_loss + self.aux_weight * aux_loss @@ -37,6 +38,7 @@ class MoeLoss(_Loss): :type aux_weight: float :type loss_fn: Callable """ + def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): super().__init__() self.loss_fn = loss_fn(*args, **kwargs) @@ -44,5 +46,5 @@ class MoeLoss(_Loss): def forward(self, *args, **kwargs): main_loss = self.loss_fn(*args, **kwargs) - aux_loss = moe_env.get_loss() + aux_loss = MOE_CONTEXT.get_loss() return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py index 4d7e02218..70f413cbd 100644 --- a/colossalai/utils/moe.py +++ b/colossalai/utils/moe.py @@ -1,6 +1,6 @@ import torch.nn as nn import torch.distributed as dist -from colossalai.core import global_context as gpc, moe_context as moe_env +from colossalai.core import global_context as gpc, MOE_CONTEXT from colossalai.context import ParallelMode from .common import is_using_ddp from typing import Dict, List @@ -45,7 +45,7 @@ def sync_moe_model_param(model: nn.Module): for ep_size in param_dict: # When ep_size = world_size, communication is not needed - if ep_size != 1 and ep_size != moe_env.world_size: - src_rank = dist.get_rank(moe_env.information[ep_size].ep_group) + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + src_rank = dist.get_rank(MOE_CONTEXT.information[ep_size].ep_group) for param in param_dict[ep_size]: dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)