mirror of https://github.com/hpcaitech/ColossalAI
[moe] add parallel strategy for shared_expert && fix test for deepseek (#6063)
parent
63314ce4e4
commit
4fa6b9509c
|
@ -109,6 +109,19 @@ class EPDeepseekMoE(ParallelModule):
|
||||||
for p in self.experts.parameters():
|
for p in self.experts.parameters():
|
||||||
set_moe_tensor_ep_group(p, ep_group)
|
set_moe_tensor_ep_group(p, ep_group)
|
||||||
|
|
||||||
|
if self.config.n_shared_experts is not None:
|
||||||
|
self.shared_experts.gate_proj = Linear1D_Col.from_native_module(
|
||||||
|
self.shared_experts.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
|
self.shared_experts.up_proj = Linear1D_Col.from_native_module(
|
||||||
|
self.shared_experts.up_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
|
self.shared_experts.down_proj = Linear1D_Row.from_native_module(
|
||||||
|
self.shared_experts.down_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
module,
|
module,
|
||||||
|
|
|
@ -20,14 +20,15 @@ from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||||
NUM_BATCH = 8
|
NUM_BATCH = 8
|
||||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
|
||||||
NUM_LAYERS = 4
|
NUM_LAYERS = 4
|
||||||
HIDDEN_SIZE_PER_HEAD = 4
|
HIDDEN_SIZE_PER_HEAD = 8
|
||||||
NUM_HEADS = 8
|
NUM_HEADS = 8
|
||||||
TOP_K = 2
|
TOP_K = 2
|
||||||
|
|
||||||
|
|
||||||
def run_deepseek_commom(config: Tuple[int, ...]):
|
def run_deepseek_commom(parallel_config: Tuple[int, ...]):
|
||||||
Randomizer.reset_index()
|
Randomizer.reset_index()
|
||||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
print(f"rank {dist.get_rank()} testing {parallel_config}")
|
||||||
|
stage, ep_size, pp_size, tp_size, sp_size = parallel_config
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
dtype, precision = torch.bfloat16, "bf16"
|
dtype, precision = torch.bfloat16, "bf16"
|
||||||
|
@ -65,6 +66,7 @@ def run_deepseek_commom(config: Tuple[int, ...]):
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
torch_dtype="float16",
|
torch_dtype="float16",
|
||||||
n_routed_experts=NUM_EXPERTS,
|
n_routed_experts=NUM_EXPERTS,
|
||||||
|
n_shared_experts=2,
|
||||||
num_experts_per_tok=TOP_K,
|
num_experts_per_tok=TOP_K,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
@ -159,7 +161,7 @@ def run_deepseek_commom(config: Tuple[int, ...]):
|
||||||
if rank == world_size - 1:
|
if rank == world_size - 1:
|
||||||
shutil.rmtree(model_dir)
|
shutil.rmtree(model_dir)
|
||||||
|
|
||||||
print(f"rank {dist.get_rank()} test passed")
|
print(f"rank {dist.get_rank()} passed {parallel_config}")
|
||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
|
|
Loading…
Reference in New Issue