mirror of https://github.com/hpcaitech/ColossalAI
59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def load_json(path: str):
|
|
with open(path) as f:
|
|
return json.load(f)
|
|
|
|
|
|
def parse_shape_info(flat_dir: str):
|
|
data = load_json(os.path.join(flat_dir, "shape.json"))
|
|
flat_info = defaultdict(lambda: defaultdict(list))
|
|
for k, shape in data.items():
|
|
matched = re.match(r"decoder.layers.\d+", k)
|
|
if matched is None:
|
|
flat_key = "flat_param_0"
|
|
else:
|
|
flat_key = f"{matched[0]}.flat_param_0"
|
|
flat_info[flat_key]["names"].append(k)
|
|
flat_info[flat_key]["shapes"].append(shape)
|
|
flat_info[flat_key]["numels"].append(int(np.prod(shape)))
|
|
return flat_info
|
|
|
|
|
|
def convert(flat_dir: str, output_dir: str, part: int):
|
|
flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt")
|
|
output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt")
|
|
flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json"))
|
|
flat_sd = torch.load(flat_path)
|
|
print(f"Loaded flat state dict from {flat_path}")
|
|
output_sd = {}
|
|
for flat_key, param_meta in flat_meta.items():
|
|
flat_param = flat_sd["model"][flat_key]
|
|
assert (
|
|
sum(param_meta["numels"]) == flat_param.numel()
|
|
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
|
for name, shape, param in zip(
|
|
param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"])
|
|
):
|
|
output_sd[name] = param.view(shape)
|
|
|
|
torch.save(output_sd, output_path)
|
|
print(f"Saved unflat state dict to {output_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("flat_dir")
|
|
parser.add_argument("output_dir")
|
|
parser.add_argument("part", type=int)
|
|
args = parser.parse_args()
|
|
convert(args.flat_dir, args.output_dir, args.part)
|