diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 1cbb5e7..7d945b4 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -28,7 +28,7 @@ ckpt = dict( # 'load_ckpt_info' setting guide: # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internlm", "llama", "hf_llama". load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 9242548..abe66c4 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -51,6 +51,7 @@ class CheckpointSaveType(Enum): class CheckpointLoadType(Enum): INTERNLM = "internlm" HF_LLAMA = "hf_llama" + LLAMA = "llama" # The load method implemented by internlm by default does not use string representation types, @@ -58,6 +59,7 @@ class CheckpointLoadType(Enum): LOAD_TYPE_DICT = { "internlm": CheckpointLoadType.INTERNLM, "hf_llama": CheckpointLoadType.HF_LLAMA, + "llama": CheckpointLoadType.LLAMA, } @@ -92,7 +94,11 @@ class CheckpointLoadMethod: CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func}) - if load_type in (CheckpointLoadType.INTERNLM, CheckpointLoadType.HF_LLAMA): + if load_type in ( + CheckpointLoadType.INTERNLM, + CheckpointLoadType.HF_LLAMA, + CheckpointLoadType.LLAMA, + ): CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func) else: if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG and gpc.is_rank_for_log(): @@ -190,24 +196,32 @@ def load_shard_state_dict(shard_model, shard_state, **kwargs): return (missing_k, unexpected_keys) -def try_load_hf_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613 +def process_load_info(load_info): load_content_str = "" load_ckpt_folder = load_info["path"] load_content: CheckpointLoadMask = load_info["content"] if gpc.is_rank_for_log(): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") + return load_content_str, load_ckpt_folder, load_content + + +def try_load_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613 + load_content_str, load_ckpt_folder, load_content = process_load_info(load_info) if load_content.need_load(CheckpointLoadContent.MODEL): - load_hf_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model.model) + load_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model) + load_content_str += f"{CheckpointLoadContent.MODEL}, " + + +def try_load_hf_LLAMA_ckpt(ckpt_mm, load_info, train_state: TrainState): # pylint: disable=W0613 + load_content_str, load_ckpt_folder, load_content = process_load_info(load_info) + if load_content.need_load(CheckpointLoadContent.MODEL): + load_hf_llama_pretrained_weights(folder=load_ckpt_folder, model=ckpt_mm.model) load_content_str += f"{CheckpointLoadContent.MODEL}, " def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): - load_content_str = "" - load_ckpt_folder = load_info["path"] - load_content: CheckpointLoadMask = load_info["content"] - if gpc.is_rank_for_log(): - logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") + load_content_str, load_ckpt_folder, load_content = process_load_info(load_info) if load_content.need_load(CheckpointLoadContent.MODEL): load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) @@ -328,7 +342,69 @@ def save_model_checkpoint(folder, model): torch.distributed.barrier() +def load_llama_pretrained_weights(folder, model): + model = model.model + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")] + model_fns.sort() + + old_tp = len(model_fns) + cur_tp = gpc.get_world_size(ParallelMode.TENSOR) + # If the two tp are inconsistent, you need to consider the merge before splitting + if old_tp != cur_tp: + raise RuntimeError( + f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first" + ) + + states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu") + + current_states = {} + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + # Temporarily combine the loading logic that supports baichuan2's checkpoint with llama. This may change in + # the future. + if gpc.config.model_type in ("LLAMA", "BAICHUAN2"): + # LLAMA's w2 and w3 are in reverse order + w2 = states.pop(f"layers.{i}.feed_forward.w2.weight") + w3 = states.pop(f"layers.{i}.feed_forward.w3.weight") + states[f"layers.{i}.feed_forward.w2.weight"] = w3 + states[f"layers.{i}.feed_forward.w3.weight"] = w2 + if "rope.freqs" in states: + states[f"layers.{i}.attention.rotary_emb.inv_freq"] = states["rope.freqs"] + for name in list(states.keys()): + if f".{i}." in name: + current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) + + model_state_keys = set(list(model.state_dict().keys())) + + if "tok_embeddings.weight" in model_state_keys: + current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"] + assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}" + if "output.weight" in model_state_keys: + current_states["norm.weight"] = states["norm.weight"] + current_states["output.weight"] = states["output.weight"] + if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0: + for i in range(model.extra_pred_tokens): + current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone() + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + del states + del current_states + torch.cuda.empty_cache() + + def load_hf_llama_pretrained_weights(folder, model): + model = model.model assert folder is not None, "Please specify the folder of the pretrained model" if gpc.is_rank_for_log(): logger.info(f"Loading pretrained model from {folder}") @@ -811,6 +887,7 @@ class CheckpointManager: self.defalut_load_type_func = { CheckpointLoadType.INTERNLM: try_load_internlm_ckpt, CheckpointLoadType.HF_LLAMA: try_load_hf_LLAMA_ckpt, + CheckpointLoadType.LLAMA: try_load_LLAMA_ckpt, } for ckpt_load_type in CheckpointLoadType: CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])