From eb5ba40defe727927abe1e8da74103ecdd6e049f Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 21 Aug 2024 02:58:23 +0000 Subject: [PATCH] fix the merge --- colossalai/shardformer/modeling/deepseek.py | 4 +++- colossalai/shardformer/modeling/mixtral.py | 5 +++++ colossalai/shardformer/policies/qwen2.py | 3 ++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 59f9d4516..7ec390d6a 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -146,7 +146,9 @@ class EPDeepseekMoE(nn.Module): # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3] dist.all_to_all_single( - output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication + output_split_sizes, + input_split_sizes, + group=self.ep_group, ) with torch.no_grad(): diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 50334677e..4850ef1b6 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -68,6 +68,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): self.ep_size = dist.get_world_size(ep_group) self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group + self.fp8_communication = fp8_communication + + if self.num_experts % self.ep_size != 0: + raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") + self.num_experts_per_ep = self.num_experts // self.ep_size self.expert_start_idx = self.ep_rank * self.num_experts_per_ep held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 7f82f8290..1b066200d 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -119,6 +119,7 @@ class Qwen2Policy(Policy): SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", @@ -319,7 +320,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): setattr(self.shard_config, "causal_lm", True) if self.shard_config.enable_tensor_parallelism: - # add a new item for causal lm + # add a new item for casual lm new_item = { Qwen2ForCausalLM: ModulePolicyDescription( sub_module_replacement=[