import os import warnings from typing import Dict import pytest import torch import torch.distributed as dist import colossalai from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeGradientHandler def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_flag: bool = False) -> None: """Sync the parameters of tp model from local model Args: tp_model (MoeModule) local_model (MoeModule) """ for (tp_name, tp_param), (local_name, local_param) in \ zip(tp_model.named_parameters(), local_model.named_parameters()): assert tp_name == local_name if not is_moe_tensor(tp_param): if assert_grad_flag: assert torch.allclose(tp_param, local_param) assert torch.allclose(tp_param.grad, local_param.grad) else: tp_param.data.copy_(local_param.data) continue tp_rank = get_ep_rank(tp_param) tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0] tp_slice = [slice(None)] * tp_dim + [ slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) ] if assert_grad_flag: assert torch.allclose(tp_param, local_param[tuple(tp_slice)]) assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)]) else: tp_param.data.copy_(local_param[tuple(tp_slice)].data) def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: """Sync the parameters of tp model from ep model Args: tp_model (MoeModule) ep_model (MoeModule) """ for (tp_name, tp_param), (ep_name, ep_param) in \ zip(tp_model.named_parameters(), ep_model.named_parameters()): assert tp_name == ep_name if not is_moe_tensor(tp_param): if assert_grad_flag: assert torch.allclose(tp_param, ep_param) assert torch.allclose(tp_param.grad, ep_param.grad) else: tp_param.data.copy_(ep_param.data) continue # gather param from ep model param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) all_param = torch.cat(param_list, dim=0) if assert_grad_flag: grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) all_grad = torch.cat(grad_list, dim=0) # get tp param tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1 tp_rank = get_ep_rank(tp_param) tp_slice = [slice(None)] * tp_dim + [ slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) ] new_tp_param = all_param[tuple(tp_slice)] if assert_grad_flag: new_grad = all_grad[tuple(tp_slice)] if assert_grad_flag: assert torch.allclose(tp_param, new_tp_param) assert torch.allclose(tp_param.grad, new_grad) else: tp_param.data.copy_(new_tp_param.data) def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: """Sync the parameters of tp model from ep model Args: local_model (MoeModule) ep_model (MoeModule) """ for (local_name, local_param), (ep_name, ep_param) in \ zip(local_model.named_parameters(), ep_model.named_parameters()): assert local_name == ep_name if "experts" not in local_name: if assert_grad_flag: assert torch.allclose(local_param, ep_param) assert torch.allclose(local_param.grad, ep_param.grad) else: local_param.data.copy_(ep_param.data) continue # gather param from ep model param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) all_param = torch.cat(param_list, dim=0) if assert_grad_flag: grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) all_grad = torch.cat(grad_list, dim=0) if assert_grad_flag: assert torch.allclose(local_param, all_param) assert torch.allclose(local_param.grad, all_grad) else: local_param.data.copy_(all_param.data) def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict): assert batch_size % world_size == 0 colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel=None) local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="EP") enable_hierarchical_comm = config.get("enable_hierarchical_comm", False) if enable_hierarchical_comm: os.environ["LOCAL_WORLD_SIZE"] = str(world_size) ep_model = SparseMLP( num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2, enable_hierarchical_comm=enable_hierarchical_comm ) MOE_MANAGER.__init__() MOE_MANAGER.setup(parallel="TP") tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) local_model = local_model.to(get_current_device()) # sync ep param sync_moe_model_param(ep_model) dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) ep_grad_handler = MoeGradientHandler(ep_model) # sync local param sync_local_from_ep(local_model, ep_model) # sync tp param sync_tp_from_ep(tp_model, ep_model) tp_grad_handler = MoeGradientHandler(tp_model) rank = dist.get_rank() input_data = torch.randn(batch_size, dim, device=get_current_device()) micro_batch_size = batch_size // world_size index = rank * micro_batch_size # NOTE: ep & tp takes in sharded data for each process shard_data = input_data.detach()[index:index + micro_batch_size] out_local = local_model(input_data) MOE_MANAGER.reset_loss() out_tp = tp_model(shard_data) MOE_MANAGER.reset_loss() out_ep = ep_model(shard_data) MOE_MANAGER.reset_loss() assert torch.allclose(out_tp, out_ep, atol=1e-6), \ f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}" try: out_local_slice = out_local[index:index + micro_batch_size] assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \ f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}" except AssertionError as e: """ e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1 router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2 However, in ep mode, there are 2 separate routers dealing with sharded data. Assume router 0 handles token [01] and router 1 handles token [23]. Note that for each router the capacity is only 1 !!! Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both. The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature. """ warnings.warn( "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) out_local.mean().backward() out_tp.mean().backward() tp_grad_handler.handle_gradient() out_ep.mean().backward() ep_grad_handler.handle_gradient() assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) try: sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) except AssertionError as e: warnings.warn( "EP & TP may result in different behavior from local model. " "Please check the comments for details." ) @pytest.mark.dist @pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("dim", [64]) @pytest.mark.parametrize("config", [ {"enable_hierarchical_comm": False}, {"enable_hierarchical_comm": True}, ]) @rerun_if_address_is_in_use() def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config) if __name__ == '__main__': test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)