You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_moe/test_moe_ep_tp.py

82 lines
3.3 KiB

import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
assert batch_size % world_size == 0
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_current_device())
tp_model = tp_model.to(get_current_device())
local_model = local_model.to(get_current_device())
# sync ep param
sync_moe_model_param(ep_model)
dist_dict = MOE_MANAGER.parallel_info_dict
assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
grad_handler = MoeGradientHandler(ep_model)
# sync tp param
sync_tp_from_ep(tp_model, ep_model)
# sync local param
sync_local_from_ep(local_model, ep_model)
rank = dist.get_rank()
torch.cuda.manual_seed(seed)
tp_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size
ep_data = tp_data.detach()[micro_batch_size * rank : micro_batch_size * (rank + 1)]
out_local = local_model(tp_data)
MOE_MANAGER.reset_loss()
out_tp = tp_model(tp_data)
MOE_MANAGER.reset_loss()
out_ep = ep_model(ep_data)
MOE_MANAGER.reset_loss()
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank : micro_batch_size * (rank + 1)])
assert torch.allclose(out_ep, out_local[micro_batch_size * rank : micro_batch_size * (rank + 1)])
out_local.mean().backward()
out_tp.mean().backward()
out_ep.mean().backward()
grad_handler.handle_gradient()
assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
@pytest.mark.dist
@pytest.mark.parametrize("num_experts", [4, 8])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("dim", [32])
@pytest.mark.parametrize("seed", [42])
@rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
if __name__ == "__main__":
test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42)