mirror of https://github.com/hpcaitech/ColossalAI
update pytest
parent
54b197cc02
commit
570f5cd693
|
@ -142,38 +142,37 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
|
|||
@torch.no_grad()
|
||||
def pre_save_model(self, model: nn.Module) -> dict:
|
||||
state_dict = model.state_dict()
|
||||
for name, param in model.named_parameters():
|
||||
if ".experts." in name:
|
||||
if ".experts.gate_weight" in name:
|
||||
new_name = name.replace(".experts.gate_weight", ".experts.gate.weight")
|
||||
state_dict[new_name] = state_dict.pop(name).cpu()
|
||||
elif ".experts." in name and is_moe_tensor(param):
|
||||
ep_group = get_ep_group(param)
|
||||
ep_rank = get_ep_rank(param)
|
||||
ep_size = get_ep_size(param)
|
||||
dp_rank = get_dp_rank(param)
|
||||
for name, param in list(model.named_parameters()):
|
||||
if ".gate_weight" in name:
|
||||
new_name = name.replace(".gate_weight", ".gate.weight")
|
||||
state_dict[new_name] = state_dict.pop(name).cpu()
|
||||
elif ".experts." in name:
|
||||
ep_group = get_ep_group(param)
|
||||
ep_rank = get_ep_rank(param)
|
||||
ep_size = get_ep_size(param)
|
||||
dp_rank = get_dp_rank(param)
|
||||
|
||||
if dp_rank == 0:
|
||||
param = param.data.cuda()
|
||||
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||
# gather param from every ep rank
|
||||
dist.all_gather(all_param, param, group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
assert all_param.shape[0] == 8
|
||||
for i in range(8):
|
||||
if ".wi_gate" in name:
|
||||
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||
elif ".wi_up" in name:
|
||||
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||
elif ".wo" in name:
|
||||
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||
new_name = new_name.replace("module.", "")
|
||||
new_param = all_param[i].transpose(-1, -2)
|
||||
state_dict[new_name] = new_param.cpu()
|
||||
state_dict.pop(name)
|
||||
else:
|
||||
state_dict[name] = param.cpu()
|
||||
if dp_rank == 0:
|
||||
param = param.data.cuda()
|
||||
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||
# gather param from every ep rank
|
||||
dist.all_gather(all_param, param, group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
assert all_param.shape[0] == 8
|
||||
for i in range(8):
|
||||
if ".wi_gate" in name:
|
||||
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||
elif ".wi_up" in name:
|
||||
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||
elif ".wo" in name:
|
||||
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||
new_name = new_name.replace("module.", "")
|
||||
new_param = all_param[i].transpose(-1, -2)
|
||||
state_dict[new_name] = new_param.cpu()
|
||||
state_dict.pop(name)
|
||||
else:
|
||||
state_dict[name] = param.cpu()
|
||||
|
||||
for name, param in list(state_dict.items()):
|
||||
new_name = name.replace("module.", "")
|
||||
|
@ -186,9 +185,9 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
|
|||
# and gather them one by one
|
||||
new_state_dict = {}
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
gap_keys = len(state_dict_keys) // 10
|
||||
gap_keys = len(state_dict_keys) // 10 + 1
|
||||
for i in range(10):
|
||||
cur_keys = state_dict_keys[(i - 1) * gap_keys : i * gap_keys]
|
||||
cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
|
||||
cur_state_dict = {}
|
||||
for k in cur_keys:
|
||||
cur_state_dict[k] = state_dict[k]
|
||||
|
|
|
@ -44,17 +44,17 @@ class MixtralSparseMLP:
|
|||
router_top_k=module.top_k,
|
||||
router_norm=True,
|
||||
router_loss=False,
|
||||
# router_capacity_factor_train = .
|
||||
# router_capacity_factor_eval = .
|
||||
# router_capacity_factor_train=
|
||||
# router_capacity_factor_eval=
|
||||
mlp_activation="silu",
|
||||
mlp_gated=True,
|
||||
# enable_load_balance = .
|
||||
# load_balance_tolerance = .
|
||||
# load_balance_beam_width = .
|
||||
# load_balance_group_swap_factor = .
|
||||
# enable_load_balance=
|
||||
# load_balance_tolerance=
|
||||
# load_balance_beam_width=
|
||||
# load_balance_group_swap_factor=
|
||||
enable_kernel=enable_kernel,
|
||||
# enable_comm_overlap = .
|
||||
# enable_hierarchical_comm = .
|
||||
# enable_comm_overlap=
|
||||
# enable_hierarchical_comm=
|
||||
return_gate_logits=True,
|
||||
)
|
||||
dtype = module.gate.weight.dtype
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
|
@ -16,13 +16,6 @@ from colossalai.moe.manager import MOE_MANAGER
|
|||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
sys.path.append(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"examples/language/openmoe",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
|
||||
|
@ -71,7 +64,7 @@ def get_config():
|
|||
config = MixtralConfig(
|
||||
vocab_size=300,
|
||||
hidden_size=32,
|
||||
intermediate_size=128,
|
||||
intermediate_size=16,
|
||||
num_hidden_layers=2,
|
||||
dropout_rate=0.0,
|
||||
)
|
||||
|
@ -81,6 +74,7 @@ def get_config():
|
|||
def get_model(parallel):
|
||||
config = get_config()
|
||||
model = MixtralForCausalLM(config).to(torch.bfloat16)
|
||||
replace_moe_layer(model)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
args = dict(
|
||||
precision="bf16",
|
||||
|
@ -89,22 +83,11 @@ def get_model(parallel):
|
|||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoECheckpointIO,
|
||||
)
|
||||
if parallel == None:
|
||||
if parallel == "ep":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
**args,
|
||||
)
|
||||
elif parallel == "ep":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
**args,
|
||||
)
|
||||
elif parallel == "ep_zero":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
extra_dp_size=2,
|
||||
**args,
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=2,
|
||||
|
@ -117,6 +100,13 @@ def get_model(parallel):
|
|||
|
||||
|
||||
def _test_moe_checkpoint(parallel):
|
||||
if dist.get_rank() == 0:
|
||||
if os.path.exists("./tmp_ckpt1"):
|
||||
shutil.rmtree("./tmp_ckpt1")
|
||||
if os.path.exists("./tmp_ckpt2"):
|
||||
shutil.rmtree("./tmp_ckpt2")
|
||||
dist.barrier()
|
||||
|
||||
if parallel == None:
|
||||
MOE_MANAGER.setup(
|
||||
parallel=None,
|
||||
|
@ -125,11 +115,6 @@ def _test_moe_checkpoint(parallel):
|
|||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
)
|
||||
elif parallel == "ep_zero":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=2,
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
|
@ -184,11 +169,11 @@ def _run_dist(rank, world_size, port, parallel):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("parallel", ["ep", "ep_zero"])
|
||||
@pytest.mark.parametrize("parallel", ["ep", "hybrid"])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_checkpoint(world_size, parallel):
|
||||
spawn(_run_dist, world_size, parallel=parallel)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_checkpoint(world_size=4, parallel="ep")
|
||||
test_moe_checkpoint(world_size=4, parallel="hybrid")
|
||||
|
|
Loading…
Reference in New Issue