feat(model): support llama model with checkpoint loading (#532)

* support hf llama

* support hf llama

* support hf llama

* support hf llama

* importerror

* importerror

* modeling

* modeling
pull/538/head
jiaxingli 2023-12-11 16:25:24 +08:00 committed by GitHub
parent 81ffb3d824
commit 6c0ff4820f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1352 additions and 7 deletions

View File

@ -28,7 +28,7 @@ ckpt = dict(
# 'load_ckpt_info' setting guide: # 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path, # 1. the 'path' indicate ckpt path,
# 2. the 'content means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" # 2. the 'content means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ckpt_type means the type of checkpoint to be loaded, now only 'normal' type is supported. # 3. the ckpt_type means the type of checkpoint to be loaded, support: "internlm", "llama", "hf_llama".
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)

View File

@ -5,6 +5,7 @@ from .embedding import Embedding1D, RotaryEmbedding
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
from .metrics import AccPerplex from .metrics import AccPerplex
from .modeling_internlm import build_model_with_cfg from .modeling_internlm import build_model_with_cfg
from .modeling_llama import build_model_with_cfg as build_model_with_llama_cfg
from .modeling_moe import build_model_with_moe_cfg from .modeling_moe import build_model_with_moe_cfg
from .moe import MoE from .moe import MoE
from .multi_head_attention import MHA from .multi_head_attention import MHA
@ -22,4 +23,5 @@ __all__ = [
"gather_forward_split_backward", "gather_forward_split_backward",
"build_model_with_cfg", "build_model_with_cfg",
"build_model_with_moe_cfg", "build_model_with_moe_cfg",
"build_model_with_llama_cfg",
] ]

File diff suppressed because it is too large Load Diff

View File

