diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index cc3722e..c4b0c3c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -384,9 +384,6 @@ def load_llama_pretrained_weights(folder, model): if "output.weight" in model_state_keys: current_states["norm.weight"] = states["norm.weight"] current_states["output.weight"] = states["output.weight"] - if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0: - for i in range(model.extra_pred_tokens): - current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone() missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) if gpc.get_local_rank(ParallelMode.DATA) == 0: @@ -497,9 +494,6 @@ def load_hf_llama_pretrained_weights(folder, model): current_states["output.weight"] = torch.chunk( states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0 )[gpc.get_local_rank(ParallelMode.TENSOR)] - if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0: - for i in range(model.extra_pred_tokens): - current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone() missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)