mirror of https://github.com/hpcaitech/ColossalAI
144 lines
4.6 KiB
Python
144 lines
4.6 KiB
Python
import os
|
|
import shutil
|
|
from copy import deepcopy
|
|
from typing import Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
|
|
import colossalai
|
|
from colossalai.booster.booster import Booster
|
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|
from colossalai.testing.random import seed_all
|
|
from tests.test_moe.moe_utils import loose_close
|
|
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
|
|
|
NUM_BATCH = 4
|
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4
|
|
HIDDEN_SIZE_PER_HEAD = 4
|
|
NUM_HEADS = 4
|
|
TOP_K = 1
|
|
|
|
|
|
@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)])
|
|
def run_zero_with_original_model(config: Tuple[int, ...]):
|
|
ep_size, stage, dp_size, pp_size, tp_size, sp_size = config
|
|
print(config)
|
|
rank = torch.distributed.get_rank()
|
|
dtype, precision = torch.float16, "fp16"
|
|
torch.cuda.set_device(dist.get_rank())
|
|
|
|
plugin = MoeHybridParallelPlugin(
|
|
pp_size=pp_size,
|
|
num_microbatches=pp_size,
|
|
tp_size=tp_size,
|
|
sp_size=sp_size,
|
|
ep_size=ep_size,
|
|
moe_tp_size=tp_size,
|
|
zero_stage=stage,
|
|
enable_sequence_parallelism=sp_size > 1,
|
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
|
overlap_communication=False,
|
|
initial_scale=1,
|
|
precision=precision,
|
|
find_unused_parameters=True,
|
|
)
|
|
booster = Booster(plugin=plugin)
|
|
|
|
seed_all(10086)
|
|
|
|
config = MixtralConfig(
|
|
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
|
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=NUM_HEADS,
|
|
num_key_value_heads=NUM_HEADS,
|
|
num_local_experts=NUM_EXPERTS,
|
|
num_experts_per_tok=TOP_K,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
|
|
torch_model = MixtralModel(config).to(dtype).cuda()
|
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
|
|
|
zero_model = deepcopy(torch_model).to(dtype)
|
|
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
|
|
|
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
|
|
|
# create different input
|
|
seed_all(1453 + rank)
|
|
|
|
torch_model.train()
|
|
zero_model.train()
|
|
for _ in range(2):
|
|
input_data = torch.rand(
|
|
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
|
).cuda()
|
|
|
|
dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input
|
|
dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input
|
|
|
|
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
|
zero_optimizer.backward(zero_output)
|
|
zero_optimizer.step()
|
|
zero_optimizer.zero_grad()
|
|
dist.all_reduce(zero_output)
|
|
|
|
all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
|
|
dist.all_gather(all_inputs, input_data)
|
|
|
|
torch_output_sum = 0
|
|
for input_data_ in all_inputs:
|
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
|
torch_output.backward()
|
|
torch_output_sum += torch_output.detach()
|
|
# avg dp grads
|
|
for p in torch_model.parameters():
|
|
if p.grad is not None:
|
|
p.grad /= dist.get_world_size()
|
|
torch_optimizer.step()
|
|
torch_optimizer.zero_grad()
|
|
|
|
loose_close(zero_output, torch_output_sum, dtype=dtype)
|
|
|
|
# use checkpoint to load sharded zero model
|
|
model_dir = "./test_mixtral"
|
|
if dist.get_rank() == 0:
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
dist.barrier()
|
|
|
|
booster.save_model(zero_model, model_dir, shard=True)
|
|
|
|
dist.barrier()
|
|
|
|
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
|
check_model_equal(torch_model, saved_model)
|
|
|
|
dist.barrier()
|
|
if dist.get_rank() == 0:
|
|
shutil.rmtree(model_dir)
|
|
|
|
print(f"{dist.get_rank()} test passed")
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
run_zero_with_original_model()
|
|
|
|
|
|
@pytest.mark.dist
|
|
@pytest.mark.parametrize("world_size", [8])
|
|
@rerun_if_address_is_in_use()
|
|
def test_mistral(world_size):
|
|
spawn(run_dist, world_size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_mistral(world_size=8)
|