From 9b9b76bdcd4b2130796f59ce830c781b54762744 Mon Sep 17 00:00:00 2001 From: botbw Date: Fri, 12 Jul 2024 03:27:20 +0000 Subject: [PATCH] [moe] add mixtral dp grad scaling when not all experts are activated --- .../plugin/moe_hybrid_parallel_plugin.py | 6 ++- colossalai/moe/_operation.py | 51 +++++++++++++++++-- colossalai/shardformer/layer/moe/experts.py | 6 +-- colossalai/shardformer/layer/moe/routers.py | 6 +-- colossalai/shardformer/modeling/mixtral.py | 46 ++++++++++++----- colossalai/shardformer/policies/mixtral.py | 14 +---- colossalai/shardformer/shard/shard_config.py | 3 ++ ...o_fwd_bwd_optim.py => test_moe_ep_zero.py} | 8 ++- 8 files changed, 98 insertions(+), 42 deletions(-) rename tests/test_moe/{test_moe_zero_fwd_bwd_optim.py => test_moe_ep_zero.py} (97%) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index f689fe988..902500e42 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -141,6 +141,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): # set ep_group after super init # TODO do it in a better way self.shard_config.ep_group = self.ep_group + self.shard_config.moe_dp_group = self.moe_dp_group + self.shard_config.moe_tp_group = self.moe_tp_group self.force_overlap_comm = force_overlap_comm @@ -159,7 +161,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): # create groups from submesh for stage_idx, stage_rank in enumerate(ranks_by_pp_stage): - # axis 0 is dp, axis 1 is tp, axis 2 is sp + # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size) # hardcode here since we only have 3 axis @@ -188,7 +190,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): assert self.moe_tp_group is None self.moe_tp_group = group - self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}") + self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0]) def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index cad9573fb..abec2aa6e 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -290,7 +290,7 @@ def moe_cumsum(inputs: Tensor, use_kernel: bool = False): return torch.cumsum(inputs, dim=0) - 1 -class MoeInGradScaler(torch.autograd.Function): +class EPGradScalerIn(torch.autograd.Function): """ Scale the gradient back by the number of experts because the batch size increases in the moe stage @@ -298,8 +298,7 @@ class MoeInGradScaler(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: - if ctx is not None: - ctx.ep_size = ep_size + ctx.ep_size = ep_size return inputs @staticmethod @@ -311,7 +310,7 @@ class MoeInGradScaler(torch.autograd.Function): return grad, None -class MoeOutGradScaler(torch.autograd.Function): +class EPGradScalerOut(torch.autograd.Function): """ Scale the gradient by the number of experts because the batch size increases in the moe stage @@ -331,6 +330,50 @@ class MoeOutGradScaler(torch.autograd.Function): return grad, None +class DPGradScalerIn(torch.autograd.Function): + """ + Scale the gradient back by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor: + assert activated_experts != 0, f"shouldn't be called when no expert is activated" + ctx.moe_dp_size = moe_dp_size + ctx.activated_experts = activated_experts + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.moe_dp_size != ctx.activated_experts: + grad.mul_(ctx.activated_experts / ctx.moe_dp_size) + return grad, None, None + + +class DPGradScalerOut(torch.autograd.Function): + """ + Scale the gradient by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, moe_dp_size: int, activated_experts: int) -> Tensor: + assert activated_experts != 0, f"shouldn't be called when no expert is activated" + ctx.moe_dp_size = moe_dp_size + ctx.activated_experts = activated_experts + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.moe_dp_size != ctx.activated_experts: + grad.mul_(ctx.moe_dp_size / ctx.activated_experts) + return grad, None, None + + def _all_to_all( inputs: torch.Tensor, input_split_sizes: Optional[List[int]] = None, diff --git a/colossalai/shardformer/layer/moe/experts.py b/colossalai/shardformer/layer/moe/experts.py index 1be7a2754..109740dbb 100644 --- a/colossalai/shardformer/layer/moe/experts.py +++ b/colossalai/shardformer/layer/moe/experts.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation from colossalai.shardformer.layer.utils import Randomizer @@ -118,7 +118,7 @@ class MLPExperts(nn.Module): Returns: torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) """ - x = MoeInGradScaler.apply(x, self.ep_size) + x = EPGradScalerIn.apply(x, self.ep_size) e = x.size(1) h = x.size(-1) @@ -157,5 +157,5 @@ class MLPExperts(nn.Module): x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() - x = MoeOutGradScaler.apply(x, self.ep_size) + x = EPGradScalerOut.apply(x, self.ep_size) return x diff --git a/colossalai/shardformer/layer/moe/routers.py b/colossalai/shardformer/layer/moe/routers.py index 1be7a2754..109740dbb 100644 --- a/colossalai/shardformer/layer/moe/routers.py +++ b/colossalai/shardformer/layer/moe/routers.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.moe._operation import EPGradScalerIn, EPGradScalerOut from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation from colossalai.shardformer.layer.utils import Randomizer @@ -118,7 +118,7 @@ class MLPExperts(nn.Module): Returns: torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) """ - x = MoeInGradScaler.apply(x, self.ep_size) + x = EPGradScalerIn.apply(x, self.ep_size) e = x.size(1) h = x.size(-1) @@ -157,5 +157,5 @@ class MLPExperts(nn.Module): x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() - x = MoeOutGradScaler.apply(x, self.ep_size) + x = EPGradScalerOut.apply(x, self.ep_size) return x diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 5d2dc1dc3..609fc6f3e 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -14,18 +14,23 @@ from transformers.models.mixtral.modeling_mixtral import ( from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext -from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven, drop_tokens, gather_tokens +from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): - def __init__(self, config, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None): - super().__init__(config) - self.setup_process_groups(ep_group, tp_group, moe_tp_group) + def __init__(self, *args, **kwargs): + raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") + + def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup): + assert tp_group is not None + assert moe_dp_group is not None + assert ep_group is not None + assert moe_tp_group is not None - def setup_process_groups(self, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None): # setup ep group self.ep_size = dist.get_world_size(ep_group) self.ep_rank = dist.get_rank(ep_group) @@ -40,7 +45,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): set_tensors_to_none(self.experts, exclude=set(held_experts)) for p in self.experts.parameters(): - p.ep_group = ep_group + set_moe_tensor_ep_group(p, ep_group) + + # setup moe_dp group + self.moe_dp_group = moe_dp_group + self.moe_dp_size = moe_dp_group.size() # setup global tp group self.tp_group = tp_group @@ -50,11 +59,12 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): @staticmethod def from_native_module( - module: MixtralSparseMoeBlock, ep_group: ProcessGroup, tp_group: Optional[ProcessGroup]=None, moe_tp_group: Optional[ProcessGroup]=None, *args, **kwargs + module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs ) -> "EPMixtralSparseMoeBlock": + # TODO: better init LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock - module.setup_process_groups(ep_group) + module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -76,36 +86,48 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): output_split_sizes = torch.zeros_like(input_split_sizes) dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + with torch.no_grad(): + activate_experts = output_split_sizes[: self.num_experts_per_ep].clone() + for i in range(1, self.ep_size): + activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] + activate_experts = (activate_experts > 0).float() + dist.all_reduce(activate_experts, group=self.moe_dp_group) + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - if self.tp_group is not None and self.tp_group.size() > 1: + if self.tp_group.size() > 1: dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group) output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) # compute expert output - output_states = MoeInGradScaler.apply(output_states, self.ep_size) + output_states = EPGradScalerIn.apply(output_states, self.ep_size) if output_states.size(0) > 0: if self.num_experts_per_ep == 1: # no need to split expert = self.experts[self.expert_start_idx] + output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item()) output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) output_states = expert.w2(output_states) + output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item()) else: output_states_splits = output_states.split(output_split_sizes.tolist()) output_states_list = [] for i, split_states in enumerate(output_states_splits): if split_states.size(0) == 0: continue + split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()) expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) split_states = expert.w2(split_states) + split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()) output_states_list.append(split_states) output_states = torch.cat(output_states_list) - output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + + output_states = EPGradScalerOut.apply(output_states, self.ep_size) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) - if self.tp_group is not None and self.tp_group.size() > 1: + if self.tp_group.size() > 1: dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group) recover_experts_idx = torch.empty_like(selected_experts_idx) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 14d57c79d..69bcc54ed 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -76,18 +76,6 @@ class MixtralPolicy(Policy): suffix="self_attn.o_proj", target_module=Linear1D_Row, ), - # SubModuleReplacementDescription( # TODO: enable moe tp parallel - # suffix="mlp.gate_proj", - # target_module=Linear1D_Col, - # ), - # SubModuleReplacementDescription( - # suffix="mlp.up_proj", - # target_module=Linear1D_Col, - # ), - # SubModuleReplacementDescription( - # suffix="mlp.down_proj", - # target_module=Linear1D_Row, - # ), ], ) @@ -98,7 +86,7 @@ class MixtralPolicy(Policy): SubModuleReplacementDescription( suffix="block_sparse_moe", target_module=EPMixtralSparseMoeBlock, - kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group}, + kwargs={"ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "moe_tp_group": self.shard_config.moe_tp_group}, ) ], policy=policy, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index d1aebd5b2..f12c78526 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -46,6 +46,9 @@ class ShardConfig: make_vocab_size_divisible_by: int = 64 gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + + # for moe related + moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None moe_tp_group: Optional[ProcessGroup] = None diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_ep_zero.py similarity index 97% rename from tests/test_moe/test_moe_zero_fwd_bwd_optim.py rename to tests/test_moe/test_moe_ep_zero.py index 3d6af2b1a..c5adaad06 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py +++ b/tests/test_moe/test_moe_ep_zero.py @@ -18,8 +18,7 @@ NUM_BATCH=4 NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 HIDDEN_SIZE_PER_HEAD = 4 NUM_HEADS=2 -TOP_K = 2 - +TOP_K = 1 def split_grad(grad, world_size): with torch.no_grad(): @@ -96,7 +95,6 @@ def run_zero_with_original_model(stage: int, ep_size: int): # check grad name_to_p = {n: p for n, p in ddp_model.named_parameters()} for n, p in zero_model.named_parameters(): - print(f"rank {dist.get_rank()} {n}") zero_grad = zero_optimizer.get_param_grad(p) if name_to_p[n].grad is None: name_to_p[n].grad = torch.zeros_like(name_to_p[n].data) @@ -124,9 +122,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(world_size): +def test_moe_ep_zero(world_size): spawn(run_dist, world_size) if __name__ == "__main__": - test_moe_ep_tp(world_size=4) + test_moe_ep_zero(world_size=4)