mirror of https://github.com/InternLM/InternLM
modeling
parent
a83b02acf4
commit
b63b8e58bd
|
@ -384,9 +384,6 @@ def load_llama_pretrained_weights(folder, model):
|
||||||
if "output.weight" in model_state_keys:
|
if "output.weight" in model_state_keys:
|
||||||
current_states["norm.weight"] = states["norm.weight"]
|
current_states["norm.weight"] = states["norm.weight"]
|
||||||
current_states["output.weight"] = states["output.weight"]
|
current_states["output.weight"] = states["output.weight"]
|
||||||
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
|
|
||||||
for i in range(model.extra_pred_tokens):
|
|
||||||
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
|
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
|
||||||
|
|
||||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||||
|
@ -497,9 +494,6 @@ def load_hf_llama_pretrained_weights(folder, model):
|
||||||
current_states["output.weight"] = torch.chunk(
|
current_states["output.weight"] = torch.chunk(
|
||||||
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
|
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
|
||||||
)[gpc.get_local_rank(ParallelMode.TENSOR)]
|
)[gpc.get_local_rank(ParallelMode.TENSOR)]
|
||||||
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
|
|
||||||
for i in range(model.extra_pred_tokens):
|
|
||||||
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
|
|
||||||
|
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue