mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
5.0 KiB
147 lines
5.0 KiB
from copy import deepcopy
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
|
from torch.optim import Adam
|
|
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
|
|
|
import colossalai
|
|
from colossalai.booster import Booster
|
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
|
from colossalai.testing.utils import spawn
|
|
|
|
tokens, n_experts = 7, 4
|
|
hidden_size = 8
|
|
top_k = 2
|
|
|
|
|
|
def check_model_equal(model1, model2):
|
|
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
|
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
|
assert torch.equal(p1.half(), p2.half())
|
|
|
|
|
|
def get_optimizer_snapshot(optim):
|
|
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
|
param_groups = []
|
|
for group in optim.param_groups:
|
|
params = [id(p) for p in group["params"]]
|
|
new_group = {"params": params}
|
|
for k, v in group.items():
|
|
if k != "params":
|
|
new_group[k] = v
|
|
param_groups.append(new_group)
|
|
return {
|
|
"state": state,
|
|
"param_groups": param_groups,
|
|
}
|
|
|
|
|
|
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
|
|
# check param_groups
|
|
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
|
|
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
|
|
assert set(group1.keys()) == set(group2.keys())
|
|
for k in group1.keys():
|
|
assert group1[k] == group2[k]
|
|
# check state
|
|
assert set(snapshot1["state"].keys()) == set(
|
|
snapshot2["state"].keys()
|
|
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
|
|
for pid in snapshot1["state"].keys():
|
|
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
|
|
assert set(state1.keys()) == set(state2.keys())
|
|
for k in state1.keys():
|
|
if isinstance(state1[k], torch.Tensor):
|
|
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
|
|
else:
|
|
assert state1[k] == state2[k]
|
|
|
|
|
|
def check_mixtral_moe_layer():
|
|
torch.cuda.set_device(dist.get_rank())
|
|
config = MixtralConfig(
|
|
hidden_size=hidden_size,
|
|
intermediate_size=hidden_size * 2,
|
|
num_local_experts=n_experts,
|
|
num_experts_per_tok=top_k,
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
)
|
|
torch.manual_seed(0)
|
|
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
|
orig_model = MixtralForCausalLM(config).cuda()
|
|
model = deepcopy(orig_model)
|
|
optimizer = Adam(model.parameters(), lr=1e-3)
|
|
plugin = MoeHybridParallelPlugin(
|
|
tp_size=1,
|
|
pp_size=2,
|
|
ep_size=2,
|
|
custom_policy=MixtralForCausalLMPolicy(),
|
|
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
|
microbatch_size=1,
|
|
zero_stage=1,
|
|
)
|
|
booster = Booster(plugin=plugin)
|
|
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
|
# initialize grads
|
|
data_iter = iter(
|
|
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
|
|
)
|
|
booster.execute_pipeline(
|
|
data_iter,
|
|
model,
|
|
lambda outputs, inputs: outputs.loss,
|
|
optimizer,
|
|
)
|
|
|
|
# check save model
|
|
booster.save_model(model, "mixtral_model", shard=True)
|
|
dist.barrier()
|
|
if dist.get_rank() == 0:
|
|
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
|
|
check_model_equal(orig_model, saved_model)
|
|
saved_model.save_pretrained("mixtral_hf_model")
|
|
dist.barrier()
|
|
|
|
# check load model
|
|
new_model = MixtralForCausalLM(config).cuda()
|
|
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
|
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
|
booster.load_model(new_model, "mixtral_hf_model")
|
|
check_model_equal(model, new_model)
|
|
|
|
# check save optimizer
|
|
optimizer.step()
|
|
for group in optimizer.param_groups:
|
|
group["lr"] = 0.1
|
|
snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
|
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
|
|
dist.barrier()
|
|
# reset optimizer state
|
|
for state in optimizer.unwrap().state.values():
|
|
for v in state.values():
|
|
if isinstance(v, torch.Tensor):
|
|
v.zero_()
|
|
booster.load_optimizer(optimizer, "mixtral_optim")
|
|
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
|
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
|
|
|
|
|
|
def run_dist(rank: int, world_size: int, port: int):
|
|
colossalai.launch(rank, world_size, "localhost", port)
|
|
check_mixtral_moe_layer()
|
|
|
|
|
|
@pytest.mark.parametrize("world_size", [4])
|
|
def test_mixtral_moe_layer(world_size: int):
|
|
spawn(run_dist, world_size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_mixtral_moe_layer(4)
|