add support for moe checkpoint

pull/375/head
Wenwen Qu 2023-08-24 17:01:14 +08:00
parent e32fbaaae2
commit 0e6b1f856c
4 changed files with 98 additions and 2 deletions

View File

@ -115,8 +115,9 @@ class NonPipelineScheduler(BaseScheduler):
loss = self._call_engine_criterion(engine, output, label) loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss) self._call_hooks("after_criterion", loss)
moe_loss = sum(moe_losses) * moe_loss_coeff moe_loss = sum(moe_losses) * moe_loss_coeff
loss += moe_loss moe_loss /= scale_loss
loss /= scale_loss loss /= scale_loss
loss += moe_loss
# backward # backward
if not forward_only: if not forward_only:

View File

@ -74,7 +74,7 @@ class MoE(torch.nn.Module):
drop_tokens: bool = True, drop_tokens: bool = True,
use_rts: bool = True, use_rts: bool = True,
using_default_moe: bool = True, using_default_moe: bool = True,
use_residual=True, use_residual=False,
residual_mlp=None, residual_mlp=None,
): ):

View File

@ -356,6 +356,8 @@ class TopKGate(Module):
# Only top-1 and top-2 are supported at the moment. # Only top-1 and top-2 are supported at the moment.
if k not in (1, 2): if k not in (1, 2):
raise ValueError("Only top-1 and top-2 gatings are supported.") raise ValueError("Only top-1 and top-2 gatings are supported.")
# TODO: can we use tensor parallel here?
# Deepspeed's mechisms, alway use fp32
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.k = k self.k = k
self.capacity_factor = capacity_factor self.capacity_factor = capacity_factor

View File

@ -4,8 +4,10 @@
import copy import copy
import fcntl import fcntl
import os import os
import re
import socket import socket
import time import time
from collections import defaultdict
from enum import Enum from enum import Enum
from typing import Dict from typing import Dict
@ -14,6 +16,7 @@ import torch
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState from internlm.core.trainer import TrainState
from internlm.model.moe import MoE
from internlm.monitor import send_alert_message from internlm.monitor import send_alert_message
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import get_current_device from internlm.utils.common import get_current_device
@ -70,6 +73,8 @@ def save_model_checkpoint(folder, model):
""" """
states = model.state_dict() states = model.state_dict()
# get non-moe parameters
states = get_non_moe_state_dict(states)
topo = get_model_topology(model) topo = get_model_topology(model)
if folder is not None: if folder is not None:
@ -93,6 +98,9 @@ def save_model_checkpoint(folder, model):
topo_fp = os.path.join(folder, topo_fn) topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo) llm_save(topo_fp, saved_obj=topo)
# move the judgement logic into save_moe_checkpoint(.)
try_save_moe_checkpoint(folder, model)
torch.distributed.barrier() torch.distributed.barrier()
@ -129,6 +137,18 @@ def load_model_checkpoint(folder, model):
fp = os.path.join(folder, should_load_name) fp = os.path.join(folder, should_load_name)
states = llm_load(fp, map_location=get_current_device()) states = llm_load(fp, map_location=get_current_device())
"""
# need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in
# gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make
# the gate parameters to be float16 before forward.
for key in list(states.keys()):
if 'moe_layer.gate.wg.weight' in key:
states[key] = states[key].float()
print("load: ", states[key].float(),flush=True)
"""
try_load_moe_checkpoint(folder, model, states)
missing_k, unexpected_keys = model.load_state_dict(states, strict=False) missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
if len(missing_k) != 0: if len(missing_k) != 0:
logger.warning(f"Warning: missing keys {missing_k}") logger.warning(f"Warning: missing keys {missing_k}")
@ -140,6 +160,58 @@ def load_model_checkpoint(folder, model):
torch.cuda.empty_cache() torch.cuda.empty_cache()
def try_save_moe_checkpoint(folder, model):
# Using layer_#_expert_# to save the model's expert state_dicta hack.
moe_layer_id = 0
for n_module, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
num_local_experts = module.num_local_experts
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
# get all moe parameters
moe_state_dict = {}
for n, p in module.state_dict().items():
if "expert" in n and "moe_layer.gate.wg.weight" not in n:
moe_state_dict[n_module + "." + n] = p
moe_str_prefix = ".moe_layer.experts.experts."
# Reorder the moe name rank, so that each checkpoint only has one expert
experts_state_dict = defaultdict(dict)
for key in list(moe_state_dict.keys()):
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
local_expert_id = None
if not m:
logger.warning(f"No expert found in key {key}.")
else:
local_expert_id = m.group(1)
global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}")
# truncating extra tensor (shared) storage
truncated = moe_state_dict.pop(key).clone().detach()
experts_state_dict[str(global_expert_id)][expert_key] = truncated
# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
fp = os.path.join(folder, fn)
llm_save(fp, saved_obj=expert_state_dict)
moe_layer_id += 1
def get_non_moe_state_dict(full_state_dict):
"""
Get the state dict of the non-moe layers
"""
for key in list(full_state_dict.keys()):
if "expert" in key and "moe_layer.gate.wg.weight" not in key:
full_state_dict.pop(key)
return full_state_dict
def save_optimizer_checkpoint(optim, state_path): def save_optimizer_checkpoint(optim, state_path):
"""Store the state of the optimizer to the local file system or remote OSS. """Store the state of the optimizer to the local file system or remote OSS.
@ -168,6 +240,27 @@ def save_optimizer_checkpoint(optim, state_path):
llm_save(os.path.join(state_path, fp), states) llm_save(os.path.join(state_path, fp), states)
def try_load_moe_checkpoint(folder, model, state_dict):
moe_layer_id = 0
for _, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
num_local_experts = module.num_local_experts
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
fp = os.path.join(folder, fn)
expert_state_dict = llm_load(fp, map_location=get_current_device())
# Updating global -> local expert ids
moe_str_prefix = ".moe_layer.experts.experts."
for key in list(expert_state_dict.keys()):
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
moe_layer_id += 1
def load_optimizer_checkpoint(folder, optim): def load_optimizer_checkpoint(folder, optim):
"""Load the optimizer state from the local file system or remote """Load the optimizer state from the local file system or remote
object storage Service (OSS). object storage Service (OSS).