[moe] fix tests

pull/5372/head
ver217 2024-02-08 12:46:37 +08:00
parent 65e5d6baa5
commit 06db94fbc9
4 changed files with 31 additions and 39 deletions

View File

@ -47,7 +47,7 @@ class MoeRouter(nn.Module, ABC):
def get_capacity(self, num_tokens, num_experts, ep_group=None):
if ep_group is not None:
num_tokens_tensor = torch.tensor(num_tokens, device=get_current_device())
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
dist.all_reduce(num_tokens_tensor, group=ep_group)
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval

View File

@ -911,11 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
else:
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
if hasattr(self, "master_moe_params"):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.copy_(working_moe_param)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
if hasattr(self, "moe_master_to_working_map"):
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
return self._param_store.master_to_working_param

View File

@ -12,7 +12,6 @@ import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
sys.path.append(
@ -95,6 +94,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=1,
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
@ -103,6 +103,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size(),
zero_stage=2,
custom_policy=OpenMoeForCausalLMPolicy(),
)
@ -111,6 +112,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=1,
ep_size=2,
zero_stage=2,
extra_dp_size=2,
custom_policy=OpenMoeForCausalLMPolicy(),
@ -120,6 +122,7 @@ def get_model(parallel):
precision="bf16",
tp_size=1,
pp_size=2,
ep_size=2,
zero_stage=1,
microbatch_size=1,
custom_policy=OpenMoeForCausalLMPolicy(),
@ -130,27 +133,6 @@ def get_model(parallel):
def _test_moe_checkpoint(rank, parallel):
if parallel == None:
MOE_MANAGER.setup(
parallel=None,
)
elif parallel == "ep":
MOE_MANAGER.setup(
parallel="EP",
)
elif parallel == "ep_zero":
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=2,
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
model3, booster3, optim3 = get_model(parallel)
@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
_test_moe_checkpoint(rank, parallel)
@pytest.mark.skip(reason="This is tested in ColossalMOE")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])

View File

@ -4,15 +4,21 @@ import torch
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
@pytest.mark.parametrize(["router", "num_groups"], [
(Top1Router(), 1),
(Top2Router(), 1),
# (TopKRouter(num_selected_experts=3), 4),
])
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
(4, 5, 8),
(3, 4, 4),
])
@pytest.mark.parametrize(
["router", "num_groups"],
[
(Top1Router(), 1),
(Top2Router(), 1),
# (TopKRouter(num_selected_experts=3), 4),
],
)
@pytest.mark.parametrize(
["batch_size", "seq_len", "num_experts"],
[
(4, 5, 8),
(3, 4, 4),
],
)
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
if num_groups > 1:
@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
router.train()
if isinstance(router, TopKRouter):
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
_, combine_array, dispatch_mask = router(x)
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
router.eval()
if isinstance(router, TopKRouter):
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
_, combine_array, dispatch_mask = router(x)
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)