|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.moe import SparseMLP
|
|
|
|
from colossalai.moe.manager import MOE_MANAGER
|
|
|
|
from colossalai.moe.utils import sync_moe_model_param
|
|
|
|
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep
|
|
|
|
|
|
|
|
|
|
|
|
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int):
|
|
|
|
assert batch_size % world_size == 0
|
|
|
|
|
|
|
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
|
|
|
|
|
|
MOE_MANAGER.__init__()
|
|
|
|
MOE_MANAGER.setup(parallel=None)
|
|
|
|
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
|
|
|
MOE_MANAGER.__init__()
|
|
|
|
MOE_MANAGER.setup(parallel="EP")
|
|
|
|
ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
|
|
|
MOE_MANAGER.__init__()
|
|
|
|
MOE_MANAGER.setup(parallel="TP")
|
|
|
|
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
|
|
|
ep_model = ep_model.to(get_current_device())
|
|
|
|
tp_model = tp_model.to(get_current_device())
|
|
|
|
local_model = local_model.to(get_current_device())
|
|
|
|
|
|
|
|
# sync ep param
|
|
|
|
sync_moe_model_param(ep_model)
|
|
|
|
dist_dict = MOE_MANAGER.parallel_info_dict
|
|
|
|
assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
|
|
|
|
assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
|
|
|
|
grad_handler = MoeGradientHandler(ep_model)
|
|
|
|
# sync tp param
|
|
|
|
sync_tp_from_ep(tp_model, ep_model)
|
|
|
|
# sync local param
|
|
|
|
sync_local_from_ep(local_model, ep_model)
|
|
|
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
tp_data = torch.randn(batch_size, dim, device=get_current_device())
|
|
|
|
micro_batch_size = batch_size // world_size
|
|
|
|
ep_data = tp_data.detach()[micro_batch_size * rank : micro_batch_size * (rank + 1)]
|
|
|
|
|
|
|
|
out_local = local_model(tp_data)
|
|
|
|
MOE_MANAGER.reset_loss()
|
|
|
|
out_tp = tp_model(tp_data)
|
|
|
|
MOE_MANAGER.reset_loss()
|
|
|
|
out_ep = ep_model(ep_data)
|
|
|
|
MOE_MANAGER.reset_loss()
|
|
|
|
assert torch.allclose(out_ep, out_tp[micro_batch_size * rank : micro_batch_size * (rank + 1)])
|
|
|
|
assert torch.allclose(out_ep, out_local[micro_batch_size * rank : micro_batch_size * (rank + 1)])
|
|
|
|
|
|
|
|
out_local.mean().backward()
|
|
|
|
out_tp.mean().backward()
|
|
|
|
out_ep.mean().backward()
|
|
|
|
grad_handler.handle_gradient()
|
|
|
|
|
|
|
|
assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
|
|
|
|
assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
|
|
|
|
|
|
|
|
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
|
|
|
|
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@pytest.mark.parametrize("num_experts", [4, 8])
|
|
|
|
@pytest.mark.parametrize("batch_size", [4])
|
|
|
|
@pytest.mark.parametrize("dim", [32])
|
|
|
|
@pytest.mark.parametrize("seed", [42])
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int):
|
|
|
|
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42)
|