mirror of https://github.com/hpcaitech/ColossalAI
update pytest
parent
54b197cc02
commit
570f5cd693
|
@ -142,38 +142,37 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def pre_save_model(self, model: nn.Module) -> dict:
|
def pre_save_model(self, model: nn.Module) -> dict:
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
for name, param in model.named_parameters():
|
for name, param in list(model.named_parameters()):
|
||||||
if ".experts." in name:
|
if ".gate_weight" in name:
|
||||||
if ".experts.gate_weight" in name:
|
new_name = name.replace(".gate_weight", ".gate.weight")
|
||||||
new_name = name.replace(".experts.gate_weight", ".experts.gate.weight")
|
state_dict[new_name] = state_dict.pop(name).cpu()
|
||||||
state_dict[new_name] = state_dict.pop(name).cpu()
|
elif ".experts." in name:
|
||||||
elif ".experts." in name and is_moe_tensor(param):
|
ep_group = get_ep_group(param)
|
||||||
ep_group = get_ep_group(param)
|
ep_rank = get_ep_rank(param)
|
||||||
ep_rank = get_ep_rank(param)
|
ep_size = get_ep_size(param)
|
||||||
ep_size = get_ep_size(param)
|
dp_rank = get_dp_rank(param)
|
||||||
dp_rank = get_dp_rank(param)
|
|
||||||
|
|
||||||
if dp_rank == 0:
|
if dp_rank == 0:
|
||||||
param = param.data.cuda()
|
param = param.data.cuda()
|
||||||
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||||
# gather param from every ep rank
|
# gather param from every ep rank
|
||||||
dist.all_gather(all_param, param, group=ep_group)
|
dist.all_gather(all_param, param, group=ep_group)
|
||||||
if ep_rank == 0:
|
if ep_rank == 0:
|
||||||
all_param = torch.cat(all_param, dim=0)
|
all_param = torch.cat(all_param, dim=0)
|
||||||
assert all_param.shape[0] == 8
|
assert all_param.shape[0] == 8
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
if ".wi_gate" in name:
|
if ".wi_gate" in name:
|
||||||
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||||
elif ".wi_up" in name:
|
elif ".wi_up" in name:
|
||||||
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||||
elif ".wo" in name:
|
elif ".wo" in name:
|
||||||
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||||
new_name = new_name.replace("module.", "")
|
new_name = new_name.replace("module.", "")
|
||||||
new_param = all_param[i].transpose(-1, -2)
|
new_param = all_param[i].transpose(-1, -2)
|
||||||
state_dict[new_name] = new_param.cpu()
|
state_dict[new_name] = new_param.cpu()
|
||||||
state_dict.pop(name)
|
state_dict.pop(name)
|
||||||
else:
|
else:
|
||||||
state_dict[name] = param.cpu()
|
state_dict[name] = param.cpu()
|
||||||
|
|
||||||
for name, param in list(state_dict.items()):
|
for name, param in list(state_dict.items()):
|
||||||
new_name = name.replace("module.", "")
|
new_name = name.replace("module.", "")
|
||||||
|
@ -186,9 +185,9 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||||
# and gather them one by one
|
# and gather them one by one
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
gap_keys = len(state_dict_keys) // 10
|
gap_keys = len(state_dict_keys) // 10 + 1
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
cur_keys = state_dict_keys[(i - 1) * gap_keys : i * gap_keys]
|
cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
|
||||||
cur_state_dict = {}
|
cur_state_dict = {}
|
||||||
for k in cur_keys:
|
for k in cur_keys:
|
||||||
cur_state_dict[k] = state_dict[k]
|
cur_state_dict[k] = state_dict[k]
|
||||||
|
|
|
@ -44,17 +44,17 @@ class MixtralSparseMLP:
|
||||||
router_top_k=module.top_k,
|
router_top_k=module.top_k,
|
||||||
router_norm=True,
|
router_norm=True,
|
||||||
router_loss=False,
|
router_loss=False,
|
||||||
# router_capacity_factor_train = .
|
# router_capacity_factor_train=
|
||||||
# router_capacity_factor_eval = .
|
# router_capacity_factor_eval=
|
||||||
mlp_activation="silu",
|
mlp_activation="silu",
|
||||||
mlp_gated=True,
|
mlp_gated=True,
|
||||||
# enable_load_balance = .
|
# enable_load_balance=
|
||||||
# load_balance_tolerance = .
|
# load_balance_tolerance=
|
||||||
# load_balance_beam_width = .
|
# load_balance_beam_width=
|
||||||
# load_balance_group_swap_factor = .
|
# load_balance_group_swap_factor=
|
||||||
enable_kernel=enable_kernel,
|
enable_kernel=enable_kernel,
|
||||||
# enable_comm_overlap = .
|
# enable_comm_overlap=
|
||||||
# enable_hierarchical_comm = .
|
# enable_hierarchical_comm=
|
||||||
return_gate_logits=True,
|
return_gate_logits=True,
|
||||||
)
|
)
|
||||||
dtype = module.gate.weight.dtype
|
dtype = module.gate.weight.dtype
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
|
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
|
@ -16,13 +16,6 @@ from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
sys.path.append(
|
|
||||||
os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
|
||||||
"examples/language/openmoe",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
||||||
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
|
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
|
||||||
|
@ -71,7 +64,7 @@ def get_config():
|
||||||
config = MixtralConfig(
|
config = MixtralConfig(
|
||||||
vocab_size=300,
|
vocab_size=300,
|
||||||
hidden_size=32,
|
hidden_size=32,
|
||||||
intermediate_size=128,
|
intermediate_size=16,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
dropout_rate=0.0,
|
dropout_rate=0.0,
|
||||||
)
|
)
|
||||||
|
@ -81,6 +74,7 @@ def get_config():
|
||||||
def get_model(parallel):
|
def get_model(parallel):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
model = MixtralForCausalLM(config).to(torch.bfloat16)
|
model = MixtralForCausalLM(config).to(torch.bfloat16)
|
||||||
|
replace_moe_layer(model)
|
||||||
optim = torch.optim.Adam(model.parameters())
|
optim = torch.optim.Adam(model.parameters())
|
||||||
args = dict(
|
args = dict(
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
|
@ -89,22 +83,11 @@ def get_model(parallel):
|
||||||
custom_policy=MixtralForCausalLMPolicy(),
|
custom_policy=MixtralForCausalLMPolicy(),
|
||||||
checkpoint_io=MixtralMoECheckpointIO,
|
checkpoint_io=MixtralMoECheckpointIO,
|
||||||
)
|
)
|
||||||
if parallel == None:
|
if parallel == "ep":
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
pp_size=1,
|
pp_size=1,
|
||||||
**args,
|
**args,
|
||||||
)
|
)
|
||||||
elif parallel == "ep":
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
|
||||||
pp_size=1,
|
|
||||||
**args,
|
|
||||||
)
|
|
||||||
elif parallel == "ep_zero":
|
|
||||||
plugin = MoeHybridParallelPlugin(
|
|
||||||
pp_size=1,
|
|
||||||
extra_dp_size=2,
|
|
||||||
**args,
|
|
||||||
)
|
|
||||||
elif parallel == "hybrid":
|
elif parallel == "hybrid":
|
||||||
plugin = MoeHybridParallelPlugin(
|
plugin = MoeHybridParallelPlugin(
|
||||||
pp_size=2,
|
pp_size=2,
|
||||||
|
@ -117,6 +100,13 @@ def get_model(parallel):
|
||||||
|
|
||||||
|
|
||||||
def _test_moe_checkpoint(parallel):
|
def _test_moe_checkpoint(parallel):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
if os.path.exists("./tmp_ckpt1"):
|
||||||
|
shutil.rmtree("./tmp_ckpt1")
|
||||||
|
if os.path.exists("./tmp_ckpt2"):
|
||||||
|
shutil.rmtree("./tmp_ckpt2")
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
if parallel == None:
|
if parallel == None:
|
||||||
MOE_MANAGER.setup(
|
MOE_MANAGER.setup(
|
||||||
parallel=None,
|
parallel=None,
|
||||||
|
@ -125,11 +115,6 @@ def _test_moe_checkpoint(parallel):
|
||||||
MOE_MANAGER.setup(
|
MOE_MANAGER.setup(
|
||||||
parallel="EP",
|
parallel="EP",
|
||||||
)
|
)
|
||||||
elif parallel == "ep_zero":
|
|
||||||
MOE_MANAGER.setup(
|
|
||||||
parallel="EP",
|
|
||||||
max_ep_size=2,
|
|
||||||
)
|
|
||||||
elif parallel == "hybrid":
|
elif parallel == "hybrid":
|
||||||
MOE_MANAGER.setup(
|
MOE_MANAGER.setup(
|
||||||
parallel="EP",
|
parallel="EP",
|
||||||
|
@ -184,11 +169,11 @@ def _run_dist(rank, world_size, port, parallel):
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [4])
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
@pytest.mark.parametrize("parallel", ["ep", "ep_zero"])
|
@pytest.mark.parametrize("parallel", ["ep", "hybrid"])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_moe_checkpoint(world_size, parallel):
|
def test_moe_checkpoint(world_size, parallel):
|
||||||
spawn(_run_dist, world_size, parallel=parallel)
|
spawn(_run_dist, world_size, parallel=parallel)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_checkpoint(world_size=4, parallel="ep")
|
test_moe_checkpoint(world_size=4, parallel="hybrid")
|
||||||
|
|
Loading…
Reference in New Issue