mirror of https://github.com/InternLM/InternLM
support hf llama
parent
9d824d66ec
commit
6def66fb07
|
@ -28,7 +28,7 @@ ckpt = dict(
|
|||
# 'load_ckpt_info' setting guide:
|
||||
# 1. the 'path' indicate ckpt path,
|
||||
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
|
||||
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internlm", "llama", "hf_llama".
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
|
||||
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
|
||||
|
|
|
@ -51,6 +51,7 @@ class CheckpointSaveType(Enum):
|
|||
class CheckpointLoadType(Enum):
|
||||
INTERNLM = "internlm"
|
||||
HF_LLAMA = "hf_llama"
|
||||
LLAMA = "llama"
|
||||
|
||||
|
||||
# The load method implemented by internlm by default does not use string representation types,
|
||||
|
@ -58,6 +59,7 @@ class CheckpointLoadType(Enum):
|
|||
LOAD_TYPE_DICT = {
|
||||
"internlm": CheckpointLoadType.INTERNLM,
|
||||
"hf_llama": CheckpointLoadType.HF_LLAMA,
|
||||
"llama": CheckpointLoadType.LLAMA,
|
||||
}
|
||||
|
||||
|
||||
|
@ -92,7 +94,11 @@ class CheckpointLoadMethod:
|
|||
|
||||
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
|
||||
|
||||
if load_type in (CheckpointLoadType.INTERNLM, CheckpointLoadType.HF_LLAMA):
|
||||
if load_type in (
|
||||
CheckpointLoadType.INTERNLM,
|
||||
CheckpointLoadType.HF_LLAMA,
|
||||
CheckpointLoadType.LLAMA,
|
||||
):
|
||||
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
|
||||
else:
|
||||
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log():
|
||||
|
@ -190,24 +196,32 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs):
|
|||
return (missing_k, unexpected_keys)
|
||||
|
||||
|
||||
def try_load_hf_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613
|
||||
def process_load_info(load_info):
|
||||
load_content_str = ""
|
||||
load_ckpt_folder = load_info["path"]
|
||||
load_content: CheckpointLoadMask = load_info["content"]
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||
|
||||
return load_content_str, load_ckpt_folder, load_content
|
||||
|
||||
|
||||
def try_load_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613
|
||||
load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
|
||||
if load_content.need_load(CheckpointLoadContent.MODEL):
|
||||
load_hf_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model.model)
|
||||
load_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model)
|
||||
load_content_str += f"{CheckpointLoadContent.MODEL}, "
|
||||
|
||||
|
||||
def try_load_hf_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613
|
||||
load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
|
||||
if load_content.need_load(CheckpointLoadContent.MODEL):
|
||||
load_hf_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model)
|
||||
load_content_str += f"{CheckpointLoadContent.MODEL}, "
|
||||
|
||||
|
||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
||||
load_content_str = ""
|
||||
load_ckpt_folder = load_info["path"]
|
||||
load_content: CheckpointLoadMask = load_info["content"]
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||
load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)
|
||||
|
||||
if load_content.need_load(CheckpointLoadContent.MODEL):
|
||||
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
|
||||
|
@ -328,7 +342,69 @@ def save_model_checkpoint(folder, model):
|
|||
torch.distributed.barrier()
|
||||
|
||||
|
||||
def load_llama_pretrained_weights(folder, model):
|
||||
model = model.model
|
||||
assert folder is not None, "Please specify the folder of the pretrained model"
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Loading pretrained model from {folder}")
|
||||
|
||||
fns = get_fns(folder)
|
||||
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")]
|
||||
model_fns.sort()
|
||||
|
||||
old_tp = len(model_fns)
|
||||
cur_tp = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
# If the two tp are inconsistent, you need to consider the merge before splitting
|
||||
if old_tp != cur_tp:
|
||||
raise RuntimeError(
|
||||
f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first"
|
||||
)
|
||||
|
||||
states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu")
|
||||
|
||||
current_states = {}
|
||||
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
||||
# Temporarily combine the loading logic that supports baichuan2's checkpoint with llama. This may change in
|
||||
# the future.
|
||||
if gpc.config.model_type in ("LLAMA", "BAICHUAN2"):
|
||||
# LLAMA's w2 and w3 are in reverse order
|
||||
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
|
||||
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
|
||||
states[f"layers.{i}.feed_forward.w2.weight"] = w3
|
||||
states[f"layers.{i}.feed_forward.w3.weight"] = w2
|
||||
if "rope.freqs" in states:
|
||||
states[f"layers.{i}.attention.rotary_emb.inv_freq"] = states["rope.freqs"]
|
||||
for name in list(states.keys()):
|
||||
if f".{i}." in name:
|
||||
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
|
||||
|
||||
model_state_keys = set(list(model.state_dict().keys()))
|
||||
|
||||
if "tok_embeddings.weight" in model_state_keys:
|
||||
current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"]
|
||||
assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}"
|
||||
if "output.weight" in model_state_keys:
|
||||
current_states["norm.weight"] = states["norm.weight"]
|
||||
current_states["output.weight"] = states["output.weight"]
|
||||
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
|
||||
for i in range(model.extra_pred_tokens):
|
||||
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
|
||||
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
|
||||
|
||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
logger.info(
|
||||
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
|
||||
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
|
||||
)
|
||||
|
||||
del states
|
||||
del current_states
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_hf_llama_pretrained_weights(folder, model):
|
||||
model = model.model
|
||||
assert folder is not None, "Please specify the folder of the pretrained model"
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Loading pretrained model from {folder}")
|
||||
|
@ -811,6 +887,7 @@ class CheckpointManager:
|
|||
self.defalut_load_type_func = {
|
||||
CheckpointLoadType.INTERNLM: try_load_internlm_ckpt,
|
||||
CheckpointLoadType.HF_LLAMA: try_load_hf_LLAMA_ckpt,
|
||||
CheckpointLoadType.LLAMA: try_load_LLAMA_ckpt,
|
||||
}
|
||||
for ckpt_load_type in CheckpointLoadType:
|
||||
CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
|
||||
|
|
Loading…
Reference in New Issue