mirror of https://github.com/hpcaitech/ColossalAI
[moe] fix tests
parent
65e5d6baa5
commit
06db94fbc9
|
@ -47,7 +47,7 @@ class MoeRouter(nn.Module, ABC):
|
|||
|
||||
def get_capacity(self, num_tokens, num_experts, ep_group=None):
|
||||
if ep_group is not None:
|
||||
num_tokens_tensor = torch.tensor(num_tokens, device=get_current_device())
|
||||
num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device())
|
||||
dist.all_reduce(num_tokens_tensor, group=ep_group)
|
||||
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
|
||||
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
||||
|
|
|
@ -911,11 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
||||
else:
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.copy_(working_moe_param)
|
||||
if hasattr(self, "master_moe_params"):
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.copy_(working_moe_param)
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.working_to_master_param
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
|
||||
if hasattr(self, "moe_master_to_working_map"):
|
||||
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
|
||||
return self._param_store.master_to_working_param
|
||||
|
|
|
@ -12,7 +12,6 @@ import colossalai
|
|||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||
|
||||
sys.path.append(
|
||||
|
@ -95,6 +94,7 @@ def get_model(parallel):
|
|||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=1,
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
|
@ -103,6 +103,7 @@ def get_model(parallel):
|
|||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=dist.get_world_size(),
|
||||
zero_stage=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
)
|
||||
|
@ -111,6 +112,7 @@ def get_model(parallel):
|
|||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=2,
|
||||
zero_stage=2,
|
||||
extra_dp_size=2,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
|
@ -120,6 +122,7 @@ def get_model(parallel):
|
|||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=2,
|
||||
ep_size=2,
|
||||
zero_stage=1,
|
||||
microbatch_size=1,
|
||||
custom_policy=OpenMoeForCausalLMPolicy(),
|
||||
|
@ -130,27 +133,6 @@ def get_model(parallel):
|
|||
|
||||
|
||||
def _test_moe_checkpoint(rank, parallel):
|
||||
if parallel == None:
|
||||
MOE_MANAGER.setup(
|
||||
parallel=None,
|
||||
)
|
||||
elif parallel == "ep":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
)
|
||||
elif parallel == "ep_zero":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=2,
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=1,
|
||||
fixed_ep_size=2,
|
||||
fixed_pp_size=2,
|
||||
)
|
||||
model1, booster1, optim1 = get_model(parallel)
|
||||
model2, booster2, optim2 = get_model(parallel)
|
||||
model3, booster3, optim3 = get_model(parallel)
|
||||
|
@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel):
|
|||
_test_moe_checkpoint(rank, parallel)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This is tested in ColossalMOE")
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
|
||||
|
|
|
@ -4,15 +4,21 @@ import torch
|
|||
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["router", "num_groups"], [
|
||||
(Top1Router(), 1),
|
||||
(Top2Router(), 1),
|
||||
# (TopKRouter(num_selected_experts=3), 4),
|
||||
])
|
||||
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
|
||||
(4, 5, 8),
|
||||
(3, 4, 4),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
["router", "num_groups"],
|
||||
[
|
||||
(Top1Router(), 1),
|
||||
(Top2Router(), 1),
|
||||
# (TopKRouter(num_selected_experts=3), 4),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
["batch_size", "seq_len", "num_experts"],
|
||||
[
|
||||
(4, 5, 8),
|
||||
(3, 4, 4),
|
||||
],
|
||||
)
|
||||
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
|
||||
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
|
||||
if num_groups > 1:
|
||||
|
@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
|
|||
|
||||
router.train()
|
||||
if isinstance(router, TopKRouter):
|
||||
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
_, combine_array, dispatch_mask = router(x)
|
||||
combine_array, dispatch_mask = router(x)[1:3]
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
||||
router.eval()
|
||||
if isinstance(router, TopKRouter):
|
||||
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
_, combine_array, dispatch_mask = router(x)
|
||||
combine_array, dispatch_mask = router(x)[1:3]
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
|
Loading…
Reference in New Issue