import os import shutil from copy import deepcopy from typing import Tuple import pytest import torch import torch.distributed as dist from transformers import AutoConfig, AutoModel import colossalai from colossalai.booster.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import loose_close from tests.test_moe.test_moe_checkpoint import check_model_equal NUM_BATCH = 4 NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4 HIDDEN_SIZE_PER_HEAD = 4 NUM_HEADS = 4 TOP_K = 1 @parameterize("config", [(1, 1, 1)]) def run_zero_with_original_model(config: Tuple[int, ...]): stage, ep_size, tp_size = config dtype = torch.float16 rank = torch.distributed.get_rank() torch.cuda.set_device(dist.get_rank()) plugin = MoeHybridParallelPlugin( pp_size=1, tp_size=tp_size, moe_tp_size=tp_size, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1, precision="fp32", ) booster = Booster(plugin=plugin) seed_all(10086) config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True) config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2 config.num_hidden_layers = 2 config.num_attention_heads = NUM_HEADS config.num_key_value_heads = NUM_HEADS config.n_routed_experts = NUM_EXPERTS config.num_experts_per_tok = TOP_K torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) zero_model = deepcopy(torch_model).to(dtype) zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) # create different input seed_all(1453 + rank) torch_model.train() zero_model.train() for _ in range(2): input_data = torch.rand( NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True ).cuda() dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean() zero_optimizer.backward(zero_output) zero_optimizer.step() zero_optimizer.zero_grad() dist.all_reduce(zero_output) all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())] dist.all_gather(all_inputs, input_data) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward() torch_output_sum += torch_output.detach() # avg dp grads for p in torch_model.parameters(): if p.grad is not None: p.grad /= dist.get_world_size() torch_optimizer.step() torch_optimizer.zero_grad() loose_close(zero_output, torch_output_sum, dtype=dtype) # use checkpoint to load sharded zero model model_dir = "./test_deepseek" if dist.get_rank() == 0: os.makedirs(model_dir, exist_ok=True) dist.barrier() booster.save_model(zero_model, model_dir, shard=True) dist.barrier() saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda() check_model_equal(torch_model, saved_model) dist.barrier() if dist.get_rank() == 0: shutil.rmtree(model_dir) print(f"{dist.get_rank()} test passed") def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_zero_with_original_model() @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_mistral(world_size): spawn(run_dist, world_size) if __name__ == "__main__": test_mistral(world_size=4)