support hf llama

pull/532/head
lijiaxing 2023-12-08 16:43:56 +08:00
parent 6def66fb07
commit 41edd074a6
1 changed files with 1 additions and 5 deletions

View File

@ -364,9 +364,7 @@ def load_llama_pretrained_weights(folder, model):
current_states = {} current_states = {}
for idx, i in enumerate(range(model.first_layer, model.last_layer)): for idx, i in enumerate(range(model.first_layer, model.last_layer)):
# Temporarily combine the loading logic that supports baichuan2's checkpoint with llama. This may change in if gpc.config.model_type == "LLAMA":
# the future.
if gpc.config.model_type in ("LLAMA", "BAICHUAN2"):
# LLAMA's w2 and w3 are in reverse order # LLAMA's w2 and w3 are in reverse order
w2 = states.pop(f"layers.{i}.feed_forward.w2.weight") w2 = states.pop(f"layers.{i}.feed_forward.w2.weight")
w3 = states.pop(f"layers.{i}.feed_forward.w3.weight") w3 = states.pop(f"layers.{i}.feed_forward.w3.weight")
@ -493,14 +491,12 @@ def load_hf_llama_pretrained_weights(folder, model):
current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk( current_states["tok_embeddings.word_embeddings.weight"] = torch.chunk(
states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1 states["model.embed_tokens.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=1
)[gpc.get_local_rank(ParallelMode.TENSOR)] )[gpc.get_local_rank(ParallelMode.TENSOR)]
# current_states["tok_embeddings.weight"] = states["model.embed_tokens.weight"]
assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}" assert model.first_layer == 0, f"Expect model.first_layer to be 0, but got {model.first_layer}"
if "output.weight" in model_state_keys: if "output.weight" in model_state_keys:
current_states["norm.weight"] = states["model.norm.weight"] current_states["norm.weight"] = states["model.norm.weight"]
current_states["output.weight"] = torch.chunk( current_states["output.weight"] = torch.chunk(
states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0 states["lm_head.weight"], gpc.get_world_size(ParallelMode.TENSOR), dim=0
)[gpc.get_local_rank(ParallelMode.TENSOR)] )[gpc.get_local_rank(ParallelMode.TENSOR)]
# current_states["output.weight"] = states["lm_head.weight"]
if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0: if hasattr(model, "extra_pred_tokens") and model.extra_pred_tokens > 0:
for i in range(model.extra_pred_tokens): for i in range(model.extra_pred_tokens):
current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone() current_states[f"extra_outputs.{i}.weight"] = current_states["output.weight"].clone()