mirror of https://github.com/InternLM/InternLM
support hf llama
parent
6def66fb07
commit
41edd074a6
|
@ -364,9 +364,7 @@ def load_llama_pretrained_weights(folder, model):
|
||||||
|
|
||||||
current_states = {}
|
current_states = {}
|
||||||
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
||||||
# Temporarily combine the loading logic that supports baichuan2's checkpoint with llama. This may change in
|
if gpc.config.model_type == "LLAMA":
|
||||||
# the future.
|
|
||||||
if gpc.config.model_type in ("LLAMA", "BAICHUAN2"):
|
|
||||||
# LLAMA's w2 and w3 are in reverse order
|
# LLAMA's w2 and w3 are in reverse order
|
||||||
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
|
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
|
||||||
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
|
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
|
||||||
|
@ -493,14 +491,12 @@ def load_hf_llama_pretrained_weights(folder, model):
|
||||||
current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
|
current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
|
||||||
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
|
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
|
||||||
)[gpc.get_local_rank(ParallelMode.TENSOR)]
|
)[gpc.get_local_rank(ParallelMode.TENSOR)]
|
||||||
# current_states["tok_embeddings.weight"] = states["model.embed_tokens.weight"]
|
|
||||||
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"
|
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"
|
||||||
if "output.weight" in model_state_keys:
|
if "output.weight" in model_state_keys:
|
||||||
current_states["norm.weight"] = states["model.norm.weight"]
|
current_states["norm.weight"] = states["model.norm.weight"]
|
||||||
current_states["output.weight"] = torch.chunk(
|
current_states["output.weight"] = torch.chunk(
|
||||||
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
|
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
|
||||||
)[gpc.get_local_rank(ParallelMode.TENSOR)]
|
)[gpc.get_local_rank(ParallelMode.TENSOR)]
|
||||||
# current_states["output.weight"] = states["lm_head.weight"]
|
|
||||||
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
|
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
|
||||||
for i in range(model.extra_pred_tokens):
|
for i in range(model.extra_pred_tokens):
|
||||||
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
|
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
|
||||||
|
|
Loading…
Reference in New Issue