mirror of https://github.com/hpcaitech/ColossalAI
[feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;
parent
abd455189d
commit
160e9a4175
|
@ -60,6 +60,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||||
moe_dp_group: ProcessGroup,
|
moe_dp_group: ProcessGroup,
|
||||||
ep_group: ProcessGroup,
|
ep_group: ProcessGroup,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
|
use_zbv: bool = False,
|
||||||
):
|
):
|
||||||
assert tp_group is not None
|
assert tp_group is not None
|
||||||
assert moe_dp_group is not None
|
assert moe_dp_group is not None
|
||||||
|
@ -70,6 +71,7 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||||
self.ep_rank = dist.get_rank(ep_group)
|
self.ep_rank = dist.get_rank(ep_group)
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
self.fp8_communication = fp8_communication
|
self.fp8_communication = fp8_communication
|
||||||
|
self.use_zbv = use_zbv
|
||||||
|
|
||||||
if self.num_experts % self.ep_size != 0:
|
if self.num_experts % self.ep_size != 0:
|
||||||
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||||
|
@ -89,13 +91,13 @@ class EPMixtralSparseMoeBlock(ParallelModule):
|
||||||
if self.tp_group.size() > 1:
|
if self.tp_group.size() > 1:
|
||||||
for expert in held_experts:
|
for expert in held_experts:
|
||||||
expert.w1 = Linear1D_Col.from_native_module(
|
expert.w1 = Linear1D_Col.from_native_module(
|
||||||
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
|
expert.w1, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||||
)
|
)
|
||||||
expert.w3 = Linear1D_Col.from_native_module(
|
expert.w3 = Linear1D_Col.from_native_module(
|
||||||
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
|
expert.w3, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||||
)
|
)
|
||||||
expert.w2 = Linear1D_Row.from_native_module(
|
expert.w2 = Linear1D_Row.from_native_module(
|
||||||
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
|
expert.w2, self.tp_group, fp8_communication=self.fp8_communication, use_zbv=self.use_zbv
|
||||||
)
|
)
|
||||||
|
|
||||||
for p in self.experts.parameters():
|
for p in self.experts.parameters():
|
||||||
|
|
|
@ -195,6 +195,7 @@ class MixtralPolicy(Policy):
|
||||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||||
"fp8_communication": self.shard_config.fp8_communication,
|
"fp8_communication": self.shard_config.fp8_communication,
|
||||||
|
"use_zbv": self.shard_config.use_zbv,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in New Issue