diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index abe66c4..cc3722e 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -364,9 +364,7 @@ def load_llama_pretrained_weights(folder, model): current_states = {} for idx, i in enumerate(range(model.first_layer, model.last_layer)): - # Temporarily combine the loading logic that supports baichuan2's checkpoint with llama. This may change in - # the future. - if gpc.config.model_type in ("LLAMA", "BAICHUAN2"): + if gpc.config.model_type == "LLAMA": # LLAMA's w2 and w3 are in reverse order w2 = states.pop(f"layers.{i}.feed_forward.w2.weight") w3 = states.pop(f"layers.{i}.feed_forward.w3.weight") @@ -493,14 +491,12 @@ def load_hf_llama_pretrained_weights(folder, model): current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk( states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 )[gpc.get_local_rank(ParallelMode.TENSOR)] - # current_states["tok_embeddings.weight"] = states["model.embed_tokens.weight"] assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}" if "output.weight" in model_state_keys: current_states["norm.weight"] = states["model.norm.weight"] current_states["output.weight"] = torch.chunk( states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0 )[gpc.get_local_rank(ParallelMode.TENSOR)] - # current_states["output.weight"] = states["lm_head.weight"] if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0: for i in range(model.extra_pred_tokens): current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()