From 102b784a10f0cd1c740d9ceba343a78166314290 Mon Sep 17 00:00:00 2001 From: hxwang Date: Fri, 12 Jul 2024 09:08:16 +0000 Subject: [PATCH] [chore] arg pass & remove drop token --- .../plugin/moe_hybrid_parallel_plugin.py | 15 ++++---- colossalai/shardformer/modeling/mixtral.py | 34 +++++++++++++------ tests/test_moe/test_mixtral_layer.py | 8 ++++- tests/test_moe/test_moe_checkpoint.py | 2 -- tests/test_moe/test_moe_ep_tp.py | 2 +- 5 files changed, 41 insertions(+), 20 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 22157b5cf..047782aa9 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,8 +1,8 @@ import warnings from types import MethodType from typing import Callable, Optional, OrderedDict, Tuple -import numpy as np +import numpy as np import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -21,7 +21,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( reinitialize_optimizer, ) from colossalai.checkpoint_io import MoECheckpointIO -from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.tensor.moe_tensor.api import is_moe_tensor @@ -89,11 +88,9 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): overlap_communication=overlap_communication, partition_grad=partition_grad, cpu_offload=cpu_offload, - # dp_process_group=dp_process_group, tp_process_group=tp_process_group, pp_process_group=pp_process_group, forced_dtype=forced_dtype, - ## moe args pg_to_param_list=pg_param_list, ) @@ -104,6 +101,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): """ def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None: + if "overlap_communication" not in kwargs: + kwargs["overlap_communication"] = False + super().__init__(*args, **kwargs) self.use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 @@ -142,7 +142,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.moe_dp_group = None self.ep_group = None self.moe_tp_group = None - + # create submesh for ep, moe_dp, moe_tp ranks_by_pp_stage = self.pg_mesh.get_group_along_axis( [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True @@ -182,7 +182,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): assert self.moe_tp_group is None self.moe_tp_group = group - self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", ranks=[0]) + self.logger.info( + f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}", + ranks=[0], + ) def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 609fc6f3e..5a42a1073 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -14,7 +14,13 @@ from transformers.models.mixtral.modeling_mixtral import ( from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext -from colossalai.moe._operation import DPGradScalerIn, DPGradScalerOut, EPGradScalerIn, EPGradScalerOut, all_to_all_uneven, drop_tokens, gather_tokens +from colossalai.moe._operation import ( + DPGradScalerIn, + DPGradScalerOut, + EPGradScalerIn, + EPGradScalerOut, + all_to_all_uneven, +) from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard.utils import set_tensors_to_none @@ -25,7 +31,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): def __init__(self, *args, **kwargs): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup): + def setup_process_groups( + self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup + ): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None @@ -59,7 +67,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): @staticmethod def from_native_module( - module: MixtralSparseMoeBlock, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup, *args, **kwargs + module: MixtralSparseMoeBlock, + tp_group: ProcessGroup, + moe_dp_group: ProcessGroup, + ep_group: ProcessGroup, + moe_tp_group: ProcessGroup, + *args, + **kwargs, ) -> "EPMixtralSparseMoeBlock": # TODO: better init LazyInitContext.materialize(module) @@ -96,8 +110,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() - if self.tp_group.size() > 1: - dispatch_states = drop_tokens(dispatch_states, -1, self.tp_group) + # TODO drop tokens to reduce tp group redundant communication output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) # compute expert output @@ -116,20 +129,21 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): for i, split_states in enumerate(output_states_splits): if split_states.size(0) == 0: continue - split_states = DPGradScalerIn.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()) + split_states = DPGradScalerIn.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item() + ) expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) split_states = expert.w2(split_states) - split_states = DPGradScalerOut.apply(split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()) + split_states = DPGradScalerOut.apply( + split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item() + ) output_states_list.append(split_states) output_states = torch.cat(output_states_list) output_states = EPGradScalerOut.apply(output_states, self.ep_size) dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) - if self.tp_group.size() > 1: - dispatch_states = gather_tokens(dispatch_states, -1, self.tp_group) - recover_experts_idx = torch.empty_like(selected_experts_idx) recover_experts_idx[selected_experts_idx] = torch.arange( selected_experts_idx.size(0), device=selected_experts_idx.device diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py index b7b0322e0..de34b5c7b 100644 --- a/tests/test_moe/test_mixtral_layer.py +++ b/tests/test_moe/test_mixtral_layer.py @@ -36,7 +36,13 @@ def check_mixtral_moe_layer(): x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() orig_output, orig_logits = orig_model(x) model = deepcopy(orig_model) - model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group) + model = EPMixtralSparseMoeBlock.from_native_module( + model, + ep_group=plugin.ep_group, + tp_group=plugin.tp_group, + moe_dp_group=plugin.moe_dp_group, + moe_tp_group=plugin.moe_tp_group, + ) ep_output, ep_logits = model(x) assert_close(orig_logits, ep_logits) assert_close(orig_output, ep_output) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 773036358..6f3c5b299 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -13,7 +13,6 @@ from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.checkpoint_io import MoECheckpointIO from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import parameterize, spawn from colossalai.testing.utils import spawn @@ -115,7 +114,6 @@ def check_moe_checkpoint(test_config): pp_size=2, ep_size=2, tp_size=1, - checkpoint_io=MoECheckpointIO, microbatch_size=1, zero_stage=1, ) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 632a8ce38..cc5448e51 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -34,7 +34,7 @@ def split_grad(grad, world_size): @parameterize("stage", [1]) @parameterize("ep_size", [1, 2, 4]) @parameterize("tp_size", [1, 2, 4]) -def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int = 1): +def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int): dtype = torch.bfloat16 rank = torch.distributed.get_rank()