add support for moe checkpoint

pull/375/head
Wenwen Qu 2023-08-24 17:01:14 +08:00
parent e32fbaaae2
commit 0e6b1f856c
4 changed files with 98 additions and 2 deletions

View File

@ -115,8 +115,9 @@ class NonPipelineScheduler(BaseScheduler):
loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
moe_loss = sum(moe_losses) * moe_loss_coeff
loss += moe_loss
moe_loss /= scale_loss
loss /= scale_loss
loss += moe_loss
# backward
if not forward_only:

View File

@ -74,7 +74,7 @@ class MoE(torch.nn.Module):
drop_tokens: bool = True,
use_rts: bool = True,
using_default_moe: bool = True,
use_residual=True,
use_residual=False,
residual_mlp=None,
):

View File

@ -356,6 +356,8 @@ class TopKGate(Module):
# Only top-1 and top-2 are supported at the moment.
if k not in (1, 2):
raise ValueError("Only top-1 and top-2 gatings are supported.")
# TODO: can we use tensor parallel here?
# Deepspeed's mechisms, alway use fp32
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.k = k
self.capacity_factor = capacity_factor

View File

@ -4,8 +4,10 @@
import copy
import fcntl
import os
import re
import socket
import time
from collections import defaultdict
from enum import Enum
from typing import Dict
@ -14,6 +16,7 @@ import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
from internlm.model.moe import MoE
from internlm.monitor import send_alert_message
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import get_current_device
@ -70,6 +73,8 @@ def save_model_checkpoint(folder, model):
"""
states = model.state_dict()
# get non-moe parameters
states = get_non_moe_state_dict(states)
topo = get_model_topology(model)
if folder is not None:
@ -93,6 +98,9 @@ def save_model_checkpoint(folder, model):
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)
# move the judgement logic into save_moe_checkpoint(.)
try_save_moe_checkpoint(folder, model)
torch.distributed.barrier()
@ -129,6 +137,18 @@ def load_model_checkpoint(folder, model):
fp = os.path.join(folder, should_load_name)
states = llm_load(fp, map_location=get_current_device())
"""
# need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in
# gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make
# the gate parameters to be float16 before forward.
for key in list(states.keys()):
if 'moe_layer.gate.wg.weight' in key:
states[key] = states[key].float()
print("load: ", states[key].float(),flush=True)
"""
try_load_moe_checkpoint(folder, model, states)
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
if len(missing_k) != 0:
logger.warning(f"Warning: missing keys {missing_k}")
@ -140,6 +160,58 @@ def load_model_checkpoint(folder, model):
torch.cuda.empty_cache()
def try_save_moe_checkpoint(folder, model):
# Using layer_#_expert_# to save the model's expert state_dicta hack.
moe_layer_id = 0
for n_module, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
num_local_experts = module.num_local_experts
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
# get all moe parameters
moe_state_dict = {}
for n, p in module.state_dict().items():
if "expert" in n and "moe_layer.gate.wg.weight" not in n:
moe_state_dict[n_module + "." + n] = p
moe_str_prefix = ".moe_layer.experts.experts."
# Reorder the moe name rank, so that each checkpoint only has one expert
experts_state_dict = defaultdict(dict)
for key in list(moe_state_dict.keys()):
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
local_expert_id = None
if not m:
logger.warning(f"No expert found in key {key}.")
else:
local_expert_id = m.group(1)
global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}")
# truncating extra tensor (shared) storage
truncated = moe_state_dict.pop(key).clone().detach()
experts_state_dict[str(global_expert_id)][expert_key] = truncated
# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
fp = os.path.join(folder, fn)
llm_save(fp, saved_obj=expert_state_dict)
moe_layer_id += 1
def get_non_moe_state_dict(full_state_dict):
"""
Get the state dict of the non-moe layers
"""
for key in list(full_state_dict.keys()):
if "expert" in key and "moe_layer.gate.wg.weight" not in key:
full_state_dict.pop(key)
return full_state_dict
def save_optimizer_checkpoint(optim, state_path):
"""Store the state of the optimizer to the local file system or remote OSS.
@ -168,6 +240,27 @@ def save_optimizer_checkpoint(optim, state_path):
llm_save(os.path.join(state_path, fp), states)
def try_load_moe_checkpoint(folder, model, state_dict):
moe_layer_id = 0
for _, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
num_local_experts = module.num_local_experts
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
fp = os.path.join(folder, fn)
expert_state_dict = llm_load(fp, map_location=get_current_device())
# Updating global -> local expert ids
moe_str_prefix = ".moe_layer.experts.experts."
for key in list(expert_state_dict.keys()):
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
moe_layer_id += 1
def load_optimizer_checkpoint(folder, optim):
"""Load the optimizer state from the local file system or remote
object storage Service (OSS).