update pytest

pull/5190/head
Xuanlei Zhao 2023-12-27 16:05:00 +08:00
parent 54b197cc02
commit 570f5cd693
3 changed files with 53 additions and 69 deletions

View File

@ -142,38 +142,37 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
@torch.no_grad()
def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict()
for name, param in model.named_parameters():
if ".experts." in name:
if ".experts.gate_weight" in name:
new_name = name.replace(".experts.gate_weight", ".experts.gate.weight")
state_dict[new_name] = state_dict.pop(name).cpu()
elif ".experts." in name and is_moe_tensor(param):
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
for name, param in list(model.named_parameters()):
if ".gate_weight" in name:
new_name = name.replace(".gate_weight", ".gate.weight")
state_dict[new_name] = state_dict.pop(name).cpu()
elif ".experts." in name:
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
assert all_param.shape[0] == 8
for i in range(8):
if ".wi_gate" in name:
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
elif ".wi_up" in name:
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
elif ".wo" in name:
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
new_name = new_name.replace("module.", "")
new_param = all_param[i].transpose(-1, -2)
state_dict[new_name] = new_param.cpu()
state_dict.pop(name)
else:
state_dict[name] = param.cpu()
if dp_rank == 0:
param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
assert all_param.shape[0] == 8
for i in range(8):
if ".wi_gate" in name:
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
elif ".wi_up" in name:
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
elif ".wo" in name:
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
new_name = new_name.replace("module.", "")
new_param = all_param[i].transpose(-1, -2)
state_dict[new_name] = new_param.cpu()
state_dict.pop(name)
else:
state_dict[name] = param.cpu()
for name, param in list(state_dict.items()):
new_name = name.replace("module.", "")
@ -186,9 +185,9 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
# and gather them one by one
new_state_dict = {}
state_dict_keys = list(state_dict.keys())
gap_keys = len(state_dict_keys) // 10
gap_keys = len(state_dict_keys) // 10 + 1
for i in range(10):
cur_keys = state_dict_keys[(i - 1) * gap_keys : i * gap_keys]
cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
cur_state_dict = {}
for k in cur_keys:
cur_state_dict[k] = state_dict[k]

View File

@ -44,17 +44,17 @@ class MixtralSparseMLP:
router_top_k=module.top_k,
router_norm=True,
router_loss=False,
# router_capacity_factor_train = .
# router_capacity_factor_eval = .
# router_capacity_factor_train=
# router_capacity_factor_eval=
mlp_activation="silu",
mlp_gated=True,
# enable_load_balance = .
# load_balance_tolerance = .
# load_balance_beam_width = .
# load_balance_group_swap_factor = .
# enable_load_balance=
# load_balance_tolerance=
# load_balance_beam_width=
# load_balance_group_swap_factor=
enable_kernel=enable_kernel,
# enable_comm_overlap = .
# enable_hierarchical_comm = .
# enable_comm_overlap=
# enable_hierarchical_comm=
return_gate_logits=True,
)
dtype = module.gate.weight.dtype

View File

@ -1,11 +1,11 @@
import os
import shutil
import sys
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@ -16,13 +16,6 @@ from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
sys.path.append(
os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"examples/language/openmoe",
)
)
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
@ -71,7 +64,7 @@ def get_config():
config = MixtralConfig(
vocab_size=300,
hidden_size=32,
intermediate_size=128,
intermediate_size=16,
num_hidden_layers=2,
dropout_rate=0.0,
)
@ -81,6 +74,7 @@ def get_config():
def get_model(parallel):
config = get_config()
model = MixtralForCausalLM(config).to(torch.bfloat16)
replace_moe_layer(model)
optim = torch.optim.Adam(model.parameters())
args = dict(
precision="bf16",
@ -89,22 +83,11 @@ def get_model(parallel):
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoECheckpointIO,
)
if parallel == None:
if parallel == "ep":
plugin = MoeHybridParallelPlugin(
pp_size=1,
**args,
)
elif parallel == "ep":
plugin = MoeHybridParallelPlugin(
pp_size=1,
**args,
)
elif parallel == "ep_zero":
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=2,
**args,
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
pp_size=2,
@ -117,6 +100,13 @@ def get_model(parallel):
def _test_moe_checkpoint(parallel):
if dist.get_rank() == 0:
if os.path.exists("./tmp_ckpt1"):
shutil.rmtree("./tmp_ckpt1")
if os.path.exists("./tmp_ckpt2"):
shutil.rmtree("./tmp_ckpt2")
dist.barrier()
if parallel == None:
MOE_MANAGER.setup(
parallel=None,
@ -125,11 +115,6 @@ def _test_moe_checkpoint(parallel):
MOE_MANAGER.setup(
parallel="EP",
)
elif parallel == "ep_zero":
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=2,
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
parallel="EP",
@ -184,11 +169,11 @@ def _run_dist(rank, world_size, port, parallel):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", ["ep", "ep_zero"])
@pytest.mark.parametrize("parallel", ["ep", "hybrid"])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="ep")
test_moe_checkpoint(world_size=4, parallel="hybrid")