@ -50,12 +50,16 @@ class CheckpointSaveType(Enum):
class CheckpointLoadType(Enum): class CheckpointLoadType(Enum):
INTERNLM = "internlm" INTERNLM = "internlm"
HF_LLAMA = "hf_llama"
LLAMA = "llama"
# The load method implemented by internlm by default does not use string representation types, # The load method implemented by internlm by default does not use string representation types,
# but uses enumeration types defined in advance. # but uses enumeration types defined in advance.
LOAD_TYPE_DICT = { LOAD_TYPE_DICT = {
"internlm": CheckpointLoadType.INTERNLM, "internlm": CheckpointLoadType.INTERNLM,
"hf_llama": CheckpointLoadType.HF_LLAMA,
"llama": CheckpointLoadType.LLAMA,
} }
@ -74,7 +78,7 @@ class CheckpointLoadMethod:
LOAD_TYPE_FUNC = {} LOAD_TYPE_FUNC = {}
@staticmethod @staticmethod
def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]: def convert_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
if load_type.lower() in LOAD_TYPE_DICT: if load_type.lower() in LOAD_TYPE_DICT:
# The ckpt load method implemented by internlm by default. # The ckpt load method implemented by internlm by default.
return LOAD_TYPE_DICT[load_type.lower()] return LOAD_TYPE_DICT[load_type.lower()]
@ -90,7 +94,11 @@ class CheckpointLoadMethod:
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func}) CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
if load_type == CheckpointLoadType.INTERNLM: if load_type in (
CheckpointLoadType.INTERNLM,
CheckpointLoadType.HF_LLAMA,
CheckpointLoadType.LLAMA,
):
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func) CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
else: else:
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log(): if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log():
@ -188,13 +196,33 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs):
return (missing_k, unexpected_keys) return (missing_k, unexpected_keys)
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): def process_load_info(load_info):
load_content_str = "" load_content_str = ""
load_ckpt_folder = load_info["path"] load_ckpt_folder = load_info["path"]
load_content: CheckpointLoadMask = load_info["content"] load_content: CheckpointLoadMask = load_info["content"]
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
return load_content_str, load_ckpt_folder, load_content
def try_load_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613
load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
if load_content.need_load(CheckpointLoadContent.MODEL):
load_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model)
load_content_str += f"{CheckpointLoadContent.MODEL}, "
def try_load_hf_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613
load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
if load_content.need_load(CheckpointLoadContent.MODEL):
load_hf_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model)
load_content_str += f"{CheckpointLoadContent.MODEL}, "
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
if load_content.need_load(CheckpointLoadContent.MODEL): if load_content.need_load(CheckpointLoadContent.MODEL):
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
load_content_str += f"{CheckpointLoadContent.MODEL}, " load_content_str += f"{CheckpointLoadContent.MODEL}, "
@ -314,6 +342,170 @@ def save_model_checkpoint(folder, model):
torch.distributed.barrier() torch.distributed.barrier()
def load_llama_pretrained_weights(folder, model):
model = model.model
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")
fns = get_fns(folder)
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")]
model_fns.sort()
old_tp = len(model_fns)
cur_tp = gpc.get_world_size(ParallelMode.TENSOR)
# If the two tp are inconsistent, you need to consider the merge before splitting
if old_tp != cur_tp:
raise RuntimeError(
f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first"
)
states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu")
current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
if gpc.config.model_type == "LLAMA":
# LLAMA's w2 and w3 are in reverse order
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
states[f"layers.{i}.feed_forward.w2.weight"] = w3
states[f"layers.{i}.feed_forward.w3.weight"] = w2
if "rope.freqs" in states:
states[f"layers.{i}.attention.rotary_emb.inv_freq"] = states["rope.freqs"]
for name in list(states.keys()):
if f".{i}." in name:
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
model_state_keys = set(list(model.state_dict().keys()))
if "tok_embeddings.weight" in model_state_keys:
current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"]
assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}"
if "output.weight" in model_state_keys:
current_states["norm.weight"] = states["norm.weight"]
current_states["output.weight"] = states["output.weight"]
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
)
del states
del current_states
torch.cuda.empty_cache()
def load_hf_llama_pretrained_weights(folder, model):
model = model.model
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")
fns = get_fns(folder)
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") and fn.startswith("pytorch_model")]
model_fns.sort()
states = {}
for model_fn in model_fns:
states.update(llm_load(model_fn, map_location="cpu"))
deep_split = getattr(model, "deep_split", False)
if deep_split:
print("using deep split when loading pretrained weights!")
current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
if gpc.config.model_type == "LLAMA":
if deep_split:
layer_ids = i // 2
else:
layer_ids = i
if not deep_split or (i + 2) % 2 == 0:
states[f"layers.{i}.attention.wq.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wk.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wv.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wo.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention_norm.weight"] = states.pop(
f"model.layers.{layer_ids}.input_layernorm.weight"
)
if not deep_split or (i + 2) % 2 == 1:
states[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.ffn_norm.weight"] = states.pop(
f"model.layers.{layer_ids}.post_attention_layernorm.weight"
)
if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states:
states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq")
for name in list(states.keys()):
if name.startswith(f"layers.{i}"):
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
model_state_keys = set(list(model.state_dict().keys()))
if "tok_embeddings.weight" in model_state_keys or "tok_embeddings.word_embeddings.weight" in model_state_keys:
if gpc.config.model.get("embed_split_hidden", True):
current_states["tok_embeddings.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
else:
current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"
if "output.weight" in model_state_keys:
current_states["norm.weight"] = states["model.norm.weight"]
current_states["output.weight"] = torch.chunk(
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
)[gpc.get_local_rank(ParallelMode.TENSOR)]
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
)
torch.cuda.empty_cache()
def load_model_checkpoint(folder, model): def load_model_checkpoint(folder, model):
""" """
There should be weights with names similar to the following under the folder. There should be weights with names similar to the following under the folder.
@ -682,7 +874,11 @@ class CheckpointManager:
self.model_config_file = model_config_file self.model_config_file = model_config_file
# Register defalut internlm ckpt load type. # Register defalut internlm ckpt load type.
self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt} self.defalut_load_type_func = {
CheckpointLoadType.INTERNLM: try_load_internlm_ckpt,
CheckpointLoadType.HF_LLAMA: try_load_hf_LLAMA_ckpt,
CheckpointLoadType.LLAMA: try_load_LLAMA_ckpt,
}
for ckpt_load_type in CheckpointLoadType: for ckpt_load_type in CheckpointLoadType:
CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type]) CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
@ -718,7 +914,7 @@ class CheckpointManager:
# replace load_ckpt # replace load_ckpt
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"]) self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"])
torch.distributed.barrier() torch.distributed.barrier()
# test storage setting is ok. # test storage setting is ok.

View File

@ -28,7 +28,6 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.generation.streamers import BaseStreamer
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@ -42,6 +41,11 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
try:
from transformers.generation.streamers import BaseStreamer
except: # noqa # pylint: disable=bare-except
BaseStreamer = None
from .configuration_internlm import InternLMConfig from .configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -113,6 +117,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
device (Any, optional): Running device. Defaults to None. device (Any, optional): Running device. Defaults to None.
""" """
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))