mirror of https://github.com/InternLM/InternLM
support pipeline convert
parent
934f60b753
commit
b5ce6825ce
|
@ -6,7 +6,7 @@ from tqdm import tqdm
|
|||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash):
|
||||
def revert(src, tgt, tp_size, pp_size, embed_split_hidden, adapt_hf, use_flash):
|
||||
hf_state = {}
|
||||
print("Loading HF checkpoints...")
|
||||
for filename in tqdm(os.listdir(src)):
|
||||
|
@ -33,75 +33,81 @@ def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash):
|
|||
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
# revert
|
||||
states = [{} for _ in range(tp_size)]
|
||||
# no-moe is stored according to tp and pp ranks
|
||||
states = [[{} for _ in range(tp_size)] for _ in range(pp_size)]
|
||||
# moe is stored according to layer id, expert id and tp rank
|
||||
moe_states = [
|
||||
[[{} for _ in range(tp_size)] for _ in range(config.num_experts)] for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
# layers
|
||||
for layer_i in tqdm(range(config.num_hidden_layers)):
|
||||
# no-moe
|
||||
for i in range(tp_size):
|
||||
states[i][f"model.layers.{layer_i}.attention_norm.weight"] = hf_state[
|
||||
f"model.layers.{layer_i}.input_layernorm.weight"
|
||||
].clone()
|
||||
states[i][f"model.layers.{layer_i}.ffn_norm.weight"] = hf_state[
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight"
|
||||
].clone()
|
||||
states[i][f"model.layers.{layer_i}.feed_forward.moe_layer.gate.wg.weight"] = hf_state[
|
||||
f"model.layers.{layer_i}.mlp.gate.weight"
|
||||
].clone()
|
||||
|
||||
# mha
|
||||
wqs = (
|
||||
permute(hf_state[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
|
||||
# .view(-1, dims_per_head, dim)
|
||||
.chunk(tp_size, 0)
|
||||
)
|
||||
wks = (
|
||||
permute(hf_state[f"model.layers.{layer_i}.self_attn.k_proj.weight"], n_kv_heads, -1, dim)
|
||||
# .view(-1, dims_per_head, dim)
|
||||
.chunk(tp_size, 0)
|
||||
)
|
||||
wvs = (
|
||||
hf_state[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
|
||||
# .view(-1, dims_per_head, dim)
|
||||
.chunk(tp_size, 0)
|
||||
)
|
||||
wos = hf_state[f"model.layers.{layer_i}.self_attn.o_proj.weight"].chunk(tp_size, 1)
|
||||
for i in range(tp_size):
|
||||
states[i][f"model.layers.{layer_i}.attention.wq.weight"] = wqs[i].reshape(-1, dim).clone()
|
||||
states[i][f"model.layers.{layer_i}.attention.wk.weight"] = wks[i].reshape(-1, dim).clone()
|
||||
states[i][f"model.layers.{layer_i}.attention.wv.weight"] = wvs[i].reshape(-1, dim).clone()
|
||||
states[i][f"model.layers.{layer_i}.attention.wo.weight"] = wos[i].clone()
|
||||
|
||||
# moe
|
||||
for expert_id in range(config.num_experts):
|
||||
w1s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w1.weight"].chunk(tp_size, 0)
|
||||
w2s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w3.weight"].chunk(tp_size, 0)
|
||||
w3s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w2.weight"].chunk(tp_size, 1)
|
||||
assert config.num_hidden_layers % pp_size == 0
|
||||
num_layer_per_stage = config.num_hidden_layers // pp_size
|
||||
for p_i in range(pp_size):
|
||||
for layer_i in tqdm(range(num_layer_per_stage)):
|
||||
# no-moe
|
||||
for i in range(tp_size):
|
||||
moe_states[layer_i][expert_id][i][
|
||||
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w1.weight"
|
||||
] = w1s[i].clone()
|
||||
moe_states[layer_i][expert_id][i][
|
||||
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w2.weight"
|
||||
] = w2s[i].clone()
|
||||
moe_states[layer_i][expert_id][i][
|
||||
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight"
|
||||
] = w3s[i].clone()
|
||||
states[p_i][i][f"model.layers.{layer_i}.attention_norm.weight"] = hf_state[
|
||||
f"model.layers.{layer_i}.input_layernorm.weight"
|
||||
].clone()
|
||||
states[p_i][i][f"model.layers.{layer_i}.ffn_norm.weight"] = hf_state[
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight"
|
||||
].clone()
|
||||
states[p_i][i][f"model.layers.{layer_i}.feed_forward.moe_layer.gate.wg.weight"] = hf_state[
|
||||
f"model.layers.{layer_i}.mlp.gate.weight"
|
||||
].clone()
|
||||
|
||||
# mha
|
||||
wqs = (
|
||||
permute(hf_state[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
|
||||
# .view(-1, dims_per_head, dim)
|
||||
.chunk(tp_size, 0)
|
||||
)
|
||||
wks = (
|
||||
permute(hf_state[f"model.layers.{layer_i}.self_attn.k_proj.weight"], n_kv_heads, -1, dim)
|
||||
# .view(-1, dims_per_head, dim)
|
||||
.chunk(tp_size, 0)
|
||||
)
|
||||
wvs = (
|
||||
hf_state[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
|
||||
# .view(-1, dims_per_head, dim)
|
||||
.chunk(tp_size, 0)
|
||||
)
|
||||
wos = hf_state[f"model.layers.{layer_i}.self_attn.o_proj.weight"].chunk(tp_size, 1)
|
||||
for i in range(tp_size):
|
||||
states[p_i][i][f"model.layers.{layer_i}.attention.wq.weight"] = wqs[i].reshape(-1, dim).clone()
|
||||
states[p_i][i][f"model.layers.{layer_i}.attention.wk.weight"] = wks[i].reshape(-1, dim).clone()
|
||||
states[p_i][i][f"model.layers.{layer_i}.attention.wv.weight"] = wvs[i].reshape(-1, dim).clone()
|
||||
states[p_i][i][f"model.layers.{layer_i}.attention.wo.weight"] = wos[i].clone()
|
||||
|
||||
# moe
|
||||
global_layer_i = p_i * num_layer_per_stage + layer_i
|
||||
for expert_id in range(config.num_experts):
|
||||
w1s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w1.weight"].chunk(tp_size, 0)
|
||||
w2s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w3.weight"].chunk(tp_size, 0)
|
||||
w3s = hf_state[f"model.layers.{layer_i}.mlp.experts.{expert_id}.w2.weight"].chunk(tp_size, 1)
|
||||
for i in range(tp_size):
|
||||
moe_states[global_layer_i][expert_id][i][
|
||||
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w1.weight"
|
||||
] = w1s[i].clone()
|
||||
moe_states[global_layer_i][expert_id][i][
|
||||
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w2.weight"
|
||||
] = w2s[i].clone()
|
||||
moe_states[global_layer_i][expert_id][i][
|
||||
f"model.layers.{layer_i}.feed_forward.moe_layer.experts.experts.{expert_id}.w3.weight"
|
||||
] = w3s[i].clone()
|
||||
for i in range(tp_size):
|
||||
if embed_split_hidden:
|
||||
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 1)
|
||||
states[i]["model.tok_embeddings.weight"] = embeds[i].clone()
|
||||
states[0][i]["model.tok_embeddings.weight"] = embeds[i].clone()
|
||||
else:
|
||||
embeds = hf_state["model.embed_tokens.weight"].chunk(tp_size, 0)
|
||||
states[i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone()
|
||||
states[0][i]["model.tok_embeddings.word_embeddings.weight"] = embeds[i].clone()
|
||||
|
||||
outputs = hf_state["lm_head.weight"].chunk(tp_size, 0)
|
||||
for i in range(tp_size):
|
||||
states[i]["model.norm.weight"] = hf_state["model.norm.weight"].clone()
|
||||
states[i]["model.output.weight"] = outputs[i].clone()
|
||||
states[pp_size - 1][i]["model.norm.weight"] = hf_state["model.norm.weight"].clone()
|
||||
states[pp_size - 1][i]["model.output.weight"] = outputs[i].clone()
|
||||
|
||||
mlp_ratio = round((config.intermediate_size - 255) / config.hidden_size + 0.01, 2)
|
||||
if "rotary" in config.to_dict():
|
||||
|
@ -134,8 +140,9 @@ def revert(src, tgt, tp_size, embed_split_hidden, adapt_hf, use_flash):
|
|||
# save
|
||||
os.makedirs(tgt, exist_ok=True)
|
||||
print(f"Saving to {tgt}...")
|
||||
for tp in tqdm(range(tp_size)):
|
||||
torch.save(states[tp], os.path.join(tgt, f"model_tp{tp}_pp0.pt"))
|
||||
for pp in range(pp_size):
|
||||
for tp in tqdm(range(tp_size)):
|
||||
torch.save(states[pp][tp], os.path.join(tgt, f"model_tp{tp}_pp{pp}.pt"))
|
||||
for moe_layer_id in range(config.num_hidden_layers):
|
||||
for expert_id in range(config.num_experts):
|
||||
for tp in tqdm(range(tp_size)):
|
||||
|
@ -151,6 +158,7 @@ def print_args(args):
|
|||
print(f"Source Path: {args.src}")
|
||||
print(f"Target Path: {args.tgt}")
|
||||
print(f"TP Size: {args.tp_size}")
|
||||
print(f"PP Size: {args.pp_size}")
|
||||
print(f"Embeb Split Hidden: {args.embed_split}")
|
||||
print(f"Adapt HF: {args.adapt_hf}")
|
||||
print(f"Use Flash Attn: {args.use_flash}")
|
||||
|
@ -163,6 +171,7 @@ def parse_args():
|
|||
parser.add_argument("--src", type=str, help="Input folder")
|
||||
parser.add_argument("--tgt", type=str, help="Output folder")
|
||||
parser.add_argument("--tp_size", type=int, help="world_size of tensor parallel")
|
||||
parser.add_argument("--pp_size", type=int, help="world_size of pipeline parallel")
|
||||
parser.add_argument("--embed_split", action="store_true", help="embed_split_hidden of InternLM")
|
||||
parser.add_argument("--adapt_hf", action="store_true", help="adapt_hf of InternLM")
|
||||
parser.add_argument("--use_flash", action="store_true", help="use_flash_attn of InternLM")
|
||||
|
@ -178,4 +187,4 @@ if __name__ == "__main__":
|
|||
args = parse_args()
|
||||
print_args(args)
|
||||
|
||||
revert(args.src, args.tgt, args.tp_size, args.embed_split, args.adapt_hf, args.use_flash)
|
||||
revert(args.src, args.tgt, args.tp_size, args.pp_size, args.embed_split, args.adapt_hf, args.use_flash)
|
||||
|
|
Loading…
Reference in New Issue