support hf llama

pull/532/head
lijiaxing 2023-12-08 12:43:16 +08:00
parent 81ffb3d824
commit 9d824d66ec
3 changed files with 1286 additions and 4 deletions

View File

@ -5,6 +5,7 @@ from .embedding import Embedding1D, RotaryEmbedding
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
from .metrics import AccPerplex
from .modeling_internlm import build_model_with_cfg
from .modeling_llama import build_model_with_cfg as build_model_with_llama_cfg
from .modeling_moe import build_model_with_moe_cfg
from .moe import MoE
from .multi_head_attention import MHA
@ -22,4 +23,5 @@ __all__ = [
"gather_forward_split_backward",
"build_model_with_cfg",
"build_model_with_moe_cfg",
"build_model_with_llama_cfg",
]

File diff suppressed because it is too large Load Diff

View File

@ -50,12 +50,14 @@ class CheckpointSaveType(Enum):
class CheckpointLoadType(Enum):
INTERNLM = "internlm"
HF_LLAMA = "hf_llama"
# The load method implemented by internlm by default does not use string representation types,
# but uses enumeration types defined in advance.
LOAD_TYPE_DICT = {
"internlm": CheckpointLoadType.INTERNLM,
"hf_llama": CheckpointLoadType.HF_LLAMA,
}
@ -74,7 +76,7 @@ class CheckpointLoadMethod:
LOAD_TYPE_FUNC = {}
@staticmethod
def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
def convert_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
if load_type.lower() in LOAD_TYPE_DICT:
# The ckpt load method implemented by internlm by default.
return LOAD_TYPE_DICT[load_type.lower()]
@ -90,7 +92,7 @@ class CheckpointLoadMethod:
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
if load_type == CheckpointLoadType.INTERNLM:
if load_type in (CheckpointLoadType.INTERNLM, CheckpointLoadType.HF_LLAMA):
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
else:
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log():
@ -188,6 +190,18 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs):
return (missing_k, unexpected_keys)
def try_load_hf_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613
load_content_str = ""
load_ckpt_folder = load_info["path"]
load_content: CheckpointLoadMask = load_info["content"]
if gpc.is_rank_for_log():
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
if load_content.need_load(CheckpointLoadContent.MODEL):
load_hf_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model.model)
load_content_str += f"{CheckpointLoadContent.MODEL}, "
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
load_content_str = ""
load_ckpt_folder = load_info["path"]
@ -314,6 +328,118 @@ def save_model_checkpoint(folder, model):
torch.distributed.barrier()
def load_hf_llama_pretrained_weights(folder, model):
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")
fns = get_fns(folder)
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".bin") and fn.startswith("pytorch_model")]
model_fns.sort()
states = {}
for model_fn in model_fns:
states.update(llm_load(model_fn, map_location="cpu"))
deep_split = getattr(model, "deep_split", False)
if deep_split:
print("using deep split when loading pretrained weights!")
current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
if gpc.config.model_type == "LLAMA":
if deep_split:
layer_ids = i // 2
else:
layer_ids = i
if not deep_split or (i + 2) % 2 == 0:
states[f"layers.{i}.attention.wq.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wk.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wv.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention.wo.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.attention_norm.weight"] = states.pop(
f"model.layers.{layer_ids}.input_layernorm.weight"
)
if not deep_split or (i + 2) % 2 == 1:
states[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=0,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
states.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
gpc.get_world_size(ParallelMode.TENSOR),
dim=1,
)[gpc.get_local_rank(ParallelMode.TENSOR)]
states[f"layers.{i}.ffn_norm.weight"] = states.pop(
f"model.layers.{layer_ids}.post_attention_layernorm.weight"
)
if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states:
states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq")
for name in list(states.keys()):
if name.startswith(f"layers.{i}"):
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
model_state_keys = set(list(model.state_dict().keys()))
if "tok_embeddings.weight" in model_state_keys or "tok_embeddings.word_embeddings.weight" in model_state_keys:
if gpc.config.model.get("embed_split_hidden", True):
current_states["tok_embeddings.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
else:
current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)]
# current_states["tok_embeddings.weight"] = states["model.embed_tokens.weight"]
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"
if "output.weight" in model_state_keys:
current_states["norm.weight"] = states["model.norm.weight"]
current_states["output.weight"] = torch.chunk(
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
)[gpc.get_local_rank(ParallelMode.TENSOR)]
# current_states["output.weight"] = states["lm_head.weight"]
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
for i in range(model.extra_pred_tokens):
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
)
torch.cuda.empty_cache()
def load_model_checkpoint(folder, model):
"""
There should be weights with names similar to the following under the folder.
@ -682,7 +808,10 @@ class CheckpointManager:
self.model_config_file = model_config_file
# Register defalut internlm ckpt load type.
self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt}
self.defalut_load_type_func = {
CheckpointLoadType.INTERNLM: try_load_internlm_ckpt,
CheckpointLoadType.HF_LLAMA: try_load_hf_LLAMA_ckpt,
}
for ckpt_load_type in CheckpointLoadType:
CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
@ -718,7 +847,7 @@ class CheckpointManager:
# replace load_ckpt
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convert_load_type(self.load_ckpt_info["ckpt_type"])
torch.distributed.barrier()
# test storage setting is ok.