mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
2.0 KiB
64 lines
2.0 KiB
from copy import deepcopy
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
|
|
from torch.testing import assert_close
|
|
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|
|
|
import colossalai
|
|
from colossalai.moe import MOE_MANAGER
|
|
from colossalai.testing.utils import spawn
|
|
|
|
tokens, n_experts = 7, 4
|
|
hidden_size = 8
|
|
top_k = 2
|
|
|
|
|
|
def check_mixtral_moe_layer():
|
|
torch.cuda.set_device(dist.get_rank())
|
|
MOE_MANAGER.setup(
|
|
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
|
|
)
|
|
config = MixtralConfig(
|
|
hidden_size=hidden_size,
|
|
intermediate_size=hidden_size * 2,
|
|
num_local_experts=n_experts,
|
|
num_experts_per_tok=top_k,
|
|
)
|
|
torch.manual_seed(0)
|
|
orig_model = MixtralSparseMoeBlock(config).cuda()
|
|
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
|
orig_output, orig_logits = orig_model(x)
|
|
model = deepcopy(orig_model)
|
|
model = EPMixtralSparseMoeBlock.from_native_module(model)
|
|
ep_output, ep_logits = model(x)
|
|
assert_close(orig_logits, ep_logits)
|
|
assert_close(orig_output, ep_output)
|
|
orig_loss = orig_output.mean()
|
|
orig_loss.backward()
|
|
ep_loss = ep_output.mean()
|
|
ep_loss.backward()
|
|
assert_close(orig_loss, ep_loss)
|
|
name_to_p = {n: p for n, p in orig_model.named_parameters()}
|
|
for n, ep_p in model.named_parameters():
|
|
p = name_to_p[n]
|
|
if ep_p.grad is not None:
|
|
assert_close(p.grad, ep_p.grad)
|
|
|
|
|
|
def run_dist(rank: int, world_size: int, port: int):
|
|
colossalai.launch(rank, world_size, "localhost", port)
|
|
check_mixtral_moe_layer()
|
|
|
|
|
|
@pytest.mark.parametrize("world_size", [2, 4])
|
|
def test_mixtral_moe_layer(world_size: int):
|
|
spawn(run_dist, world_size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_mixtral_moe_layer(2)
|