You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_moe/test_moe_router.py

48 lines
1.4 KiB

import pytest
import torch
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
10 months ago
@pytest.mark.parametrize(
["router", "num_groups"],
[
(Top1Router(), 1),
(Top2Router(), 1),
# (TopKRouter(num_selected_experts=3), 4),
],
)
@pytest.mark.parametrize(
["batch_size", "seq_len", "num_experts"],
[
(4, 5, 8),
(3, 4, 4),
],
)
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
if num_groups > 1:
x = x.expand(num_groups, -1, -1)
router.train()
if isinstance(router, TopKRouter):
10 months ago
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
10 months ago
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
router.eval()
if isinstance(router, TopKRouter):
10 months ago
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
10 months ago
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
if __name__ == "__main__":
test_router_forward(Top2Router(), 4, 4, 4, 1)