fix the merge

pull/6023/head
wangbluo 3 months ago
parent 2d362ac090
commit eb5ba40def

@ -146,7 +146,9 @@ class EPDeepseekMoE(nn.Module):
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
dist.all_to_all_single(
output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication
output_split_sizes,
input_split_sizes,
group=self.ep_group,
)
with torch.no_grad():

@ -68,6 +68,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group
self.fp8_communication = fp8_communication
if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
self.num_experts_per_ep = self.num_experts // self.ep_size
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]

@ -119,6 +119,7 @@ class Qwen2Policy(Policy):
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
@ -319,7 +320,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
# add a new item for casual lm
new_item = {
Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[

Loading…
Cancel
Save