diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py index 3047923fc..8a7087330 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py @@ -142,38 +142,37 @@ class MixtralMoECheckpointIO(MoECheckpintIO): @torch.no_grad() def pre_save_model(self, model: nn.Module) -> dict: state_dict = model.state_dict() - for name, param in model.named_parameters(): - if ".experts." in name: - if ".experts.gate_weight" in name: - new_name = name.replace(".experts.gate_weight", ".experts.gate.weight") - state_dict[new_name] = state_dict.pop(name).cpu() - elif ".experts." in name and is_moe_tensor(param): - ep_group = get_ep_group(param) - ep_rank = get_ep_rank(param) - ep_size = get_ep_size(param) - dp_rank = get_dp_rank(param) + for name, param in list(model.named_parameters()): + if ".gate_weight" in name: + new_name = name.replace(".gate_weight", ".gate.weight") + state_dict[new_name] = state_dict.pop(name).cpu() + elif ".experts." in name: + ep_group = get_ep_group(param) + ep_rank = get_ep_rank(param) + ep_size = get_ep_size(param) + dp_rank = get_dp_rank(param) - if dp_rank == 0: - param = param.data.cuda() - all_param = [torch.zeros_like(param) for _ in range(ep_size)] - # gather param from every ep rank - dist.all_gather(all_param, param, group=ep_group) - if ep_rank == 0: - all_param = torch.cat(all_param, dim=0) - assert all_param.shape[0] == 8 - for i in range(8): - if ".wi_gate" in name: - new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight") - elif ".wi_up" in name: - new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight") - elif ".wo" in name: - new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight") - new_name = new_name.replace("module.", "") - new_param = all_param[i].transpose(-1, -2) - state_dict[new_name] = new_param.cpu() - state_dict.pop(name) - else: - state_dict[name] = param.cpu() + if dp_rank == 0: + param = param.data.cuda() + all_param = [torch.zeros_like(param) for _ in range(ep_size)] + # gather param from every ep rank + dist.all_gather(all_param, param, group=ep_group) + if ep_rank == 0: + all_param = torch.cat(all_param, dim=0) + assert all_param.shape[0] == 8 + for i in range(8): + if ".wi_gate" in name: + new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight") + elif ".wi_up" in name: + new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight") + elif ".wo" in name: + new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight") + new_name = new_name.replace("module.", "") + new_param = all_param[i].transpose(-1, -2) + state_dict[new_name] = new_param.cpu() + state_dict.pop(name) + else: + state_dict[name] = param.cpu() for name, param in list(state_dict.items()): new_name = name.replace("module.", "") @@ -186,9 +185,9 @@ class MixtralMoECheckpointIO(MoECheckpintIO): # and gather them one by one new_state_dict = {} state_dict_keys = list(state_dict.keys()) - gap_keys = len(state_dict_keys) // 10 + gap_keys = len(state_dict_keys) // 10 + 1 for i in range(10): - cur_keys = state_dict_keys[(i - 1) * gap_keys : i * gap_keys] + cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys] cur_state_dict = {} for k in cur_keys: cur_state_dict[k] = state_dict[k] diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py index 6d61c501a..e395c8578 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py @@ -44,17 +44,17 @@ class MixtralSparseMLP: router_top_k=module.top_k, router_norm=True, router_loss=False, - # router_capacity_factor_train = . - # router_capacity_factor_eval = . + # router_capacity_factor_train= + # router_capacity_factor_eval= mlp_activation="silu", mlp_gated=True, - # enable_load_balance = . - # load_balance_tolerance = . - # load_balance_beam_width = . - # load_balance_group_swap_factor = . + # enable_load_balance= + # load_balance_tolerance= + # load_balance_beam_width= + # load_balance_group_swap_factor= enable_kernel=enable_kernel, - # enable_comm_overlap = . - # enable_hierarchical_comm = . + # enable_comm_overlap= + # enable_hierarchical_comm= return_gate_logits=True, ) dtype = module.gate.weight.dtype diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py index d365b7a54..772cbb977 100644 --- a/applications/ColossalMoE/tests/test_moe_checkpoint.py +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -1,11 +1,11 @@ import os import shutil -import sys import pytest import torch import torch.distributed as dist from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO +from colossal_moe.models.mixtral_layer import replace_moe_layer from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM @@ -16,13 +16,6 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -sys.path.append( - os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), - "examples/language/openmoe", - ) -) - def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) @@ -71,7 +64,7 @@ def get_config(): config = MixtralConfig( vocab_size=300, hidden_size=32, - intermediate_size=128, + intermediate_size=16, num_hidden_layers=2, dropout_rate=0.0, ) @@ -81,6 +74,7 @@ def get_config(): def get_model(parallel): config = get_config() model = MixtralForCausalLM(config).to(torch.bfloat16) + replace_moe_layer(model) optim = torch.optim.Adam(model.parameters()) args = dict( precision="bf16", @@ -89,22 +83,11 @@ def get_model(parallel): custom_policy=MixtralForCausalLMPolicy(), checkpoint_io=MixtralMoECheckpointIO, ) - if parallel == None: + if parallel == "ep": plugin = MoeHybridParallelPlugin( pp_size=1, **args, ) - elif parallel == "ep": - plugin = MoeHybridParallelPlugin( - pp_size=1, - **args, - ) - elif parallel == "ep_zero": - plugin = MoeHybridParallelPlugin( - pp_size=1, - extra_dp_size=2, - **args, - ) elif parallel == "hybrid": plugin = MoeHybridParallelPlugin( pp_size=2, @@ -117,6 +100,13 @@ def get_model(parallel): def _test_moe_checkpoint(parallel): + if dist.get_rank() == 0: + if os.path.exists("./tmp_ckpt1"): + shutil.rmtree("./tmp_ckpt1") + if os.path.exists("./tmp_ckpt2"): + shutil.rmtree("./tmp_ckpt2") + dist.barrier() + if parallel == None: MOE_MANAGER.setup( parallel=None, @@ -125,11 +115,6 @@ def _test_moe_checkpoint(parallel): MOE_MANAGER.setup( parallel="EP", ) - elif parallel == "ep_zero": - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=2, - ) elif parallel == "hybrid": MOE_MANAGER.setup( parallel="EP", @@ -184,11 +169,11 @@ def _run_dist(rank, world_size, port, parallel): @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) -@pytest.mark.parametrize("parallel", ["ep", "ep_zero"]) +@pytest.mark.parametrize("parallel", ["ep", "hybrid"]) @rerun_if_address_is_in_use() def test_moe_checkpoint(world_size, parallel): spawn(_run_dist, world_size, parallel=parallel) if __name__ == "__main__": - test_moe_checkpoint(world_size=4, parallel="ep") + test_moe_checkpoint(world_size=4, parallel="hybrid")