diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py new file mode 100644 index 000000000..26fa81921 --- /dev/null +++ b/tests/test_moe/modelling/test_mixtral.py @@ -0,0 +1,140 @@ +import os +import shutil +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel + +import colossalai +from colossalai.booster.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import loose_close +from tests.test_moe.test_moe_checkpoint import check_model_equal + +NUM_BATCH = 4 +NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 +HIDDEN_SIZE_PER_HEAD = 4 +NUM_HEADS = 2 +TOP_K = 1 + + +def split_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +@parameterize("stage", [1]) +@parameterize("ep_size", [1, 2, 4]) +def run_zero_with_original_model(stage: int, ep_size: int, tp_size: int): + dtype = torch.float32 + + rank = torch.distributed.get_rank() + torch.cuda.set_device(dist.get_rank()) + + plugin = MoeHybridParallelPlugin( + pp_size=1, + tp_size=1, + ep_size=ep_size, + zero_stage=stage, + overlap_communication=False, + initial_scale=1, + precision="fp32", + ) + booster = Booster(plugin=plugin) + + seed_all(10086) + + config = MixtralConfig( + hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS, + intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2, + num_hidden_layers=2, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_HEADS, + num_local_experts=NUM_EXPERTS, + num_experts_per_tok=TOP_K, + ) + + torch_model = MixtralModel(config).to(dtype).cuda() + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + + zero_model = deepcopy(torch_model).to(dtype) + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + # create different input + seed_all(1453 + rank) + + torch_model.train() + zero_model.train() + for _ in range(1): + # zero-dp forward + input_data = torch.rand( + NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True + ).cuda() + zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() + # zero-dp backward + print(zero_output.dtype) + zero_optimizer.backward(zero_output) + zero_optimizer.step() + + dist.all_reduce(zero_output) + + all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())] + dist.all_gather(all_inputs, input_data) + + torch_output_sum = 0 + for input_data_ in all_inputs: + torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() + torch_output.backward() + torch_output_sum += torch_output.detach() + + # avg dp grads + for p in torch_model.parameters(): + if p.grad is not None: + p.grad /= dist.get_world_size() + + loose_close(zero_output, torch_output_sum, dtype=dtype) + torch_optimizer.step() + + # use checkpoint to load sharded zero model + model_dir = "./test_mixtral" + if dist.get_rank() == 0: + os.makedirs(model_dir, exist_ok=True) + + dist.barrier() + booster.save_model(zero_model, model_dir, shard=True) + dist.barrier() + + if dist.get_rank() == 0: + saved_model = MixtralModel.from_pretrained(model_dir).cuda() + check_model_equal(torch_model, saved_model) + shutil.rmtree(model_dir) + + print(f"{dist.get_rank()} test passed") + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_zero_with_original_model() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_mistral(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mistral(world_size=4) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index b7332a937..e49edb6f4 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -2,7 +2,6 @@ import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup -from torch.testing import assert_close from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler @@ -146,6 +145,10 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""): elif dtype is torch.bfloat16: rtol = 4e-3 atol = 4e-3 + else: + assert dtype is torch.float32 + rtol = 1e-5 + atol = 1e-5 a = a.detach().to(dtype) b = b.detach().to(dtype).to(a.device)