From da39d21b71b79462a0f922a3cb8ca480a06743ed Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Thu, 25 Jan 2024 15:48:46 +0800
Subject: [PATCH] [moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer
---
 .../colossal_moe/models/mixtral_checkpoint.py | 750 ++++++++++++++----
 .../colossal_moe/models/mixtral_layer.py      | 146 ++--
 .../colossal_moe/models/mixtral_policy.py     |  14 +
 .../ColossalMoE/colossal_moe/utils.py         |  20 +-
 applications/ColossalMoE/infer.py             |  72 +-
 .../ColossalMoE/tests/test_mixtral_layer.py   |  63 ++
 .../ColossalMoE/tests/test_moe_checkpoint.py  | 269 +++----
 .../ColossalMoE/tests/test_moe_layer.py       |  31 -
 applications/ColossalMoE/train.py             |  71 +-
 .../plugin/moe_hybrid_parallel_plugin.py      |  21 +-
 .../checkpoint_io/checkpoint_io_base.py       |  12 +-
 colossalai/moe/_operation.py                  |  67 +-
 colossalai/tensor/moe_tensor/moe_info.py      |   2 +
 colossalai/zero/low_level/low_level_optim.py  |   8 +-
 14 files changed, 996 insertions(+), 550 deletions(-)
 create mode 100644 applications/ColossalMoE/tests/test_mixtral_layer.py
 delete mode 100644 applications/ColossalMoE/tests/test_moe_layer.py

diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
index ddef565c5..635eebd89 100644
--- a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
@@ -1,205 +1,617 @@
+import copy
 import logging
 import os
 from pathlib import Path
+from shutil import rmtree
+from typing import Dict, Iterator, Optional, OrderedDict, Tuple
 
 import torch
 import torch.distributed as dist
 import torch.nn as nn
+from torch.distributed import ProcessGroup
 
 from colossalai.checkpoint_io import CheckpointIndexFile
-from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
-from colossalai.moe import MoECheckpintIO
-from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
+from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
+from colossalai.checkpoint_io.index_file import CheckpointIndexFile
+from colossalai.checkpoint_io.utils import (
+    StateDictSharder,
+    gather_distributed_param,
+    get_model_base_filenames,
+    get_optimizer_base_filenames,
+    load_shard_state_dict,
+    load_states_into_optimizer,
+    save_config_file,
+    save_param_groups,
+    save_state_dict_shards,
+    search_tp_partition_dim,
+    sharded_optimizer_loading_epilogue,
+)
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.moe import MOE_MANAGER
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
+
+try:
+    from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
+except ImportError:
+    _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
 
 
-class MixtralMoECheckpointIO(MoECheckpintIO):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
+class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
+    def __init__(
+        self,
+        dp_group: ProcessGroup,
+        pp_group: ProcessGroup,
+        tp_group: ProcessGroup,
+        zero_stage: int,
+        verbose: bool = True,
+    ) -> None:
+        super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
+        moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
+        self.ep_group = moe_info.ep_group
+        self.ep_size = moe_info.ep_size
+        self.ep_rank = moe_info.ep_rank
+        self.real_dp_rank = moe_info.dp_rank
 
-    @torch.no_grad()
-    def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
+    @staticmethod
+    def _model_sharder(
+        model: nn.Module,
+        prefix: str = "",
+        keep_vars: bool = False,
+        size_per_shard: int = 1024,
+        param_name_pattern: Optional[str] = None,
+    ) -> Iterator[Tuple[OrderedDict, int]]:
+        # An internel method that breaks state_dict of model into shards within limited size.
+
+        state_dict_sharder = StateDictSharder(size_per_shard)
+
+        # Save parameters.
+        for name, param in model.named_parameters():
+            if param is None:
+                continue
+            if param_name_pattern is not None and param_name_pattern not in name:
+                continue
+            # Gather tensor pieces when using tensor parallel.
+            param_ = gather_distributed_param(param, keep_vars=False)
+            block, block_size = state_dict_sharder.append_param(prefix + name, param_)
+            if block is not None:
+                yield block, block_size
+
+        # Save buffers.
+        for name, buf in model.named_buffers():
+            if buf is not None and name not in model._non_persistent_buffers_set:
+                buffer = buf if keep_vars else buf.detach()
+                block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
+                if block is not None:
+                    yield block, block_size
+
+        # Save extra states.
+        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
+        if (
+            getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+            is not torch.nn.Module.get_extra_state
+        ):
+            extra_state = model.get_extra_state()
+            block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
+            if block is not None:
+                yield block, block_size
+
+        # Return the last block in sharder.
+        yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+    def save_sharded_model(
+        self,
+        model: ModelWrapper,
+        checkpoint: str,
+        gather_dtensor: bool = True,
+        prefix: Optional[str] = None,
+        size_per_shard: int = 1024,
+        use_safetensors: bool = False,
+    ) -> None:
         """
-        Preprocess state_dict before loading and slice the state_dict of MOE tensors.
-        """
-        model_param_dict = dict(model.named_parameters())
-        for name, param in list(state_dict.items()):
-            if ".gate.weight" in name:
-                new_name = "module." + name.replace(".gate.weight", ".gate_weight")
-                state_dict[new_name] = state_dict.pop(name)
-            elif ".experts." in name:
-                # if is moe tensor
-                # in our moe module, expert is cat as one tensor
-                # but mixtral's experts is not cat
-                # we will insert the loaded expert into the position of cat tensor
+        Save sharded model checkpoint under the given checkpointing path.
+        The following files will be created under the path:
+        - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
+        - Multiple files that store state tensors of models.
+          If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
+          If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
 
-                # get model param
-                str_idx = name.index(".experts.")
-                expert_idx = int(name.split(".")[-3])
-                if ".w1." in name:
-                    model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
-                elif ".w2." in name:
-                    model_param_name = name.replace(name[str_idx:], ".experts.wo")
-                elif ".w3." in name:
-                    model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
-                model_param_name = "module." + model_param_name
-                # skip for pipeline
-                if model_param_name not in model_param_dict:
-                    continue
-                model_param = model_param_dict[model_param_name]
-                assert is_moe_tensor(model_param)
-                # get expert range
-                ep_rank = get_ep_rank(model_param)
-                ep_size = get_ep_size(model_param)
-                expert_num = 8 // ep_size
-                expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
-                # insert new param
-                if expert_idx in expert_range:
-                    new_param = model_param
-                    new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
-                    state_dict[model_param_name] = new_param
-                state_dict.pop(name)
-            else:
-                new_name = "module." + name
-                state_dict[new_name] = state_dict.pop(name)
-
-        dist.barrier()
-        return state_dict
-
-    def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
-        """
-        Load sharded model with the given path to index file of checkpoint folder.
 
         Args:
-            model (nn.Module): The model to be loaded.
-            checkpoint_index_file (str): Path to the index file of checkpointing folder.
-            strict (bool, optional): For name matching during loading state_dict. Defaults to False.
-                                     This argument should be manually set to False since params on same device might be stored in different files.
+            model (nn.Module): Model on local device to be saved.
+            checkpoint (str): Checkpointing path which should be a directory path.
+            gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
+            prefix (str, optional): Perfix of file to save. Defaults to None.
+            size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
+            use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
         """
 
-        # Check whether the checkpoint uses safetensors.
-        use_safetensors = False
-        if "safetensors" in checkpoint_index_file.name:
-            use_safetensors = True
+        assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
+        model = model.unwrap()
 
-        if use_safetensors and not is_safetensors_available():
-            raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
+        if os.path.isfile(checkpoint):
+            logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+            return
+
+        Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+        if self.real_dp_rank != 0:
+            return
+
+        # ep_rank 0 saves all the parameters and buffers.
+        # other ep_ranks save only experts
+        ep_param_pattern = "experts." if self.ep_rank != 0 else None
+
+        # Then collect the sharded parameters & buffers along tp_group.
+        # Only devices with tp_rank == 0 are responsible for model saving.
+        state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
+            model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
+        )
+        weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+        index_file = CheckpointIndexFile(checkpoint)
+        control_saving = self.tp_rank == 0
+
+        if self.pp_size == 1 and self.ep_size == 1:
+            # When pipeline is not used, save the model shards as in general checkpointIO
+            total_size = save_state_dict_shards(
+                sharded_state_dict=state_dict_shard,
+                checkpoint=checkpoint,
+                index_file=index_file,
+                base_filename=weights_name,
+                is_master=control_saving,
+                use_safetensors=use_safetensors,
+            )
+            if control_saving:
+                index_file.append_meta_data("total_size", total_size)
+                index_file.write_index_file(save_index_file)
+                save_config_file(model, checkpoint)
+                if self.verbose and self.coordinator.is_master():
+                    logging.info(
+                        f"The model is split into checkpoint shards. "
+                        f"You can find where each parameters has been saved in the "
+                        f"index located at {save_index_file}."
+                    )
+
+        else:
+            # When pipeline is used, each stage produces its own shard files and index files.
+            # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+            # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+            final_index_file_path = copy.deepcopy(save_index_file)
+            tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+            Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+            # Manage filenames of sharded weights and index file for each pipeline stage.
+            weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
+            weights_name = weights_name.replace(
+                ".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors"
+            )
+            save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
+            save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+            total_size = save_state_dict_shards(
+                sharded_state_dict=state_dict_shard,
+                checkpoint=checkpoint,
+                index_file=index_file,
+                base_filename=weights_name,
+                is_master=control_saving,
+                use_safetensors=use_safetensors,
+                use_pp_format=True,
+            )
+            if control_saving:
+                index_file.append_meta_data("total_size", total_size)
+                index_file.write_index_file(save_index_file)
+            else:
+                return
+
+            dist.barrier(self.pp_group)
+            dist.barrier(self.ep_group)
+
+            # The global master rank integrates the index files and clean the folder.
+            if self.coordinator.is_master():
+                final_index_file = CheckpointIndexFile(checkpoint)
+                final_index_file.append_meta_data("total_size", 0)
+
+                for filename in os.listdir(tmp_index_file_folder):
+                    stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+                    final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+                    for weight, weight_filename in stage_index_file.weight_map.items():
+                        final_index_file.append_weight_map(weight, weight_filename)
+
+                final_index_file.write_index_file(final_index_file_path)
+                save_config_file(model, checkpoint)
+                rmtree(tmp_index_file_folder)
+                if self.verbose and self.coordinator.is_master():
+                    logging.info(
+                        f"The model is split into checkpoint shards. "
+                        f"You can find where each parameters has been saved in the "
+                        f"index located at {final_index_file_path}."
+                    )
+
+    @staticmethod
+    def gather_from_sharded_optimizer_state(
+        state: OrderedDict,
+        param: torch.Tensor,
+        original_shape: torch.Size,
+        dp_group: ProcessGroup,
+        tp_group: ProcessGroup,
+        use_zero: bool,
+        inplace: bool,
+        is_moe_param: bool,
+        device: torch.device = torch.device("cpu"),
+    ) -> OrderedDict:
+        """
+        With given parameter and its optimizer states, gather the complete optimizer state for saving.
+
+        Args:
+            state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
+            param (torch.Tensor): The given parameter. It should be working_param when using Zero.
+            original_shape (torch.Size): The size of parameter before sharding.
+            dp_group (ProcessGroup): The process group of data parallel.
+            tp_group (ProcessGroup): The process group of tensor parallel.
+            use_zero (bool): Whether Zero is used.
+            inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+            device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
+
+        Returns:
+            OrderedDict: The complete optimizer state of given parameter.
+        """
+        dp_size = dist.get_world_size(dp_group)
+        tp_size = dist.get_world_size(tp_group)
+        current_shape = param.shape
+        state_ = state if inplace else copy.deepcopy(state)
+
+        for k, v in state_.items():
+            if isinstance(v, torch.Tensor) and k != "step":
+                # First gather Zero shards.
+                if use_zero and not is_moe_param:
+                    v = v.cuda()
+                    gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
+                    dist.all_gather(gather_tensor, v, group=dp_group)
+                    v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+
+                # Then gather TP shards.
+                partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
+                if partition_dim is not None:
+                    gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
+                    dist.all_gather(gather_tensor, v, group=tp_group)
+                    v = torch.cat(gather_tensor, dim=partition_dim)
+
+                state_[k] = v.detach().clone().to(device)
+
+        return state_
+
+    @staticmethod
+    def _optimizer_sharder(
+        optimizer: OptimizerWrapper,
+        use_zero: bool,
+        dp_group: ProcessGroup,
+        tp_group: ProcessGroup,
+        size_per_shard: int = 1024,
+        only_moe_param: bool = False,
+    ):
+        # An internel method that breaks state_dict of optimizer into shards within limited size.
+
+        state_dict_sharder = StateDictSharder(size_per_shard)
+        param_info = optimizer.param_info
+        master_to_working_map = optimizer.get_master_to_working_map()
+
+        for param, state in optimizer.optim.state.items():
+            if param is None:
+                continue
+
+            if master_to_working_map is not None:
+                working_param = master_to_working_map[id(param)]
+            else:
+                working_param = param
+
+            param_id = param_info["param2id"][id(working_param)]
+            original_shape = param_info["param2shape"][id(working_param)]
+            state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
+                state,
+                working_param,
+                original_shape=original_shape,
+                dp_group=dp_group,
+                tp_group=tp_group,
+                use_zero=use_zero,
+                inplace=False,
+                is_moe_param=is_moe_tensor(working_param),
+            )
+
+            if only_moe_param and not is_moe_tensor(working_param):
+                continue
+            block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
+            if block is not None:
+                yield block, block_size
+
+        # Return the last block in sharder.
+        yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+    def save_sharded_optimizer(
+        self,
+        optimizer: OptimizerWrapper,
+        checkpoint: str,
+        gather_dtensor: bool = True,
+        prefix: Optional[str] = None,
+        size_per_shard: int = 1024,
+    ):
+        """
+        Save sharded optimizer checkpoint under the given checkpointing path.
+        The following files will be created under the path:
+        - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+        - A group file (pytorch_optim_group.bin) recording information of param_groups
+        - Multiple files that store state tensors of optimizers.
+          If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
+          If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
+
+        Args:
+            optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+            checkpoint (str): Path to save optimizer state_dict
+            gather_dtensor (bool): Whether to gather_dtensor, not used
+            prefix (str): Perfix of file to save
+            size_per_shard (int): Max file size of each file shard that store state tensors
+        """
+        assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+        if os.path.isfile(checkpoint):
+            logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+            return
+
+        Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+        # Devices along the same dp_group share the same copies of states when zero is not used.
+        # In this case only let the device with dp_rank == 0 save the model.
+        if not self.use_zero and self.real_dp_rank != 0:
+            return
+
+        # Then collect the sharded states along dp_group(if using zero)/tp_group.
+        # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
+        state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
+            optimizer,
+            use_zero=self.use_zero,
+            dp_group=self.dp_group,
+            tp_group=self.tp_group,
+            size_per_shard=size_per_shard,
+            only_moe_param=self.ep_rank != 0,
+        )
+        states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+        index_file = CheckpointIndexFile(checkpoint)
+        control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
+
+        if self.pp_size == 1 and self.ep_size == 1:
+            # When pipeline is not used, save the optimizer shards as in general checkpointIO
+            total_size = save_state_dict_shards(
+                sharded_state_dict=state_dict_shard,
+                checkpoint=checkpoint,
+                index_file=index_file,
+                base_filename=states_name,
+                is_master=control_saving,
+            )
+
+            if control_saving:
+                # Store param groups.
+                index_file.append_meta_data("param_groups", param_group_file)
+                group_file_path = os.path.join(checkpoint, param_group_file)
+                save_param_groups(optimizer.param_info, group_file_path)
+                # Store index file.
+                index_file.append_meta_data("total_size", total_size)
+                index_file.write_index_file(save_index_file)
+                if self.verbose and self.coordinator.is_master():
+                    logging.info(
+                        f"The optimizer is going to be split to checkpoint shards. "
+                        f"You can find where each parameters has been saved in the "
+                        f"index located at {save_index_file}."
+                    )
+
+        else:
+            # When pipeline is used, each stage produces its own shard files and index files.
+            # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+            # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+            final_index_file_path = copy.deepcopy(save_index_file)
+            tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+            Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+            # Manage filenames of sharded weights and index file for each pipeline stage.
+            states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
+            save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
+            save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+            total_size = save_state_dict_shards(
+                sharded_state_dict=state_dict_shard,
+                checkpoint=checkpoint,
+                index_file=index_file,
+                base_filename=states_name,
+                is_master=control_saving,
+                use_pp_format=True,
+            )
+
+            if control_saving:
+                index_file.append_meta_data("total_size", total_size)
+                index_file.write_index_file(save_index_file)
+            else:
+                return
+
+            dist.barrier(self.pp_group)
+            dist.barrier(self.ep_group)
+
+            # The global master rank integrates the index files and clean the folder.
+            if self.coordinator.is_master():
+                final_index_file = CheckpointIndexFile(checkpoint)
+                final_index_file.append_meta_data("total_size", 0)
+
+                for filename in os.listdir(tmp_index_file_folder):
+                    stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+                    final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+                    for param_id, state_filename in stage_index_file.weight_map.items():
+                        final_index_file.append_weight_map(param_id, state_filename)
+
+                # Store param groups.
+                final_index_file.append_meta_data("param_groups", param_group_file)
+                group_file_path = os.path.join(checkpoint, param_group_file)
+                save_param_groups(optimizer.param_info, group_file_path)
+
+                final_index_file.write_index_file(final_index_file_path)
+                rmtree(tmp_index_file_folder)
+
+                if self.verbose and self.coordinator.is_master():
+                    logging.info(
+                        f"The model is split into checkpoint shards. "
+                        f"You can find where each parameters has been saved in the "
+                        f"index located at {final_index_file_path}."
+                    )
+
+    def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
+        """
+        Load sharded optimizer with the given path to index file of checkpoint folder.
+
+        Args:
+            optimizer (OptimizerWrapper): The optimizer to be loaded.
+            checkpoint_index_file (str): Path to the index file of checkpointing folder.
+            prefix (str): Not used.
+        """
+        assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+        def _get_param_id_from_optimizer_param(
+            param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+        ):
+            if master_to_working_map is not None:
+                working_param = master_to_working_map[id(param)]
+            else:
+                working_param = param
+            return optimizer.param_info["param2id"][id(working_param)]
+
+        # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
+        # When Zero is used, the mapped parameter objects should be fp32 master parameters.
+        # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
+        id_map = {}
+        master_to_working_map = optimizer.get_master_to_working_map()
+        for pg in optimizer.optim.param_groups:
+            for param in pg["params"]:
+                param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+                id_map[param_id] = param
 
         # Read checkpoint index file.
         ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
         ckpt_root_path = ckpt_index_file.root_path
         weight_map = ckpt_index_file.weight_map
