[feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;

pull/6083/head
duanjunwen 2024-10-14 08:22:51 +00:00
parent abd455189d
commit 160e9a4175
2 changed files with 6 additions and 3 deletions

View File

@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
moe_dp_group: ProcessGroup, moe_dp_group: ProcessGroup,
ep_group: ProcessGroup, ep_group: ProcessGroup,
fp8_communication: bool = False, fp8_communication: bool = False,
use_zbv: bool = False,
): ):
assert tp_group is not None assert tp_group is not None
assert moe_dp_group is not None assert moe_dp_group is not None
@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
self.ep_rank = dist.get_rank(ep_group) self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group self.ep_group = ep_group
self.fp8_communication = fp8_communication self.fp8_communication = fp8_communication
self.use_zbv = use_zbv
if self.num_experts % self.ep_size != 0: if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.") raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule):
if self.tp_group.size() > 1: if self.tp_group.size() > 1:
for expert in held_experts: for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module( expert.w1 = Linear1D_Col.from_native_module(
expert.w1, self.tp_group, fp8_communication=self.fp8_communication expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
expert.w3 = Linear1D_Col.from_native_module( expert.w3 = Linear1D_Col.from_native_module(
expert.w3, self.tp_group, fp8_communication=self.fp8_communication expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
expert.w2 = Linear1D_Row.from_native_module( expert.w2 = Linear1D_Row.from_native_module(
expert.w2, self.tp_group, fp8_communication=self.fp8_communication expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
) )
for p in self.experts.parameters(): for p in self.experts.parameters():

View File

@ -195,6 +195,7 @@ class MixtralPolicy(Policy):
"tp_group": self.shard_config.tensor_parallel_process_group, "tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group, "moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication, "fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
}, },
) )
], ],