From 0e6b1f856cbc71b4926d975a55b0ae3e80a1d46d Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 24 Aug 2023 17:01:14 +0800 Subject: [PATCH] add support for moe checkpoint --- .../core/scheduler/no_pipeline_scheduler.py | 3 +- internlm/model/moe.py | 2 +- internlm/moe/sharded_moe.py | 2 + internlm/utils/model_checkpoint.py | 93 +++++++++++++++++++ 4 files changed, 98 insertions(+), 2 deletions(-) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index dd9b49a..c1e8830 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -115,8 +115,9 @@ class NonPipelineScheduler(BaseScheduler): loss = self._call_engine_criterion(engine, output, label) self._call_hooks("after_criterion", loss) moe_loss = sum(moe_losses) * moe_loss_coeff - loss += moe_loss + moe_loss /= scale_loss loss /= scale_loss + loss += moe_loss # backward if not forward_only: diff --git a/internlm/model/moe.py b/internlm/model/moe.py index 75beb14..1504838 100644 --- a/internlm/model/moe.py +++ b/internlm/model/moe.py @@ -74,7 +74,7 @@ class MoE(torch.nn.Module): drop_tokens: bool = True, use_rts: bool = True, using_default_moe: bool = True, - use_residual=True, + use_residual=False, residual_mlp=None, ): diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 1ae68e0..3bd529b 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -356,6 +356,8 @@ class TopKGate(Module): # Only top-1 and top-2 are supported at the moment. if k not in (1, 2): raise ValueError("Only top-1 and top-2 gatings are supported.") + # TODO: can we use tensor parallel here? + # Deepspeed's mechisms, alway use fp32 self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() self.k = k self.capacity_factor = capacity_factor diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 08d9db7..2c7a8f4 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -4,8 +4,10 @@ import copy import fcntl import os +import re import socket import time +from collections import defaultdict from enum import Enum from typing import Dict @@ -14,6 +16,7 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState +from internlm.model.moe import MoE from internlm.monitor import send_alert_message from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device @@ -70,6 +73,8 @@ def save_model_checkpoint(folder, model): """ states = model.state_dict() + # get non-moe parameters + states = get_non_moe_state_dict(states) topo = get_model_topology(model) if folder is not None: @@ -93,6 +98,9 @@ def save_model_checkpoint(folder, model): topo_fp = os.path.join(folder, topo_fn) llm_save(topo_fp, saved_obj=topo) + # move the judgement logic into save_moe_checkpoint(.) + try_save_moe_checkpoint(folder, model) + torch.distributed.barrier() @@ -129,6 +137,18 @@ def load_model_checkpoint(folder, model): fp = os.path.join(folder, should_load_name) states = llm_load(fp, map_location=get_current_device()) + """ + # need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in + # gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make + # the gate parameters to be float16 before forward. + for key in list(states.keys()): + if 'moe_layer.gate.wg.weight' in key: + states[key] = states[key].float() + print("load: ", states[key].float(),flush=True) + """ + + try_load_moe_checkpoint(folder, model, states) + missing_k, unexpected_keys = model.load_state_dict(states, strict=False) if len(missing_k) != 0: logger.warning(f"Warning: missing keys {missing_k}") @@ -140,6 +160,58 @@ def load_model_checkpoint(folder, model): torch.cuda.empty_cache() +def try_save_moe_checkpoint(folder, model): + # Using layer_#_expert_# to save the model's expert state_dict,a hack. + moe_layer_id = 0 + for n_module, module in model.named_modules(): + if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: + num_local_experts = module.num_local_experts + expp_rank = gpc.get_local_rank(ParallelMode.EXPERT) + + # get all moe parameters + moe_state_dict = {} + for n, p in module.state_dict().items(): + if "expert" in n and "moe_layer.gate.wg.weight" not in n: + moe_state_dict[n_module + "." + n] = p + moe_str_prefix = ".moe_layer.experts.experts." + # Reorder the moe name rank, so that each checkpoint only has one expert + experts_state_dict = defaultdict(dict) + for key in list(moe_state_dict.keys()): + m = re.match(f".*{moe_str_prefix}([0-9]+).*", key) + + local_expert_id = None + if not m: + logger.warning(f"No expert found in key {key}.") + else: + local_expert_id = m.group(1) + + global_expert_id = expp_rank * num_local_experts + int(local_expert_id) + expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}") + + # truncating extra tensor (shared) storage + truncated = moe_state_dict.pop(key).clone().detach() + experts_state_dict[str(global_expert_id)][expert_key] = truncated + + # let save the moe parameters + for global_expert_id, expert_state_dict in experts_state_dict.items(): + # save the moe parameters + fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt" + fp = os.path.join(folder, fn) + llm_save(fp, saved_obj=expert_state_dict) + moe_layer_id += 1 + + +def get_non_moe_state_dict(full_state_dict): + """ + Get the state dict of the non-moe layers + """ + for key in list(full_state_dict.keys()): + if "expert" in key and "moe_layer.gate.wg.weight" not in key: + full_state_dict.pop(key) + + return full_state_dict + + def save_optimizer_checkpoint(optim, state_path): """Store the state of the optimizer to the local file system or remote OSS. @@ -168,6 +240,27 @@ def save_optimizer_checkpoint(optim, state_path): llm_save(os.path.join(state_path, fp), states) +def try_load_moe_checkpoint(folder, model, state_dict): + moe_layer_id = 0 + for _, module in model.named_modules(): + if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0: + num_local_experts = module.num_local_experts + expp_rank = gpc.get_local_rank(ParallelMode.EXPERT) + # loop all local_experts + for local_expert_id in range(num_local_experts): + global_expert_id = expp_rank * num_local_experts + local_expert_id + fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt" + fp = os.path.join(folder, fn) + expert_state_dict = llm_load(fp, map_location=get_current_device()) + # Updating global -> local expert ids + moe_str_prefix = ".moe_layer.experts.experts." + for key in list(expert_state_dict.keys()): + local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}") + expert_state_dict[local_key] = expert_state_dict.pop(key) + state_dict.update(expert_state_dict) + moe_layer_id += 1 + + def load_optimizer_checkpoint(folder, optim): """Load the optimizer state from the local file system or remote object storage Service (OSS).