From 696fced0d722ab582568fb5b6f6d7dbc536d3053 Mon Sep 17 00:00:00 2001 From: botbw Date: Fri, 13 Sep 2024 14:30:05 +0800 Subject: [PATCH] [fp8] fix missing fp8_comm flag in mixtral (#6057) --- colossalai/shardformer/modeling/mixtral.py | 7 ++++++- colossalai/shardformer/policies/mixtral.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 0103808dc..4f8ec162f 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -31,6 +31,7 @@ from colossalai.moe._operation import ( all_to_all_uneven, ) from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import all_reduce_fp8 from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_forward_split_backward, @@ -142,7 +143,11 @@ class EPMixtralSparseMoeBlock(ParallelModule): for i in range(1, self.ep_size): activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep] activate_experts = (activate_experts > 0).float() - dist.all_reduce(activate_experts, group=self.moe_dp_group) + + if self.fp8_communication: + all_reduce_fp8(activate_experts, group=self.moe_dp_group) + else: + dist.all_reduce(activate_experts, group=self.moe_dp_group) input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 9f03319e7..8e2ca5de0 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -178,6 +178,7 @@ class MixtralPolicy(Policy): "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, + "fp8_communication": self.shard_config.fp8_communication, }, ) ],