mirror of https://github.com/InternLM/InternLM
add support for moe checkpoint
parent
e32fbaaae2
commit
0e6b1f856c
|
@ -115,8 +115,9 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
loss = self._call_engine_criterion(engine, output, label)
|
loss = self._call_engine_criterion(engine, output, label)
|
||||||
self._call_hooks("after_criterion", loss)
|
self._call_hooks("after_criterion", loss)
|
||||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||||
loss += moe_loss
|
moe_loss /= scale_loss
|
||||||
loss /= scale_loss
|
loss /= scale_loss
|
||||||
|
loss += moe_loss
|
||||||
|
|
||||||
# backward
|
# backward
|
||||||
if not forward_only:
|
if not forward_only:
|
||||||
|
|
|
@ -74,7 +74,7 @@ class MoE(torch.nn.Module):
|
||||||
drop_tokens: bool = True,
|
drop_tokens: bool = True,
|
||||||
use_rts: bool = True,
|
use_rts: bool = True,
|
||||||
using_default_moe: bool = True,
|
using_default_moe: bool = True,
|
||||||
use_residual=True,
|
use_residual=False,
|
||||||
residual_mlp=None,
|
residual_mlp=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
|
@ -356,6 +356,8 @@ class TopKGate(Module):
|
||||||
# Only top-1 and top-2 are supported at the moment.
|
# Only top-1 and top-2 are supported at the moment.
|
||||||
if k not in (1, 2):
|
if k not in (1, 2):
|
||||||
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
||||||
|
# TODO: can we use tensor parallel here?
|
||||||
|
# Deepspeed's mechisms, alway use fp32
|
||||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
||||||
self.k = k
|
self.k = k
|
||||||
self.capacity_factor = capacity_factor
|
self.capacity_factor = capacity_factor
|
||||||
|
|
|
@ -4,8 +4,10 @@
|
||||||
import copy
|
import copy
|
||||||
import fcntl
|
import fcntl
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
@ -14,6 +16,7 @@ import torch
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.trainer import TrainState
|
from internlm.core.trainer import TrainState
|
||||||
|
from internlm.model.moe import MoE
|
||||||
from internlm.monitor import send_alert_message
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
|
@ -70,6 +73,8 @@ def save_model_checkpoint(folder, model):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
states = model.state_dict()
|
states = model.state_dict()
|
||||||
|
# get non-moe parameters
|
||||||
|
states = get_non_moe_state_dict(states)
|
||||||
topo = get_model_topology(model)
|
topo = get_model_topology(model)
|
||||||
|
|
||||||
if folder is not None:
|
if folder is not None:
|
||||||
|
@ -93,6 +98,9 @@ def save_model_checkpoint(folder, model):
|
||||||
topo_fp = os.path.join(folder, topo_fn)
|
topo_fp = os.path.join(folder, topo_fn)
|
||||||
llm_save(topo_fp, saved_obj=topo)
|
llm_save(topo_fp, saved_obj=topo)
|
||||||
|
|
||||||
|
# move the judgement logic into save_moe_checkpoint(.)
|
||||||
|
try_save_moe_checkpoint(folder, model)
|
||||||
|
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,6 +137,18 @@ def load_model_checkpoint(folder, model):
|
||||||
fp = os.path.join(folder, should_load_name)
|
fp = os.path.join(folder, should_load_name)
|
||||||
states = llm_load(fp, map_location=get_current_device())
|
states = llm_load(fp, map_location=get_current_device())
|
||||||
|
|
||||||
|
"""
|
||||||
|
# need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in
|
||||||
|
# gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make
|
||||||
|
# the gate parameters to be float16 before forward.
|
||||||
|
for key in list(states.keys()):
|
||||||
|
if 'moe_layer.gate.wg.weight' in key:
|
||||||
|
states[key] = states[key].float()
|
||||||
|
print("load: ", states[key].float(),flush=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
try_load_moe_checkpoint(folder, model, states)
|
||||||
|
|
||||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||||
if len(missing_k) != 0:
|
if len(missing_k) != 0:
|
||||||
logger.warning(f"Warning: missing keys {missing_k}")
|
logger.warning(f"Warning: missing keys {missing_k}")
|
||||||
|
@ -140,6 +160,58 @@ def load_model_checkpoint(folder, model):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def try_save_moe_checkpoint(folder, model):
|
||||||
|
# Using layer_#_expert_# to save the model's expert state_dict,a hack.
|
||||||
|
moe_layer_id = 0
|
||||||
|
for n_module, module in model.named_modules():
|
||||||
|
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
|
||||||
|
num_local_experts = module.num_local_experts
|
||||||
|
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
|
||||||
|
|
||||||
|
# get all moe parameters
|
||||||
|
moe_state_dict = {}
|
||||||
|
for n, p in module.state_dict().items():
|
||||||
|
if "expert" in n and "moe_layer.gate.wg.weight" not in n:
|
||||||
|
moe_state_dict[n_module + "." + n] = p
|
||||||
|
moe_str_prefix = ".moe_layer.experts.experts."
|
||||||
|
# Reorder the moe name rank, so that each checkpoint only has one expert
|
||||||
|
experts_state_dict = defaultdict(dict)
|
||||||
|
for key in list(moe_state_dict.keys()):
|
||||||
|
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
|
||||||
|
|
||||||
|
local_expert_id = None
|
||||||
|
if not m:
|
||||||
|
logger.warning(f"No expert found in key {key}.")
|
||||||
|
else:
|
||||||
|
local_expert_id = m.group(1)
|
||||||
|
|
||||||
|
global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
|
||||||
|
expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}")
|
||||||
|
|
||||||
|
# truncating extra tensor (shared) storage
|
||||||
|
truncated = moe_state_dict.pop(key).clone().detach()
|
||||||
|
experts_state_dict[str(global_expert_id)][expert_key] = truncated
|
||||||
|
|
||||||
|
# let save the moe parameters
|
||||||
|
for global_expert_id, expert_state_dict in experts_state_dict.items():
|
||||||
|
# save the moe parameters
|
||||||
|
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
|
||||||
|
fp = os.path.join(folder, fn)
|
||||||
|
llm_save(fp, saved_obj=expert_state_dict)
|
||||||
|
moe_layer_id += 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_non_moe_state_dict(full_state_dict):
|
||||||
|
"""
|
||||||
|
Get the state dict of the non-moe layers
|
||||||
|
"""
|
||||||
|
for key in list(full_state_dict.keys()):
|
||||||
|
if "expert" in key and "moe_layer.gate.wg.weight" not in key:
|
||||||
|
full_state_dict.pop(key)
|
||||||
|
|
||||||
|
return full_state_dict
|
||||||
|
|
||||||
|
|
||||||
def save_optimizer_checkpoint(optim, state_path):
|
def save_optimizer_checkpoint(optim, state_path):
|
||||||
"""Store the state of the optimizer to the local file system or remote OSS.
|
"""Store the state of the optimizer to the local file system or remote OSS.
|
||||||
|
|
||||||
|
@ -168,6 +240,27 @@ def save_optimizer_checkpoint(optim, state_path):
|
||||||
llm_save(os.path.join(state_path, fp), states)
|
llm_save(os.path.join(state_path, fp), states)
|
||||||
|
|
||||||
|
|
||||||
|
def try_load_moe_checkpoint(folder, model, state_dict):
|
||||||
|
moe_layer_id = 0
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
|
||||||
|
num_local_experts = module.num_local_experts
|
||||||
|
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
|
||||||
|
# loop all local_experts
|
||||||
|
for local_expert_id in range(num_local_experts):
|
||||||
|
global_expert_id = expp_rank * num_local_experts + local_expert_id
|
||||||
|
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
|
||||||
|
fp = os.path.join(folder, fn)
|
||||||
|
expert_state_dict = llm_load(fp, map_location=get_current_device())
|
||||||
|
# Updating global -> local expert ids
|
||||||
|
moe_str_prefix = ".moe_layer.experts.experts."
|
||||||
|
for key in list(expert_state_dict.keys()):
|
||||||
|
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
|
||||||
|
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
||||||
|
state_dict.update(expert_state_dict)
|
||||||
|
moe_layer_id += 1
|
||||||
|
|
||||||
|
|
||||||
def load_optimizer_checkpoint(folder, optim):
|
def load_optimizer_checkpoint(folder, optim):
|
||||||
"""Load the optimizer state from the local file system or remote
|
"""Load the optimizer state from the local file system or remote
|
||||||
object storage Service (OSS).
|
object storage Service (OSS).
|
||||||
|
|
Loading…
Reference in New Issue