ColossalAI/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py

59 lines
1.9 KiB
Python

import argparse
import json
import os
import re
from collections import defaultdict
import numpy as np
import torch
def load_json(path: str):
with open(path) as f:
return json.load(f)
def parse_shape_info(flat_dir: str):
data = load_json(os.path.join(flat_dir, "shape.json"))
flat_info = defaultdict(lambda: defaultdict(list))
for k, shape in data.items():
matched = re.match(r"decoder.layers.\d+", k)
if matched is None:
flat_key = "flat_param_0"
else:
flat_key = f"{matched[0]}.flat_param_0"
flat_info[flat_key]["names"].append(k)
flat_info[flat_key]["shapes"].append(shape)
flat_info[flat_key]["numels"].append(int(np.prod(shape)))
return flat_info
def convert(flat_dir: str, output_dir: str, part: int):
flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt")
output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt")
flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json"))
flat_sd = torch.load(flat_path)
print(f"Loaded flat state dict from {flat_path}")
output_sd = {}
for flat_key, param_meta in flat_meta.items():
flat_param = flat_sd["model"][flat_key]
assert (
sum(param_meta["numels"]) == flat_param.numel()
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
for name, shape, param in zip(
param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"])
):
output_sd[name] = param.view(shape)
torch.save(output_sd, output_path)
print(f"Saved unflat state dict to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("flat_dir")
parser.add_argument("output_dir")
parser.add_argument("part", type=int)
args = parser.parse_args()
convert(args.flat_dir, args.output_dir, args.part)