pull/532/head
lijiaxing 2023-12-11 15:43:59 +08:00
parent a83b02acf4
commit b63b8e58bd
1 changed files with 0 additions and 6 deletions

View File

@ -384,9 +384,6 @@ def load_llama_pretrained_weights(folder, model):
if "output.weight" in model_state_keys: if "output.weight" in model_state_keys:
current_states["norm.weight"] = states["norm.weight"] current_states["norm.weight"] = states["norm.weight"]
current_states["output.weight"] = states["output.weight"] current_states["output.weight"] = states["output.weight"]
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
for i in range(model.extra_pred_tokens):
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
if gpc.get_local_rank(ParallelMode.DATA) == 0: if gpc.get_local_rank(ParallelMode.DATA) == 0:
@ -497,9 +494,6 @@ def load_hf_llama_pretrained_weights(folder, model):
current_states["output.weight"] = torch.chunk( current_states["output.weight"] = torch.chunk(
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0 states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
)[gpc.get_local_rank(ParallelMode.TENSOR)] )[gpc.get_local_rank(ParallelMode.TENSOR)]
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
for i in range(model.extra_pred_tokens):
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)