|
|
|
import os
|
|
|
|
import warnings
|
|
|
|
from typing import Dict
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.moe import SparseMLP
|
|
|
|
from colossalai.moe.manager import MOE_MANAGER
|
|
|
|
from colossalai.moe.utils import sync_moe_model_param
|
|
|
|
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
|
|
|
|
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from tests.test_moe.moe_utils import MoeGradientHandler
|
|
|
|
|
|
|
|
|
|
|
|
def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
|
|
|
"""Sync the parameters of tp model from local model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tp_model (MoeModule)
|
|
|
|
local_model (MoeModule)
|
|
|
|
"""
|
|
|
|
for (tp_name, tp_param), (local_name, local_param) in \
|
|
|
|
zip(tp_model.named_parameters(), local_model.named_parameters()):
|
|
|
|
assert tp_name == local_name
|
|
|
|
if not is_moe_tensor(tp_param):
|
|
|
|
if assert_grad_flag:
|
|
|
|
assert torch.allclose(tp_param, local_param)
|
|
|
|
assert torch.allclose(tp_param.grad, local_param.grad)
|
|
|
|
else:
|
|
|
|
tp_param.data.copy_(local_param.data)
|
|
|
|
continue
|
|
|
|
|
|
|
|
tp_rank = get_ep_rank(tp_param)
|
|
|
|
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0]
|
|
|
|
tp_slice = [slice(None)] * tp_dim + [
|
|
|
|
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
|
|
|
|
]
|
|
|
|
|
|
|
|
if assert_grad_flag:
|
|
|
|
assert torch.allclose(tp_param, local_param[tuple(tp_slice)])
|
|
|
|
assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)])
|
|
|
|
else:
|
|
|
|
tp_param.data.copy_(local_param[tuple(tp_slice)].data)
|
|
|
|
|
|
|
|
|
|
|
|
def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
|
|
|
"""Sync the parameters of tp model from ep model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tp_model (MoeModule)
|
|
|
|
ep_model (MoeModule)
|
|
|
|
"""
|
|
|
|
for (tp_name, tp_param), (ep_name, ep_param) in \
|
|
|
|
zip(tp_model.named_parameters(), ep_model.named_parameters()):
|
|
|
|
assert tp_name == ep_name
|
|
|
|
if not is_moe_tensor(tp_param):
|
|
|
|
if assert_grad_flag:
|
|
|
|
assert torch.allclose(tp_param, ep_param)
|
|
|
|
assert torch.allclose(tp_param.grad, ep_param.grad)
|
|
|
|
else:
|
|
|
|
tp_param.data.copy_(ep_param.data)
|
|
|
|
continue
|
|
|
|
|
|
|
|
# gather param from ep model
|
|
|
|
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
|
|
|
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
|
|
|
|
all_param = torch.cat(param_list, dim=0)
|
|
|
|
if assert_grad_flag:
|
|
|
|
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
|
|
|
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
|
|
|
|
all_grad = torch.cat(grad_list, dim=0)
|
|
|
|
|
|
|
|
# get tp param
|
|
|
|
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1
|
|
|
|
tp_rank = get_ep_rank(tp_param)
|
|
|
|
tp_slice = [slice(None)] * tp_dim + [
|
|
|
|
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
|
|
|
|
]
|
|
|
|
new_tp_param = all_param[tuple(tp_slice)]
|
|
|
|
if assert_grad_flag:
|
|
|
|
new_grad = all_grad[tuple(tp_slice)]
|
|
|
|
if assert_grad_flag:
|
|
|
|
assert torch.allclose(tp_param, new_tp_param)
|
|
|
|
assert torch.allclose(tp_param.grad, new_grad)
|
|
|
|
else:
|
|
|
|
tp_param.data.copy_(new_tp_param.data)
|
|
|
|
|
|
|
|
|
|
|
|
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
|
|
|
"""Sync the parameters of tp model from ep model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
local_model (MoeModule)
|
|
|
|
ep_model (MoeModule)
|
|
|
|
"""
|
|
|
|
for (local_name, local_param), (ep_name, ep_param) in \
|
|
|
|
zip(local_model.named_parameters(), ep_model.named_parameters()):
|
|
|
|
assert local_name == ep_name
|
|
|
|
if "experts" not in local_name:
|
|
|
|
if assert_grad_flag:
|
|
|
|
assert torch.allclose(local_param, ep_param)
|
|
|
|
assert torch.allclose(local_param.grad, ep_param.grad)
|
|
|
|
else:
|
|
|
|
local_param.data.copy_(ep_param.data)
|
|
|
|
continue
|
|
|
|
|
|
|
|
# gather param from ep model
|
|
|
|
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
|
|
|
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
|
|
|
|
all_param = torch.cat(param_list, dim=0)
|
|
|
|
if assert_grad_flag:
|
|
|
|
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
|
|
|
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
|
|
|
|
all_grad = torch.cat(grad_list, dim=0)
|
|
|
|
|
|
|
|
if assert_grad_flag:
|
|
|
|
assert torch.allclose(local_param, all_param)
|
|
|
|
assert torch.allclose(local_param.grad, all_grad)
|
|
|
|
else:
|
|
|
|
local_param.data.copy_(all_param.data)
|
|
|
|
|
|
|
|
|
|
|
|
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
|
|
|
|
assert batch_size % world_size == 0
|
|
|
|
|
|
|
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
|
|
|
|
|
|
MOE_MANAGER.__init__()
|
|
|
|
MOE_MANAGER.setup(parallel=None)
|
|
|
|
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
|
|
|
MOE_MANAGER.__init__()
|
|
|
|
MOE_MANAGER.setup(parallel="EP")
|
|
|
|
enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
|
|
|
|
if enable_hierarchical_comm:
|
|
|
|
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
|
|
|
|
ep_model = SparseMLP(
|
|
|
|
num_experts=num_experts,
|
|
|
|
hidden_size=dim,
|
|
|
|
intermediate_size=dim * 2,
|
|
|
|
enable_hierarchical_comm=enable_hierarchical_comm
|
|
|
|
)
|
|
|
|
MOE_MANAGER.__init__()
|
|
|
|
MOE_MANAGER.setup(parallel="TP")
|
|
|
|
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
|
|
|
|
ep_model = ep_model.to(get_current_device())
|
|
|
|
tp_model = tp_model.to(get_current_device())
|
|
|
|
local_model = local_model.to(get_current_device())
|
|
|
|
|
|
|
|
# sync ep param
|
|
|
|
sync_moe_model_param(ep_model)
|
|
|
|
dist_dict = MOE_MANAGER.parallel_info_dict
|
|
|
|
assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
|
|
|
|
assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
|
|
|
|
ep_grad_handler = MoeGradientHandler(ep_model)
|
|
|
|
# sync local param
|
|
|
|
sync_local_from_ep(local_model, ep_model)
|
|
|
|
# sync tp param
|
|
|
|
sync_tp_from_ep(tp_model, ep_model)
|
|
|
|
tp_grad_handler = MoeGradientHandler(tp_model)
|
|
|
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
input_data = torch.randn(batch_size, dim, device=get_current_device())
|
|
|
|
micro_batch_size = batch_size // world_size
|
|
|
|
index = rank * micro_batch_size
|
|
|
|
# NOTE: ep & tp takes in sharded data for each process
|
|
|
|
shard_data = input_data.detach()[index:index + micro_batch_size]
|
|
|
|
|
|
|
|
out_local = local_model(input_data)
|
|
|
|
MOE_MANAGER.reset_loss()
|
|
|
|
out_tp = tp_model(shard_data)
|
|
|
|
MOE_MANAGER.reset_loss()
|
|
|
|
out_ep = ep_model(shard_data)
|
|
|
|
MOE_MANAGER.reset_loss()
|
|
|
|
|
|
|
|
assert torch.allclose(out_tp, out_ep, atol=1e-6), \
|
|
|
|
f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
|
|
|
|
try:
|
|
|
|
out_local_slice = out_local[index:index + micro_batch_size]
|
|
|
|
assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \
|
|
|
|
f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
|
|
|
|
except AssertionError as e:
|
|
|
|
"""
|
|
|
|
e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
|
|
|
|
router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
|
|
|
|
However, in ep mode, there are 2 separate routers dealing with sharded data.
|
|
|
|
Assume router 0 handles token [01] and router 1 handles token [23].
|
|
|
|
Note that for each router the capacity is only 1 !!!
|
|
|
|
Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both.
|
|
|
|
The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
|
|
|
|
"""
|
|
|
|
warnings.warn(
|
|
|
|
"EP & TP may result in different behavior from local model. "
|
|
|
|
"Please check the comments for details."
|
|
|
|
)
|
|
|
|
|
|
|
|
out_local.mean().backward()
|
|
|
|
out_tp.mean().backward()
|
|
|
|
tp_grad_handler.handle_gradient()
|
|
|
|
out_ep.mean().backward()
|
|
|
|
ep_grad_handler.handle_gradient()
|
|
|
|
|
|
|
|
assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
|
|
|
|
assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
|
|
|
|
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
|
|
|
|
try:
|
|
|
|
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
|
|
|
|
except AssertionError as e:
|
|
|
|
warnings.warn(
|
|
|
|
"EP & TP may result in different behavior from local model. "
|
|
|
|
"Please check the comments for details."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@pytest.mark.parametrize("num_experts", [4, 64])
|
|
|
|
@pytest.mark.parametrize("batch_size", [16])
|
|
|
|
@pytest.mark.parametrize("dim", [64])
|
|
|
|
@pytest.mark.parametrize("config", [
|
|
|
|
{"enable_hierarchical_comm": False},
|
|
|
|
{"enable_hierarchical_comm": True},
|
|
|
|
])
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
|
|
|
|
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
|