mirror of https://github.com/hpcaitech/ColossalAI
fix the merge
parent
2d362ac090
commit
eb5ba40def
|
@ -146,7 +146,9 @@ class EPDeepseekMoE(nn.Module):
|
||||||
|
|
||||||
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
|
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
|
||||||
dist.all_to_all_single(
|
dist.all_to_all_single(
|
||||||
output_split_sizes, input_split_sizes, group=self.ep_group, fp8_communication=fp8_communication
|
output_split_sizes,
|
||||||
|
input_split_sizes,
|
||||||
|
group=self.ep_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -68,6 +68,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||||
self.ep_size = dist.get_world_size(ep_group)
|
self.ep_size = dist.get_world_size(ep_group)
|
||||||
self.ep_rank = dist.get_rank(ep_group)
|
self.ep_rank = dist.get_rank(ep_group)
|
||||||
self.ep_group = ep_group
|
self.ep_group = ep_group
|
||||||
|
self.fp8_communication = fp8_communication
|
||||||
|
|
||||||
|
if self.num_experts % self.ep_size != 0:
|
||||||
|
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||||
|
|
||||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||||
|
|
|
@ -119,6 +119,7 @@ class Qwen2Policy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
|
@ -319,7 +320,7 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||||
setattr(self.shard_config, "causal_lm", True)
|
setattr(self.shard_config, "causal_lm", True)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for causal lm
|
# add a new item for casual lm
|
||||||
new_item = {
|
new_item = {
|
||||||
Qwen2ForCausalLM: ModulePolicyDescription(
|
Qwen2ForCausalLM: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
|
Loading…
Reference in New Issue