update pytest

pull/5190/head
Xuanlei Zhao 2023-12-27 16:05:00 +08:00
parent 54b197cc02
commit 570f5cd693
3 changed files with 53 additions and 69 deletions

View File

@ -142,38 +142,37 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
@torch.no_grad() @torch.no_grad()
def pre_save_model(self, model: nn.Module) -> dict: def pre_save_model(self, model: nn.Module) -> dict:
state_dict = model.state_dict() state_dict = model.state_dict()
for name, param in model.named_parameters(): for name, param in list(model.named_parameters()):
if ".experts." in name: if ".gate_weight" in name:
if ".experts.gate_weight" in name: new_name = name.replace(".gate_weight", ".gate.weight")
new_name = name.replace(".experts.gate_weight", ".experts.gate.weight") state_dict[new_name] = state_dict.pop(name).cpu()
state_dict[new_name] = state_dict.pop(name).cpu() elif ".experts." in name:
elif ".experts." in name and is_moe_tensor(param): ep_group = get_ep_group(param)
ep_group = get_ep_group(param) ep_rank = get_ep_rank(param)
ep_rank = get_ep_rank(param) ep_size = get_ep_size(param)
ep_size = get_ep_size(param) dp_rank = get_dp_rank(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0: if dp_rank == 0:
param = param.data.cuda() param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)] all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank # gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group) dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0: if ep_rank == 0:
all_param = torch.cat(all_param, dim=0) all_param = torch.cat(all_param, dim=0)
assert all_param.shape[0] == 8 assert all_param.shape[0] == 8
for i in range(8): for i in range(8):
if ".wi_gate" in name: if ".wi_gate" in name:
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight") new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
elif ".wi_up" in name: elif ".wi_up" in name:
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight") new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
elif ".wo" in name: elif ".wo" in name:
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight") new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
new_name = new_name.replace("module.", "") new_name = new_name.replace("module.", "")
new_param = all_param[i].transpose(-1, -2) new_param = all_param[i].transpose(-1, -2)
state_dict[new_name] = new_param.cpu() state_dict[new_name] = new_param.cpu()
state_dict.pop(name) state_dict.pop(name)
else: else:
state_dict[name] = param.cpu() state_dict[name] = param.cpu()
for name, param in list(state_dict.items()): for name, param in list(state_dict.items()):
new_name = name.replace("module.", "") new_name = name.replace("module.", "")
@ -186,9 +185,9 @@ class MixtralMoECheckpointIO(MoECheckpintIO):
# and gather them one by one # and gather them one by one
new_state_dict = {} new_state_dict = {}
state_dict_keys = list(state_dict.keys()) state_dict_keys = list(state_dict.keys())
gap_keys = len(state_dict_keys) // 10 gap_keys = len(state_dict_keys) // 10 + 1
for i in range(10): for i in range(10):
cur_keys = state_dict_keys[(i - 1) * gap_keys : i * gap_keys] cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
cur_state_dict = {} cur_state_dict = {}
for k in cur_keys: for k in cur_keys:
cur_state_dict[k] = state_dict[k] cur_state_dict[k] = state_dict[k]

View File

@ -44,17 +44,17 @@ class MixtralSparseMLP:
router_top_k=module.top_k, router_top_k=module.top_k,
router_norm=True, router_norm=True,
router_loss=False, router_loss=False,
# router_capacity_factor_train = . # router_capacity_factor_train=
# router_capacity_factor_eval = . # router_capacity_factor_eval=
mlp_activation="silu", mlp_activation="silu",
mlp_gated=True, mlp_gated=True,
# enable_load_balance = . # enable_load_balance=
# load_balance_tolerance = . # load_balance_tolerance=
# load_balance_beam_width = . # load_balance_beam_width=
# load_balance_group_swap_factor = . # load_balance_group_swap_factor=
enable_kernel=enable_kernel, enable_kernel=enable_kernel,
# enable_comm_overlap = . # enable_comm_overlap=
# enable_hierarchical_comm = . # enable_hierarchical_comm=
return_gate_logits=True, return_gate_logits=True,
) )
dtype = module.gate.weight.dtype dtype = module.gate.weight.dtype

View File

@ -1,11 +1,11 @@
import os import os
import shutil import shutil
import sys
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@ -16,13 +16,6 @@ from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
sys.path.append(
os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"examples/language/openmoe",
)
)
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
@ -71,7 +64,7 @@ def get_config():
config = MixtralConfig( config = MixtralConfig(
vocab_size=300, vocab_size=300,
hidden_size=32, hidden_size=32,
intermediate_size=128, intermediate_size=16,
num_hidden_layers=2, num_hidden_layers=2,
dropout_rate=0.0, dropout_rate=0.0,
) )
@ -81,6 +74,7 @@ def get_config():
def get_model(parallel): def get_model(parallel):
config = get_config() config = get_config()
model = MixtralForCausalLM(config).to(torch.bfloat16) model = MixtralForCausalLM(config).to(torch.bfloat16)
replace_moe_layer(model)
optim = torch.optim.Adam(model.parameters()) optim = torch.optim.Adam(model.parameters())
args = dict( args = dict(
precision="bf16", precision="bf16",
@ -89,22 +83,11 @@ def get_model(parallel):
custom_policy=MixtralForCausalLMPolicy(), custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoECheckpointIO, checkpoint_io=MixtralMoECheckpointIO,
) )
if parallel == None: if parallel == "ep":
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
pp_size=1, pp_size=1,
**args, **args,
) )
elif parallel == "ep":
plugin = MoeHybridParallelPlugin(
pp_size=1,
**args,
)
elif parallel == "ep_zero":
plugin = MoeHybridParallelPlugin(
pp_size=1,
extra_dp_size=2,
**args,
)
elif parallel == "hybrid": elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin( plugin = MoeHybridParallelPlugin(
pp_size=2, pp_size=2,
@ -117,6 +100,13 @@ def get_model(parallel):
def _test_moe_checkpoint(parallel): def _test_moe_checkpoint(parallel):
if dist.get_rank() == 0:
if os.path.exists("./tmp_ckpt1"):
shutil.rmtree("./tmp_ckpt1")
if os.path.exists("./tmp_ckpt2"):
shutil.rmtree("./tmp_ckpt2")
dist.barrier()
if parallel == None: if parallel == None:
MOE_MANAGER.setup( MOE_MANAGER.setup(
parallel=None, parallel=None,
@ -125,11 +115,6 @@ def _test_moe_checkpoint(parallel):
MOE_MANAGER.setup( MOE_MANAGER.setup(
parallel="EP", parallel="EP",
) )
elif parallel == "ep_zero":
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=2,
)
elif parallel == "hybrid": elif parallel == "hybrid":
MOE_MANAGER.setup( MOE_MANAGER.setup(
parallel="EP", parallel="EP",
@ -184,11 +169,11 @@ def _run_dist(rank, world_size, port, parallel):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", ["ep", "ep_zero"]) @pytest.mark.parametrize("parallel", ["ep", "hybrid"])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel): def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel) spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="ep") test_moe_checkpoint(world_size=4, parallel="hybrid")