mirror of https://github.com/InternLM/InternLM
support hf llama
parent
6def66fb07
commit
41edd074a6
|
@ -364,9 +364,7 @@ def load_llama_pretrained_weights(folder, model):
|
|||
|
||||
current_states = {}
|
||||
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
|
||||
# Temporarily combine the loading logic that supports baichuan2's checkpoint with llama. This may change in
|
||||
# the future.
|
||||
if gpc.config.model_type in ("LLAMA", "BAICHUAN2"):
|
||||
if gpc.config.model_type == "LLAMA":
|
||||
# LLAMA's w2 and w3 are in reverse order
|
||||
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
|
||||
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
|
||||
|
@ -493,14 +491,12 @@ def load_hf_llama_pretrained_weights(folder, model):
|
|||
current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
|
||||
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
|
||||
)[gpc.get_local_rank(ParallelMode.TENSOR)]
|
||||
# current_states["tok_embeddings.weight"] = states["model.embed_tokens.weight"]
|
||||
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"
|
||||
if "output.weight" in model_state_keys:
|
||||
current_states["norm.weight"] = states["model.norm.weight"]
|
||||
current_states["output.weight"] = torch.chunk(
|
||||
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
|
||||
)[gpc.get_local_rank(ParallelMode.TENSOR)]
|
||||
# current_states["output.weight"] = states["lm_head.weight"]
|
||||
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
|
||||
for i in range(model.extra_pred_tokens):
|
||||
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()
|
||||
|
|
Loading…
Reference in New Issue