import pytest import torch from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter @pytest.mark.parametrize( ["router", "num_groups"], [ (Top1Router(), 1), (Top2Router(), 1), # (TopKRouter(num_selected_experts=3), 4), ], ) @pytest.mark.parametrize( ["batch_size", "seq_len", "num_experts"], [ (4, 5, 8), (3, 4, 4), ], ) def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): x = torch.randn((batch_size * seq_len, num_experts)).cuda() if num_groups > 1: x = x.expand(num_groups, -1, -1) router.train() if isinstance(router, TopKRouter): combine_array, dispatch_mask = router(x, expert_capacity=2) else: combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) router.eval() if isinstance(router, TopKRouter): combine_array, dispatch_mask = router(x, expert_capacity=2) else: combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) if __name__ == "__main__": test_router_forward(Top2Router(), 4, 4, 4, 1)