import argparse import json import os import re from collections import defaultdict import numpy as np import torch def load_json(path: str): with open(path) as f: return json.load(f) def parse_shape_info(flat_dir: str): data = load_json(os.path.join(flat_dir, "shape.json")) flat_info = defaultdict(lambda: defaultdict(list)) for k, shape in data.items(): matched = re.match(r"decoder.layers.\d+", k) if matched is None: flat_key = "flat_param_0" else: flat_key = f"{matched[0]}.flat_param_0" flat_info[flat_key]["names"].append(k) flat_info[flat_key]["shapes"].append(shape) flat_info[flat_key]["numels"].append(int(np.prod(shape))) return flat_info def convert(flat_dir: str, output_dir: str, part: int): flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt") output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt") flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json")) flat_sd = torch.load(flat_path) print(f"Loaded flat state dict from {flat_path}") output_sd = {} for flat_key, param_meta in flat_meta.items(): flat_param = flat_sd["model"][flat_key] assert ( sum(param_meta["numels"]) == flat_param.numel() ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' for name, shape, param in zip( param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"]) ): output_sd[name] = param.view(shape) torch.save(output_sd, output_path) print(f"Saved unflat state dict to {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("flat_dir") parser.add_argument("output_dir") parser.add_argument("part", type=int) args = parser.parse_args() convert(args.flat_dir, args.output_dir, args.part)