diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py
index 772cbb977..7c6012a70 100644
--- a/applications/ColossalMoE/tests/test_moe_checkpoint.py
+++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py
@@ -126,6 +126,12 @@ def _test_moe_checkpoint(parallel):
     model1, booster1, optim1 = get_model(parallel)
     model2, booster2, optim2 = get_model(parallel)
     # param ckpt
+    # check not equal
+    try:
+        check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
+        raise AssertionError("state_dict should not be equal")
+    except:
+        pass
     # shard
     booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
     booster2.load_model(model2, "./tmp_ckpt1")
diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py
index a4cf438aa..7c8807c24 100644
--- a/applications/ColossalMoE/train.py
+++ b/applications/ColossalMoE/train.py
@@ -1,7 +1,7 @@
 import argparse
-import torch.distributed as dist
 
 import torch
+import torch.distributed as dist
 from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
 from colossal_moe.models.mixtral_layer import replace_moe_layer
 from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
@@ -10,7 +10,6 @@ from torch.utils.data import Dataset
 from tqdm import tqdm
 from transformers import AutoTokenizer
 from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
 
 import colossalai
 from colossalai.booster import Booster
@@ -19,11 +18,11 @@ from colossalai.cluster import DistCoordinator
 from colossalai.moe import MOE_MANAGER, apply_load_balance
 from colossalai.moe.layers import apply_load_balance
 from colossalai.moe.manager import MOE_MANAGER
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
 from colossalai.nn.optimizer import HybridAdam
 from colossalai.utils import get_current_device
 
 
-
 @torch.no_grad()
 def get_global_loss(loss, booster):
     global_loss = loss.clone().detach()
@@ -31,6 +30,7 @@ def get_global_loss(loss, booster):
     global_loss.div_(booster.plugin.dp_size)
     return global_loss
 
+
 class RandomDataset(Dataset):
     def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
         self.num_samples = num_samples
@@ -97,7 +97,7 @@ def parse_args():
     # optim
     parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
     parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
-    
+
     # lr scheduler
     parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
     parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
@@ -197,7 +197,7 @@ def main():
 
     # Prepare tokenizer and dataloader
     tokenizer = AutoTokenizer.from_pretrained(args.model_name)
-    dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
+    dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
     collate_fn = None
     dataloader = plugin.prepare_dataloader(
         dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
@@ -211,7 +211,7 @@ def main():
         weight_decay=args.weight_decay,
         adamw_mode=True,
     )
-    
+
     # Set lr scheduler
     lr_scheduler = CosineAnnealingWarmupLR(
         optimizer=optimizer,
@@ -264,7 +264,7 @@ def main():
                     if is_pp_last_stage:
                         loss = outputs["loss"]
                         global_loss = get_global_loss(loss, booster)
-                        if coordinator._local_rank == '0':
+                        if coordinator._local_rank == "0":
                             pbar.set_postfix({"Loss": global_loss.item()})
                 else:
                     # Forward pass
diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py
index 9e0d53aeb..9928c801d 100644
--- a/colossalai/moe/checkpoint.py
+++ b/colossalai/moe/checkpoint.py
@@ -334,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
         assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
 
         def _get_param_id_from_optimizer_param(
-            param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+            param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
         ):
             if master_to_working_map is not None and id(param) in master_to_working_map:
                 working_param = master_to_working_map[id(param)]
+            elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
+                working_param = optimizer.moe_master_to_working_map[id(param)]
             else:
                 working_param = param
             return optimizer.param_info["param2id"][id(working_param)]
@@ -349,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
         master_to_working_map = optimizer.get_master_to_working_map()
         for pg in optimizer.optim.param_groups:
             for param in pg["params"]:
-                param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+                param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
                 id_map[param_id] = param
 
         # Read checkpoint index file.
@@ -373,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
             new_pg = copy.deepcopy(saved_pg)
             new_pg["params"] = old_pg["params"]  # The parameters in the same group shouln't change.
             updated_groups.append(new_pg)
-        # ep extra group
-        if MOE_MANAGER.parallel == "EP":
+        # ep param group
+        if len(optimizer.optim.param_groups) > len(saved_groups):
             new_pg = copy.deepcopy(saved_pg)
-            new_pg["params"] = optimizer.optim.param_groups[-1][
-                "params"
-            ]  # Only keep the parameters kept by current pipeline stage.
-            for param in new_pg["params"]:
-                param.data = param.data.to(torch.float32)
+            new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
             updated_groups.append(new_pg)
         optimizer.optim.__dict__.update({"param_groups": updated_groups})
 
@@ -391,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
             for param in pg["params"]:
                 if param is None:
                     continue
-                param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+                param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
                 if param_id not in weight_map:
                     continue
                 filename = weight_map[param_id]
@@ -410,12 +408,14 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
             device = param.device
             if master_to_working_map is not None and id(param) in master_to_working_map:
                 working_param = master_to_working_map[id(param)]
+            elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
+                working_param = optimizer.moe_master_to_working_map[id(param)]
             else:
                 working_param = param
             original_shape = optimizer.param_info["param2shape"][id(working_param)]
             sharded_state = self.pre_load_optim(
                 state,
-                param,
+                working_param,
                 current_shape=working_param.shape,
                 original_shape=original_shape,
                 device=device,
@@ -578,6 +578,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
 
             if master_to_working_map is not None and id(param) in master_to_working_map:
                 working_param = master_to_working_map[id(param)]
+            elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
+                working_param = optimizer.moe_master_to_working_map[id(param)]
             else:
                 working_param = param
 
@@ -620,6 +622,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
             prefix (str): Perfix of file to save
             size_per_shard (int): Max file size of each file shard that store state tensors
         """
+        torch.cuda.empty_cache()
         assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
         if os.path.isfile(checkpoint):
             logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@@ -725,6 +728,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
                         f"You can find where each parameters has been saved in the "
                         f"index located at {final_index_file_path}."
                     )
+        torch.cuda.empty_cache()
 
     def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
         """
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index 8d2346a3c..553383f4c 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -175,12 +175,18 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
         if len(self.working_moe_params) > 0:
             self._sync_master_param = False
             param_group = dict()
+            # create fp32 master param
             for key, value in self.optim.param_groups[0].items():
                 if key != "params":
                     param_group[key] = value
             self.master_moe_params = []
             for param in self.working_moe_params:
                 self.master_moe_params.append(param.clone().to(torch.float32).detach())
+            # create mapping from master to working for optimizer io
+            self.moe_master_to_working_map = {}
+            for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
+                self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
+            # add to optim
             param_group["params"] = self.master_moe_params
             self.optim.param_groups.append(param_group)