-        strict = False
+        weight_map = {int(k): v for k, v in weight_map.items()}  # convert saved id from str to int
 
-        # Load params & buffers to model.
+        # Load param_groups
+        param_group_path = ckpt_index_file.get_param_group_filename()
+        if param_group_path is None:
+            raise RuntimeError(
+                f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+                               Lacking param group file under current directory."
+            )
+        saved_groups = torch.load(param_group_path)
+
+        updated_groups = []
+        for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+            # obtain updated param group
+            new_pg = copy.deepcopy(saved_pg)
+            new_pg["params"] = old_pg["params"]  # The parameters in the same group shouln't change.
+            updated_groups.append(new_pg)
+        # ep param groups
+        if len(optimizer.optim.param_groups) == len(saved_groups) + 1:
+            new_pg = copy.deepcopy(saved_pg)
+            new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
+            updated_groups.append(new_pg)
+        optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+        # Load saved states to optimizer.
         # Keep a record of loaded files so that file will not be repeatedly loaded.
         loaded_file = set()
+        for pg in optimizer.optim.param_groups:
+            for param in pg["params"]:
+                if param is None:
+                    continue
+                param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+                if param_id not in weight_map:
+                    continue
+                filename = weight_map[param_id]
 
-        def _load(name: str):
-            if name not in weight_map:
-                raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
-            filename = weight_map[name]
+                # If this param's states has been loaded before, directly return.
+                if filename in loaded_file:
+                    continue
 
