mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
42 lines
1.4 KiB
42 lines
1.4 KiB
import pytest
|
|
import torch
|
|
|
|
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
|
|
|
|
|
@pytest.mark.parametrize(["router", "num_groups"], [
|
|
(Top1Router(), 1),
|
|
(Top2Router(), 1),
|
|
# (TopKRouter(num_selected_experts=3), 4),
|
|
])
|
|
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
|
|
(4, 5, 8),
|
|
(3, 4, 4),
|
|
])
|
|
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
|
|
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
|
|
if num_groups > 1:
|
|
x = x.expand(num_groups, -1, -1)
|
|
|
|
router.train()
|
|
if isinstance(router, TopKRouter):
|
|
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
|
|
else:
|
|
_, combine_array, dispatch_mask = router(x)
|
|
assert combine_array.shape[:-1] == x.shape
|
|
assert dispatch_mask.shape[:-1] == x.shape
|
|
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
|
|
|
router.eval()
|
|
if isinstance(router, TopKRouter):
|
|
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
|
|
else:
|
|
_, combine_array, dispatch_mask = router(x)
|
|
assert combine_array.shape[:-1] == x.shape
|
|
assert dispatch_mask.shape[:-1] == x.shape
|
|
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_router_forward(Top2Router(), 4, 4, 4, 1)
|