diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index a783b5c5e..3687cfb99 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule): moe_dp_group: ProcessGroup, ep_group: ProcessGroup, fp8_communication: bool = False, + use_zbv: bool = False, ): assert tp_group is not None assert moe_dp_group is not None @@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule): self.ep_rank = dist.get_rank(ep_group) self.ep_group = ep_group self.fp8_communication = fp8_communication + self.use_zbv = use_zbv if self.num_experts % self.ep_size != 0: raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") @@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule): if self.tp_group.size() > 1: for expert in held_experts: expert.w1 = Linear1D_Col.from_native_module( - expert.w1, self.tp_group, fp8_communication=self.fp8_communication + expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv ) expert.w3 = Linear1D_Col.from_native_module( - expert.w3, self.tp_group, fp8_communication=self.fp8_communication + expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv ) expert.w2 = Linear1D_Row.from_native_module( - expert.w2, self.tp_group, fp8_communication=self.fp8_communication + expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv ) for p in self.experts.parameters(): diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 705f2b19f..de546b3c5 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -195,6 +195,7 @@ class MixtralPolicy(Policy): "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, "fp8_communication": self.shard_config.fp8_communication, + "use_zbv": self.shard_config.use_zbv, }, ) ],