-            # If this param/buffer has been loaded before, directly return.
-            if filename in loaded_file:
-                return
+                file_path = os.path.join(ckpt_root_path, filename)
+                state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
+                load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
+                loaded_file.add(filename)
 
-            file_path = os.path.join(ckpt_root_path, filename)
-            state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
-            state_dict = self.pre_load_model(model, state_dict)
-            missing_keys = []
-
-            load_state_dict_into_model(
-                model,
-                state_dict,
-                missing_keys=missing_keys,
-                strict=strict,
-                load_sub_module=True,
+        # Then shard the loaded optimizer states if using tp/zero.
+        for param, state in optimizer.optim.state.items():
+            device = param.device
+            if master_to_working_map is not None:
+                working_param = master_to_working_map[id(param)]
+            else:
+                working_param = param
+            original_shape = optimizer.param_info["param2shape"][id(working_param)]
+            sharded_state = self.shard_from_complete_optimizer_state(
+                state,
+                current_shape=working_param.shape,
+                original_shape=original_shape,
+                device=device,
+                inplace=True,
+                is_moe_param=is_moe_tensor(working_param),
             )
-            loaded_file.add(filename)
+            optimizer.optim.state[param] = sharded_state
 
-        # Load parameters.
-        for name, _ in model.named_parameters():
-            name = name.replace("module.", "")
-            name = name.replace(".gate_weight", ".gate.weight")
-            if ".experts.wi_gate" in name:
-                for i in range(8):
-                    new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
-                    _load(new_name)
-            elif ".experts.wi_up" in name:
-                for i in range(8):
-                    new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
-                    _load(new_name)
-            elif ".experts.wo" in name:
-                for i in range(8):
-                    new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
-                    _load(new_name)
-            else:
-                _load(name)
+        sharded_optimizer_loading_epilogue(optimizer.optim)
+        if self.verbose and self.coordinator.is_master():
+            logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
 
-        if self.verbose:
-            logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+    def shard_from_complete_optimizer_state(
+        self,
+        state: OrderedDict,
+        current_shape: torch.Size,
+        original_shape: torch.Size,
+        device: torch.device,
+        inplace: bool,
+        is_moe_param: bool,
+    ) -> OrderedDict:
+        """
+        With complete optimizer states of a specific parameter loaded from checkpoint,
+        slice out the sharded optimizer states kept by current device.
 
-    @torch.no_grad()
-    def pre_save_model(self, model: nn.Module) -> dict:
-        torch.cuda.empty_cache()
-        state_dict = model.state_dict()
-        for name, param in list(model.named_parameters()):
-            if ".gate_weight" in name:
-                new_name = name.replace(".gate_weight", ".gate.weight")
-                state_dict[new_name] = state_dict.pop(name).cpu()
-            elif ".experts." in name:
-                ep_group = get_ep_group(param)
-                ep_rank = get_ep_rank(param)
-                ep_size = get_ep_size(param)
-                dp_rank = get_dp_rank(param)
+        Args:
+            state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
+            current_shape (torch.Size): The size of parameter after sharding.
+            original_shape (torch.Size): The size of parameter before sharding.
+            device (torch.device): The destination device of loaded optimizer states.
+            inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
 
-                if dp_rank == 0:
-                    param = param.data.cuda()
-                    all_param = [torch.zeros_like(param) for _ in range(ep_size)]
-                    # gather param from every ep rank
-                    dist.all_gather(all_param, param, group=ep_group)
-                    if ep_rank == 0:
-                        all_param = torch.cat(all_param, dim=0)
-                        assert all_param.shape[0] == 8
-                        for i in range(8):
-                            if ".wi_gate" in name:
-                                new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
-                            elif ".wi_up" in name:
-                                new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
-                            elif ".wo" in name:
-                                new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
-                            new_name = new_name.replace("module.", "")
-                            new_param = all_param[i].transpose(-1, -2)
-                            state_dict[new_name] = new_param.cpu()
-                        state_dict.pop(name)
-            else:
-                state_dict[name] = param.cpu()
+        Returns:
+            OrderedDict: The sharded optimizer state of the given parameter.
+        """
+        state_ = state if inplace else copy.deepcopy(state)
 
-        for name, param in list(state_dict.items()):
-            new_name = name.replace("module.", "")
-            state_dict[new_name] = state_dict.pop(name)
-        
-        torch.cuda.empty_cache()
-        if self.pp_size > 1:
-            if self.dp_rank == 0:
-                # gather state_dict from every pp rank
-                # because ckpt is large, we split it into 10 parts
-                # and gather them one by one
-                new_state_dict = {}
-                state_dict_keys = list(state_dict.keys())
-                gap_key_num = min(30, len(state_dict_keys))
-                gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num
-                for i in range(gap_key_num):
-                    cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
-                    cur_state_dict = {}
-                    for k in cur_keys:
-                        cur_state_dict[k] = state_dict[k]
-                    out = [None for _ in range(self.pp_size)]
-                    dist.all_gather_object(out, cur_state_dict, group=self.pp_group)
-                    if self.pp_rank == 0:
-                        for o in out:
-                            for k, v in o.items():
-                                new_state_dict[k] = v.cpu()
-                state_dict = new_state_dict
-        dist.barrier()
-        return state_dict
+        for k, v in state_.items():
+            if isinstance(v, torch.Tensor) and k != "step":
+                # Shard state along tensor parallel group.
+                partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
+                if partition_dim is not None:
+                    slice_size = current_shape[partition_dim]
+                    v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
+
+                # Shard state along data parallel group when using Zero.
+                if self.use_zero and not is_moe_param:
+                    padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+                    with torch.no_grad():
+                        v = v.flatten()
+                        if padding_size > 0:
+                            v = torch.nn.functional.pad(v, [0, padding_size])
+                        slice_size = v.numel() // self.dp_size
+                        v = v.split(slice_size, dim=0)[self.dp_rank]
+
+                state_[k] = v.detach().clone().to(device)
+
+        return state_
+
+    def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
+        raise NotImplementedError
+
+    def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
+        raise NotImplementedError
+
+    def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
+        raise NotImplementedError
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
index e395c8578..a2b78a2bd 100644
--- a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
@@ -1,80 +1,92 @@
 import torch
-import torch.nn as nn
-from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock
+import torch.distributed as dist
+import torch.nn.functional as F
+from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
 
 from colossalai.lazy import LazyInitContext
-from colossalai.moe import SparseMLP
+from colossalai.moe import MOE_MANAGER
+from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
 
 
-class MixtralSparseMLP:
-    r"""
-    This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
-    """
+class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
+    def __init__(self, config):
+        super().__init__(config)
+        self.setup_ep()
 
-    def __init__(self) -> None:
-        raise NotImplementedError(
-            "FusedLayerNorm is not implemented as a physical class. "
-            "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
-        )
+    def setup_ep(self):
+        _, moe_info = MOE_MANAGER.get_info(self.num_experts)
+        ep_group = moe_info.ep_group
+        self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
+        self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+        assert self.num_experts % self.ep_size == 0
+        self.ep_group = ep_group
+        self.num_experts_per_ep = self.num_experts // self.ep_size
+        self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
+        held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+        set_tensors_to_none(self.experts, exclude=set(held_experts))
+        for p in self.experts.parameters():
+            set_moe_tensor_info(p, moe_info)
 
     @staticmethod
-    def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module:
-        r"""
-        Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
-        and optionally marking parameters for gradient aggregation.
+    def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
+        LazyInitContext.materialize(module)
+        module.__class__ = EPMixtralSparseMoeBlock
+        module.setup_ep()
+        return module
 
-        Args:
-            module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
-            sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        batch_size, sequence_length, hidden_dim = hidden_states.shape
+        hidden_states = hidden_states.view(-1, hidden_dim)
+        # router_logits: (batch * sequence_length, n_experts)
+        router_logits = self.gate(hidden_states)
 
-        Returns:
-            nn.Module: Union[FastLayerNorm, FusedLayerNorm].
+        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
+        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
+        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
+        # we cast back to the input dtype
+        routing_weights = routing_weights.to(hidden_states.dtype)
 
-        Raises:
-            AssertionError: If the provided module is not an instance of nn.LayerNorm.
-        """
-        with torch.no_grad():
-            LazyInitContext.materialize(module)
+        selected_experts = selected_experts.t().reshape(-1)
+        selected_experts_idx = selected_experts.argsort()
+        dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
+        input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
+        output_split_sizes = torch.zeros_like(input_split_sizes)
+        dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
 
-            # get the attributes of the module
-            moe_kwargs = dict(
-                num_experts=8,
-                hidden_size=module.hidden_dim,
-                intermediate_size=module.ffn_dim,
-                router_top_k=module.top_k,
-                router_norm=True,
-                router_loss=False,
-                # router_capacity_factor_train=
-                # router_capacity_factor_eval=
-                mlp_activation="silu",
-                mlp_gated=True,
-                # enable_load_balance=
-                # load_balance_tolerance=
-                # load_balance_beam_width=
-                # load_balance_group_swap_factor=
-                enable_kernel=enable_kernel,
-                # enable_comm_overlap=
-                # enable_hierarchical_comm=
-                return_gate_logits=True,
-            )
-            dtype = module.gate.weight.dtype
-            device = module.gate.weight.device
-            sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
-
-        return sparse_mlp
-
-
-def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module:
-    """
-    Reverse the replace layer operation
-
-    Args:
-        module (torch.nn.Module): The object of layer to shard
-    """
-    if isinstance(model, MixtralDecoderLayer):
-        model.block_sparse_moe = MixtralSparseMLP.from_native_module(
-            model.block_sparse_moe, enable_kernel=enable_kernel
+        input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+        output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
+        output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
+        # compute expert output
+        output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+        if output_states.size(0) > 0:
+            if self.num_experts_per_ep == 1:
+                # no need to split
+                expert = self.experts[self.expert_start_idx]
+                output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
+                output_states = expert.w2(output_states)
+            else:
+                output_states_splits = output_states.split(output_split_sizes.tolist())
+                output_states_list = []
+                for i, split_states in enumerate(output_states_splits):
+                    if split_states.size(0) == 0:
+                        continue
+                    expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+                    split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
+                    split_states = expert.w2(split_states)
+                    output_states_list.append(split_states)
+                output_states = torch.cat(output_states_list)
+        output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+        dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
+        recover_experts_idx = torch.empty_like(selected_experts_idx)
+        recover_experts_idx[selected_experts_idx] = torch.arange(
+            selected_experts_idx.size(0), device=selected_experts_idx.device
         )
-    else:
-        for _, child in model.named_children():
-            replace_moe_layer(child, enable_kernel)
+        dispatch_states = dispatch_states[recover_experts_idx]
+        k_hidden_states = dispatch_states.chunk(self.top_k)
+        output_states = k_hidden_states[0] * routing_weights[:, 0, None]
+        for i in range(1, self.top_k):
+            output_states += k_hidden_states[i] * routing_weights[:, i, None]
+        output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
+        return output_states, router_logits
diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
index 2f6021f2d..734695278 100644
--- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
+++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
@@ -20,6 +20,8 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
 from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
 from colossalai.shardformer.shard import ShardConfig
 
+from .mixtral_layer import EPMixtralSparseMoeBlock
+
 __all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
 
 
@@ -51,6 +53,18 @@ class MixtralPolicy(Policy):
         if self.shard_config.enable_tensor_parallelism:
             raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
 
+        # expert parallel
+        self.append_or_create_submodule_replacement(
+            description=[
+                SubModuleReplacementDescription(
+                    suffix="block_sparse_moe",
+                    target_module=EPMixtralSparseMoeBlock,
+                )
+            ],
+            policy=policy,
+            target_key=MixtralDecoderLayer,
+        )
+
         # optimization configuration
         if self.shard_config.enable_fused_normalization:
             self.append_or_create_submodule_replacement(
diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py
index 70b827264..a2a0a7e78 100644
--- a/applications/ColossalMoE/colossal_moe/utils.py
+++ b/applications/ColossalMoE/colossal_moe/utils.py
@@ -3,7 +3,6 @@ import os
 from typing import Any, Dict, Tuple, Union
 
 import torch
-from huggingface_hub import snapshot_download
 from torch.optim.lr_scheduler import _LRScheduler
 from torch.optim.optimizer import Optimizer
 
@@ -15,23 +14,6 @@ def move_to_cuda(batch, device):
     return {k: v.to(device) for k, v in batch.items()}
 
 
-@torch.no_grad()
-def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
-    # pytorch ckpt
-    if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
-        ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
-    # saved ckpt
-    elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
-        ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
-    # download
-    else:
-        ckpt_path = snapshot_download(ckpt_path)
-    booster.load_model(model, ckpt_path)
-    if optimizer is not None:
-        optimizer.sync_moe_master_param()
-        optimizer.update_master_params(model)
-
-
 def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
     """
     Load file in JSON format
@@ -90,7 +72,7 @@ def load_checkpoint(
     """
 
     # Update booster params states.
-    load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
+    booster.load_model(model, os.path.join(load_dir, "modeling"))
     booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
     booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
 
diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py
index d234fb628..46ff70ff3 100644
--- a/applications/ColossalMoE/infer.py
+++ b/applications/ColossalMoE/infer.py
@@ -2,10 +2,8 @@ import argparse
 
 import torch
 import torch.distributed as dist
-from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
-from colossal_moe.models.mixtral_layer import replace_moe_layer
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
 from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
-from colossal_moe.utils import load_model
 from transformers import AutoTokenizer
 from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
 
@@ -13,9 +11,6 @@ import colossalai
 from colossalai.booster import Booster
 from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
 from colossalai.cluster import DistCoordinator
-from colossalai.moe import MOE_MANAGER
-from colossalai.moe.utils import skip_init
-from colossalai.utils import get_current_device
 
 
 def parse_args():
@@ -30,16 +25,10 @@ def parse_args():
     parser.add_argument(
         "--plugin",
         type=str,
-        default="hybrid",
+        default="ep",
         choices=["ep"],
         help="Parallel methos.",
     )
-    parser.add_argument(
-        "--output_path",
-        type=str,
-        default="./outputs",
-        help="The path of your saved model after finetuning.",
-    )
     parser.add_argument(
         "--precision",
         type=str,
@@ -71,60 +60,38 @@ def main():
     colossalai.launch_from_torch(config={}, seed=args.seed)
     coordinator = DistCoordinator()
 
+    config = MixtralConfig.from_pretrained(args.model_name)
+    ep_size = min(dist.get_world_size(), config.num_local_experts)
     # Set plugin
-    booster_kwargs = {}
-    hybrid_dict = {
-        "tp_size": 1,
-        "custom_policy": MixtralForCausalLMPolicy(),
-        "enable_fused_normalization": args.use_layernorm_kernel,
-        "enable_jit_fused": args.use_kernel,
-        "precision": args.precision,
-        "checkpoint_io": MixtralMoECheckpointIO,
-        "zero_stage": 1,
-    }
-    mgr_dict = {}
     if args.plugin == "ep":
-        dp_size = dist.get_world_size()
         plugin = MoeHybridParallelPlugin(
+            tp_size=1,
             pp_size=1,
-            **hybrid_dict,
-        )
-        MOE_MANAGER.setup(
-            parallel="EP",
-            max_ep_size=dp_size,
-            **mgr_dict,
+            ep_size=ep_size,
+            zero_stage=1,
+            precision=args.precision,
+            custom_policy=MixtralForCausalLMPolicy(),
+            checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
+            enable_fused_normalization=args.use_layernorm_kernel,
+            enable_jit_fused=args.use_kernel,
         )
     else:
         raise ValueError(f"Invalid plugin {args.plugin}")
     coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
 
     # Build mixtral model
-    config = MixtralConfig.from_pretrained(args.model_name)
-    config.num_local_experts = 1  # dont change this. it will not affect model
-    with skip_init():
-        model = MixtralForCausalLM(config)
-    model.num_experts = 8
-    model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
-    model = model.to(get_current_device())
-    coordinator.print_on_master(f"Finish init model with config:\n{config}")
-
-    # Replace moe
-    with skip_init():
-        replace_moe_layer(model)
-    model.eval()
-    coordinator.print_on_master(f"Finish replace moe module")
+    model = MixtralForCausalLM.from_pretrained(args.model_name)
+    coordinator.print_on_master(f"Finish load model")
 
     # Prepare tokenizer and dataloader
     tokenizer = AutoTokenizer.from_pretrained(args.model_name)
 
     # Set booster
-    booster = Booster(plugin=plugin, **booster_kwargs)
+    booster = Booster(plugin=plugin)
     model, _, _, _, _ = booster.boost(model=model)
     coordinator.print_on_master(f"Finish init booster")
 
-    # load ckpt
-    load_model(args.model_name, model, booster)
-    coordinator.print_on_master(f"Finish load ckpt")
+    model.eval()
 
     if coordinator.rank == 0:
         text = ["Hello my name is"]
@@ -132,10 +99,13 @@ def main():
         text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
     tokenizer.pad_token = tokenizer.unk_token
     inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
-    outputs = model.module.generate(**inputs, max_new_tokens=20)
-    outputs = tokenizer.batch_decode(outputs)
+
+    with torch.no_grad():
+        outputs = model.module.generate(**inputs, max_new_tokens=20)
+    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
     print(f"[{coordinator.rank}] {outputs}")
 
 
+
 if __name__ == "__main__":
     main()
diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py
new file mode 100644
index 000000000..57589ab20
--- /dev/null
+++ b/applications/ColossalMoE/tests/test_mixtral_layer.py
@@ -0,0 +1,63 @@
+from copy import deepcopy
+
+import pytest
+import torch
+import torch.distributed as dist
+from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
+from torch.testing import assert_close
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
+
+import colossalai
+from colossalai.moe import MOE_MANAGER
+from colossalai.testing.utils import spawn
+
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
+
+
+def check_mixtral_moe_layer():
+    torch.cuda.set_device(dist.get_rank())
+    MOE_MANAGER.setup(
+        parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
+    )
+    config = MixtralConfig(
+        hidden_size=hidden_size,
+        intermediate_size=hidden_size * 2,
+        num_local_experts=n_experts,
+        num_experts_per_tok=top_k,
+    )
+    torch.manual_seed(0)
+    orig_model = MixtralSparseMoeBlock(config).cuda()
+    x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
+    orig_output, orig_logits = orig_model(x)
+    model = deepcopy(orig_model)
+    model = EPMixtralSparseMoeBlock.from_native_module(model)
+    ep_output, ep_logits = model(x)
+    assert_close(orig_logits, ep_logits)
+    assert_close(orig_output, ep_output)
+    orig_loss = orig_output.mean()
+    orig_loss.backward()
+    ep_loss = ep_output.mean()
+    ep_loss.backward()
+    assert_close(orig_loss, ep_loss)
+    name_to_p = {n: p for n, p in orig_model.named_parameters()}
+    for n, ep_p in model.named_parameters():
+        p = name_to_p[n]
+        if ep_p.grad is not None:
+            assert_close(p.grad, ep_p.grad)
+
+
+def run_dist(rank: int, world_size: int, port: int):
+    colossalai.launch({}, rank, world_size, "localhost", port)
+    check_mixtral_moe_layer()
+
+
+@pytest.mark.parametrize("world_size", [2, 4])
+def test_mixtral_moe_layer(world_size: int):
+    spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+    test_mixtral_moe_layer(2)
diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py
index 7c6012a70..d3848bc14 100644
--- a/applications/ColossalMoE/tests/test_moe_checkpoint.py
+++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py
@@ -1,185 +1,144 @@
-import os
-import shutil
+from copy import deepcopy
 
 import pytest
 import torch
 import torch.distributed as dist
-from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
-from colossal_moe.models.mixtral_layer import replace_moe_layer
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
 from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
-from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
+from torch.optim import Adam
+from transformers.models.mixtral.configuration_mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
 
 import colossalai
 from colossalai.booster import Booster
 from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
-from colossalai.moe.manager import MOE_MANAGER
-from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
+from colossalai.testing.utils import spawn
+
+tokens, n_experts = 7, 4
+hidden_size = 8
+top_k = 2
 
 
-def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
-    input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
-    attention_mask = torch.ones_like(input_ids)
+def check_model_equal(model1, model2):
+    assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
+    for p1, p2 in zip(model1.parameters(), model2.parameters()):
+        assert torch.equal(p1.half(), p2.half())
+
+
+def get_optimizer_snapshot(optim):
+    state = {id(k): deepcopy(v) for k, v in optim.state.items()}
+    param_groups = []
+    for group in optim.param_groups:
+        params = [id(p) for p in group["params"]]
+        new_group = {"params": params}
+        for k, v in group.items():
+            if k != "params":
+                new_group[k] = v
+        param_groups.append(new_group)
     return {
-        "input_ids": input_ids,
-        "attention_mask": attention_mask,
-        "labels": input_ids,
+        "state": state,
+        "param_groups": param_groups,
     }
 
 
-def run_fwd_bwd(
-    model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
-):
-    model.train()
-    if pipeline:
-        train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
-        is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
-        y = booster.execute_pipeline(
-            train_dataloader_iter,
-            model,
-            lambda x, y: x.loss,
-            optimizer,
-            return_loss=True,
-            return_outputs=True,
-        )
-        # Backward and optimize
-        if is_pp_last_stage:
-            loss = y["loss"]
-    else:
-        if criterion:
-            y = model(data).logits
-            loss = criterion(y)
-        else:
-            loss = model(data, label)
-        loss = loss.float()
-
-        if optimizer is not None:
-            optimizer.backward(loss)
-        else:
-            loss.backward()
-    return y
+def check_optimizer_snapshot_equal(snapshot1, snapshot2):
+    # check param_groups
+    assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
+    for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
+        assert set(group1.keys()) == set(group2.keys())
+        for k in group1.keys():
+            assert group1[k] == group2[k]
+    # check state
+    assert set(snapshot1["state"].keys()) == set(
+        snapshot2["state"].keys()
+    ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
+    for pid in snapshot1["state"].keys():
+        state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
+        assert set(state1.keys()) == set(state2.keys())
+        for k in state1.keys():
+            if isinstance(state1[k], torch.Tensor):
+                assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
+            else:
+                assert state1[k] == state2[k]
 
 
-def get_config():
+def check_mixtral_moe_layer():
+    torch.cuda.set_device(dist.get_rank())
     config = MixtralConfig(
-        vocab_size=300,
-        hidden_size=32,
-        intermediate_size=16,
-        num_hidden_layers=2,
-        dropout_rate=0.0,
+        hidden_size=hidden_size,
+        intermediate_size=hidden_size * 2,
+        num_local_experts=n_experts,
+        num_experts_per_tok=top_k,
+        num_attention_heads=2,
+        num_key_value_heads=2,
     )
-    return config
-
-
-def get_model(parallel):
-    config = get_config()
-    model = MixtralForCausalLM(config).to(torch.bfloat16)
-    replace_moe_layer(model)
-    optim = torch.optim.Adam(model.parameters())
-    args = dict(
-        precision="bf16",
+    torch.manual_seed(0)
+    input_ids = torch.randint(0, 100, (2, tokens)).cuda()
+    orig_model = MixtralForCausalLM(config).cuda()
+    model = deepcopy(orig_model)
+    optimizer = Adam(model.parameters(), lr=1e-3)
+    plugin = MoeHybridParallelPlugin(
         tp_size=1,
-        zero_stage=1,
+        pp_size=2,
+        ep_size=2,
         custom_policy=MixtralForCausalLMPolicy(),
-        checkpoint_io=MixtralMoECheckpointIO,
+        checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
+        microbatch_size=1,
+        zero_stage=1,
     )
-    if parallel == "ep":
-        plugin = MoeHybridParallelPlugin(
-            pp_size=1,
-            **args,
-        )
-    elif parallel == "hybrid":
-        plugin = MoeHybridParallelPlugin(
-            pp_size=2,
-            microbatch_size=1,
-            **args,
-        )
     booster = Booster(plugin=plugin)
-    model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
-    return model, booster, optim
-
-
-def _test_moe_checkpoint(parallel):
-    if dist.get_rank() == 0:
-        if os.path.exists("./tmp_ckpt1"):
-            shutil.rmtree("./tmp_ckpt1")
-        if os.path.exists("./tmp_ckpt2"):
-            shutil.rmtree("./tmp_ckpt2")
-    dist.barrier()
-
-    if parallel == None:
-        MOE_MANAGER.setup(
-            parallel=None,
-        )
-    elif parallel == "ep":
-        MOE_MANAGER.setup(
-            parallel="EP",
-        )
-    elif parallel == "hybrid":
-        MOE_MANAGER.setup(
-            parallel="EP",
-            mode="fixed",
-            fixed_dp_size=1,
-            fixed_ep_size=2,
-            fixed_pp_size=2,
-        )
-    model1, booster1, optim1 = get_model(parallel)
-    model2, booster2, optim2 = get_model(parallel)
-    # param ckpt
-    # check not equal
-    try:
-        check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
-        raise AssertionError("state_dict should not be equal")
-    except:
-        pass
-    # shard
-    booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
-    booster2.load_model(model2, "./tmp_ckpt1")
-    # check
-    check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
-
-    # optim ckpt
-    criterion = lambda x: x.mean()
-    data = torch.randint(0, 4, (2, 4)).cuda()
-    label = torch.randint(0, 4, (2,)).cuda()
-    if parallel == "hybrid":
-        kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
-    else:
-        kwargs = {}
-    run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
-    optim1.step()
-    optim1.zero_grad()
-    # shard
-    booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
-    dist.barrier()
-    booster2.load_optimizer(optim2, "./tmp_ckpt2")
-    # check
-    check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
-
-    if dist.get_rank() == 0:
-        shutil.rmtree("./tmp_ckpt1")
-        shutil.rmtree("./tmp_ckpt2")
-
-
-def _run_dist(rank, world_size, port, parallel):
-    colossalai.launch(
-        config=dict(),
-        rank=rank,
-        world_size=world_size,
-        host="localhost",
-        port=port,
-        backend="nccl",
+    model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
+    # initialize grads
+    data_iter = iter(
+        [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
     )
-    _test_moe_checkpoint(parallel)
+    booster.execute_pipeline(
+        data_iter,
+        model,
+        lambda outputs, inputs: outputs.loss,
+        optimizer,
+    )
+
+    # check save model
+    booster.save_model(model, "mixtral_model", shard=True)
+    dist.barrier()
+    if dist.get_rank() == 0:
+        saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
+        check_model_equal(orig_model, saved_model)
+        saved_model.save_pretrained("mixtral_hf_model")
+    dist.barrier()
+
+    # check load model
+    new_model = MixtralForCausalLM(config).cuda()
+    new_optimizer = Adam(new_model.parameters(), lr=1e-3)
+    new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
+    booster.load_model(new_model, "mixtral_hf_model")
+    check_model_equal(model, new_model)
+
+    # check save optimizer
+    optimizer.step()
+    snapshot = get_optimizer_snapshot(optimizer.unwrap())
+    booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
+    dist.barrier()
+    # reset optimizer state
+    for state in optimizer.unwrap().state.values():
+        for v in state.values():
+            if isinstance(v, torch.Tensor):
+                v.zero_()
+    booster.load_optimizer(optimizer, "mixtral_optim")
+    loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
+    check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
+
+
+def run_dist(rank: int, world_size: int, port: int):
+    colossalai.launch({}, rank, world_size, "localhost", port)
+    check_mixtral_moe_layer()
 
 
-@pytest.mark.dist
 @pytest.mark.parametrize("world_size", [4])
-@pytest.mark.parametrize("parallel", ["ep", "hybrid"])
-@rerun_if_address_is_in_use()
-def test_moe_checkpoint(world_size, parallel):
-    spawn(_run_dist, world_size, parallel=parallel)
+def test_mixtral_moe_layer(world_size: int):
+    spawn(run_dist, world_size)
 
 
 if __name__ == "__main__":
-    test_moe_checkpoint(world_size=4, parallel="hybrid")
+    test_mixtral_moe_layer(4)
diff --git a/applications/ColossalMoE/tests/test_moe_layer.py b/applications/ColossalMoE/tests/test_moe_layer.py
deleted file mode 100644
index 8b090c427..000000000
--- a/applications/ColossalMoE/tests/test_moe_layer.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import copy
-
-import torch
-from colossal_moe.models.mixtral_layer import MixtralSparseMLP
-from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
-
-
-class Config:
-    def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
-        self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size
-        self.num_local_experts = num_local_experts
-        self.num_experts_per_tok = num_experts_per_tok
-        self.hidden_act = hidden_act
-
-
-def test_moe_layer():
-    config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
-    mistral_moe = MixtralSparseMoeBlock(config).cuda()
-    colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
-
-    data = torch.randn(2, 8, 4).cuda()
-    mistral_output = mistral_moe(data)[0]
-    colossal_output = colossal_moe(data)[0]
-    assert torch.allclose(
-        mistral_output, colossal_output
-    ), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
-
-
-if __name__ == "__main__":
-    test_moe_layer()
diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py
index 1d0441a5a..c567038ec 100644
--- a/applications/ColossalMoE/train.py
+++ b/applications/ColossalMoE/train.py
@@ -2,22 +2,18 @@ import argparse
 
 import torch
 import torch.distributed as dist
-from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
-from colossal_moe.models.mixtral_layer import replace_moe_layer
+from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
 from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
-from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
+from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
 from torch.utils.data import Dataset
 from tqdm import tqdm
 from transformers import AutoTokenizer
-from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
+from transformers.models.mixtral import MixtralForCausalLM
 
 import colossalai
 from colossalai.booster import Booster
 from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
 from colossalai.cluster import DistCoordinator
-from colossalai.moe import MOE_MANAGER, apply_load_balance
-from colossalai.moe.layers import apply_load_balance
-from colossalai.moe.manager import MOE_MANAGER
 from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.utils import get_current_device
@@ -153,45 +149,27 @@ def main():
     coordinator = DistCoordinator()
 
     # Set plugin
-    booster_kwargs = {}
-    hybrid_dict = {
-        "tp_size": 1,
-        "custom_policy": MixtralForCausalLMPolicy(),
-        "enable_fused_normalization": args.use_layernorm_kernel,
-        "enable_jit_fused": args.use_kernel,
-        "precision": args.precision,
-        "zero_stage": args.zero_stage,
-        "checkpoint_io": MixtralMoECheckpointIO,
-    }
-    mgr_dict = {}
     if args.plugin == "hybrid":
         plugin = MoeHybridParallelPlugin(
+            tp_size=1,
             pp_size=args.pp_size,
+            ep_size=args.ep_size,
             microbatch_size=args.microbatch_size,
-            **hybrid_dict,
-        )
-        MOE_MANAGER.setup(
-            parallel="EP",
-            mode="fixed",
-            fixed_dp_size=args.dp_size,
-            fixed_ep_size=args.ep_size,
-            fixed_pp_size=args.pp_size,
-            **mgr_dict,
+            custom_policy=MixtralForCausalLMPolicy(),
+            enable_fused_normalization=args.use_layernorm_kernel,
+            enable_jit_fused=args.use_kernel,
+            precision=args.precision,
+            zero_stage=args.zero_stage,
+            checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
         )
+
     else:
         raise ValueError(f"Invalid plugin {args.plugin}")
     coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
 
     # Build Mixtral model
-    config = MixtralConfig.from_pretrained(args.model_name)
-    config.use_cache = False
-    config.num_local_experts = 1
-    model = MixtralForCausalLM(config)
-    model.num_experts = 8
-    model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
-    model = model.to(get_current_device())
-    replace_moe_layer(model, enable_kernel=args.use_kernel)
-    coordinator.print_on_master(f"Finish init model with config:\n{config}")
+    model = MixtralForCausalLM.from_pretrained(args.model_name)
+    coordinator.print_on_master(f"Finish init model")
 
     # Enable gradient checkpointing
     model.gradient_checkpointing_enable()
@@ -224,7 +202,7 @@ def main():
     )
 
     # Set booster
-    booster = Booster(plugin=plugin, **booster_kwargs)
+    booster = Booster(plugin=plugin)
     model, optimizer, _, dataloader, lr_scheduler = booster.boost(
         model=model,
         optimizer=optimizer,
@@ -236,10 +214,7 @@ def main():
     coordinator.print_on_master(f"Finish init booster")
 
     # Load ckpt
-    if args.load_checkpoint is None:
-        load_model(args.model_name, model, booster, optimizer)
-        coordinator.print_on_master(f"Finish load checkpoint")
-    else:
+    if args.load_checkpoint is not None:
         load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
         coordinator.print_on_master(f"Finish load optimizer")
 
@@ -286,13 +261,13 @@ def main():
                 optimizer.zero_grad()
 
                 # Apply load balance
-                if (
-                    args.load_balance
-                    and args.load_balance_interval > 0
-                    and (step + 1) % args.load_balance_interval == 0
-                ):
-                    coordinator.print_on_master(f"Apply load balance")
-                    apply_load_balance(model, optimizer)
+                # if (
+                #     args.load_balance
+                #     and args.load_balance_interval > 0
+                #     and (step + 1) % args.load_balance_interval == 0
+                # ):
+                #     coordinator.print_on_master(f"Apply load balance")
+                #     apply_load_balance(model, optimizer)
                 # save ckeckpoint
                 if (step + 1) % args.save_interval == 0:
                     coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 07cbc14a7..45e5a23c1 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
 )
 from colossalai.cluster import ProcessGroupMesh
 from colossalai.interface import ModelWrapper, OptimizerWrapper
-from colossalai.moe import MoECheckpintIO
+from colossalai.moe import MOE_MANAGER, MoECheckpintIO
 from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
 from colossalai.pipeline.stage_manager import PipelineStageManager
 from colossalai.shardformer import ShardConfig
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
         self,
         tp_size: int,
         pp_size: int,
+        ep_size: int,
         extra_dp_size: int = 1,
         precision: str = "fp16",
         zero_stage: int = 0,
@@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
 
         if enable_sequence_parallelism:
             assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
-
+        assert (
+            dist.get_world_size() % (tp_size * pp_size) == 0
+        ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+        assert (
+            dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
+        ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
+        self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
+        MOE_MANAGER.setup(
+            parallel="EP",
+            mode="fixed",
+            fixed_dp_size=self.real_dp_size,
+            fixed_ep_size=ep_size,
+            fixed_pp_size=pp_size,
+            use_ep_inside=use_ep_inside,
+        )
         self.tp_size = tp_size
         self.pp_size = pp_size
         self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+        self.ep_size = ep_size
+        self.moe_info = MOE_MANAGER.get_info(0)[1]
         self.precision = precision
         self.zero_stage = zero_stage
         self.cpu_offload = cpu_offload
diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py
index 780117598..712324215 100644
--- a/colossalai/checkpoint_io/checkpoint_io_base.py
+++ b/colossalai/checkpoint_io/checkpoint_io_base.py
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
 
 from colossalai.interface import ModelWrapper
 
-from .utils import has_index_file
+from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
 
 __all__ = ["CheckpointIO"]
 
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
         if index_file_exists:
             self.load_sharded_model(model, index_file_path, strict)
         else:
-            self.load_unsharded_model(model, checkpoint, strict)
+            path = Path(checkpoint, SAFE_WEIGHTS_NAME)
+            if path.is_file():
+                self.load_unsharded_model(model, str(path), strict)
+            else:
+                path = Path(checkpoint, WEIGHTS_NAME)
+                if path.is_file():
+                    self.load_unsharded_model(model, str(path), strict)
+                else:
+                    self.load_unsharded_model(model, checkpoint, strict)
 
         return origin_model
 
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index 34342436f..01c837ee3 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import torch
 import torch.distributed as dist
@@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
         if ctx.ep_size != 1:
             grad = grad / ctx.ep_size
         return grad, None
+
+
+def _all_to_all(
+    inputs: torch.Tensor,
+    input_split_sizes: Optional[List[int]] = None,
+    output_split_sizes: Optional[List[int]] = None,
+    group=None,
+    async_op: bool = False,
+):
+    """
+    Returns:
+        outputs: Tensor
+        handle: Optional[Work], if overlap is True
+    """
+    outputs_shape = list(inputs.shape)
+    if output_split_sizes is not None:
+        outputs_shape[0] = sum(output_split_sizes)
+    outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
+    inputs = inputs.contiguous()
+    outputs = outputs.contiguous()
+    handle = dist.all_to_all_single(
+        outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
+    )
+    return outputs, handle
+
+
+class AllToAllUneven(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        inputs,
+        input_split_sizes=None,
+        output_split_sizes=None,
+        group=None,
+        overlap: bool = False,
+    ):
+        """
+        Returns:
+            outputs: Tensor
+            handle: Optional[Work], if overlap is True
+        """
+        ctx.input_split_sizes = input_split_sizes
+        ctx.output_split_sizes = output_split_sizes
+        ctx.group = group
+        return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
+
+    @staticmethod
+    def backward(ctx: Any, *grad_outputs):
+        return (
+            _all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
+            None,
+            None,
+            None,
+            None,
+        )
+
+
+def all_to_all_uneven(
+    inputs: torch.Tensor,
+    input_split_sizes: Optional[List[int]] = None,
+    output_split_sizes: Optional[List[int]] = None,
+    group=None,
+    overlap: bool = False,
+):
+    return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py
index ba6c77056..5ac3c2b3a 100644
--- a/colossalai/tensor/moe_tensor/moe_info.py
+++ b/colossalai/tensor/moe_tensor/moe_info.py
@@ -26,3 +26,5 @@ class MoeParallelInfo:
         self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
         self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
         self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
+        self.ep_rank = self.pg.coordinate(self.ep_axis)
+        self.dp_rank = self.pg.coordinate(self.dp_axis)
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 47bc7603a..511eb26e8 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
                 working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
             self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
 
-    def sync_moe_master_param(self):
-        for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
-            master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
-
     def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
         r"""
         Compute and return the gradient norm for gradient clipping.
@@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
                     master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
                 else:
                     master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
+        for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
+            master_moe_param.copy_(working_moe_param)
 
     def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
         return self._param_store.working_to_master_param
 
     def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
-        return self._param_store.master_to_working_param
+        return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}