mirror of https://github.com/InternLM/InternLM
add support for moe checkpoint
parent
e32fbaaae2
commit
0e6b1f856c
|
@ -115,8 +115,9 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
loss = self._call_engine_criterion(engine, output, label)
|
||||
self._call_hooks("after_criterion", loss)
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
loss += moe_loss
|
||||
moe_loss /= scale_loss
|
||||
loss /= scale_loss
|
||||
loss += moe_loss
|
||||
|
||||
# backward
|
||||
if not forward_only:
|
||||
|
|
|
@ -74,7 +74,7 @@ class MoE(torch.nn.Module):
|
|||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
using_default_moe: bool = True,
|
||||
use_residual=True,
|
||||
use_residual=False,
|
||||
residual_mlp=None,
|
||||
):
|
||||
|
||||
|
|
|
@ -356,6 +356,8 @@ class TopKGate(Module):
|
|||
# Only top-1 and top-2 are supported at the moment.
|
||||
if k not in (1, 2):
|
||||
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
||||
# TODO: can we use tensor parallel here?
|
||||
# Deepspeed's mechisms, alway use fp32
|
||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
||||
self.k = k
|
||||
self.capacity_factor = capacity_factor
|
||||
|
|
|
@ -4,8 +4,10 @@
|
|||
import copy
|
||||
import fcntl
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
|
@ -14,6 +16,7 @@ import torch
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.model.moe import MoE
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.utils.common import get_current_device
|
||||
|
@ -70,6 +73,8 @@ def save_model_checkpoint(folder, model):
|
|||
"""
|
||||
|
||||
states = model.state_dict()
|
||||
# get non-moe parameters
|
||||
states = get_non_moe_state_dict(states)
|
||||
topo = get_model_topology(model)
|
||||
|
||||
if folder is not None:
|
||||
|
@ -93,6 +98,9 @@ def save_model_checkpoint(folder, model):
|
|||
topo_fp = os.path.join(folder, topo_fn)
|
||||
llm_save(topo_fp, saved_obj=topo)
|
||||
|
||||
# move the judgement logic into save_moe_checkpoint(.)
|
||||
try_save_moe_checkpoint(folder, model)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
|
@ -129,6 +137,18 @@ def load_model_checkpoint(folder, model):
|
|||
fp = os.path.join(folder, should_load_name)
|
||||
states = llm_load(fp, map_location=get_current_device())
|
||||
|
||||
"""
|
||||
# need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in
|
||||
# gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make
|
||||
# the gate parameters to be float16 before forward.
|
||||
for key in list(states.keys()):
|
||||
if 'moe_layer.gate.wg.weight' in key:
|
||||
states[key] = states[key].float()
|
||||
print("load: ", states[key].float(),flush=True)
|
||||
"""
|
||||
|
||||
try_load_moe_checkpoint(folder, model, states)
|
||||
|
||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||
if len(missing_k) != 0:
|
||||
logger.warning(f"Warning: missing keys {missing_k}")
|
||||
|
@ -140,6 +160,58 @@ def load_model_checkpoint(folder, model):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def try_save_moe_checkpoint(folder, model):
|
||||
# Using layer_#_expert_# to save the model's expert state_dict,a hack.
|
||||
moe_layer_id = 0
|
||||
for n_module, module in model.named_modules():
|
||||
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
|
||||
num_local_experts = module.num_local_experts
|
||||
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
|
||||
|
||||
# get all moe parameters
|
||||
moe_state_dict = {}
|
||||
for n, p in module.state_dict().items():
|
||||
if "expert" in n and "moe_layer.gate.wg.weight" not in n:
|
||||
moe_state_dict[n_module + "." + n] = p
|
||||
moe_str_prefix = ".moe_layer.experts.experts."
|
||||
# Reorder the moe name rank, so that each checkpoint only has one expert
|
||||
experts_state_dict = defaultdict(dict)
|
||||
for key in list(moe_state_dict.keys()):
|
||||
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
|
||||
|
||||
local_expert_id = None
|
||||
if not m:
|
||||
logger.warning(f"No expert found in key {key}.")
|
||||
else:
|
||||
local_expert_id = m.group(1)
|
||||
|
||||
global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
|
||||
expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}")
|
||||
|
||||
# truncating extra tensor (shared) storage
|
||||
truncated = moe_state_dict.pop(key).clone().detach()
|
||||
experts_state_dict[str(global_expert_id)][expert_key] = truncated
|
||||
|
||||
# let save the moe parameters
|
||||
for global_expert_id, expert_state_dict in experts_state_dict.items():
|
||||
# save the moe parameters
|
||||
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
llm_save(fp, saved_obj=expert_state_dict)
|
||||
moe_layer_id += 1
|
||||
|
||||
|
||||
def get_non_moe_state_dict(full_state_dict):
|
||||
"""
|
||||
Get the state dict of the non-moe layers
|
||||
"""
|
||||
for key in list(full_state_dict.keys()):
|
||||
if "expert" in key and "moe_layer.gate.wg.weight" not in key:
|
||||
full_state_dict.pop(key)
|
||||
|
||||
return full_state_dict
|
||||
|
||||
|
||||
def save_optimizer_checkpoint(optim, state_path):
|
||||
"""Store the state of the optimizer to the local file system or remote OSS.
|
||||
|
||||
|
@ -168,6 +240,27 @@ def save_optimizer_checkpoint(optim, state_path):
|
|||
llm_save(os.path.join(state_path, fp), states)
|
||||
|
||||
|
||||
def try_load_moe_checkpoint(folder, model, state_dict):
|
||||
moe_layer_id = 0
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
|
||||
num_local_experts = module.num_local_experts
|
||||
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
|
||||
# loop all local_experts
|
||||
for local_expert_id in range(num_local_experts):
|
||||
global_expert_id = expp_rank * num_local_experts + local_expert_id
|
||||
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
expert_state_dict = llm_load(fp, map_location=get_current_device())
|
||||
# Updating global -> local expert ids
|
||||
moe_str_prefix = ".moe_layer.experts.experts."
|
||||
for key in list(expert_state_dict.keys()):
|
||||
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
|
||||
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
||||
state_dict.update(expert_state_dict)
|
||||
moe_layer_id += 1
|
||||
|
||||
|
||||
def load_optimizer_checkpoint(folder, optim):
|
||||
"""Load the optimizer state from the local file system or remote
|
||||
object storage Service (OSS).
|
||||
|
|
Loading…
Reference in New Issue