mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
1.9 KiB
56 lines
1.9 KiB
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def load_json(path: str):
|
|
with open(path) as f:
|
|
return json.load(f)
|
|
|
|
|
|
def parse_shape_info(flat_dir: str):
|
|
data = load_json(os.path.join(flat_dir, 'shape.json'))
|
|
flat_info = defaultdict(lambda: defaultdict(list))
|
|
for k, shape in data.items():
|
|
matched = re.match(r'decoder.layers.\d+', k)
|
|
if matched is None:
|
|
flat_key = 'flat_param_0'
|
|
else:
|
|
flat_key = f'{matched[0]}.flat_param_0'
|
|
flat_info[flat_key]['names'].append(k)
|
|
flat_info[flat_key]['shapes'].append(shape)
|
|
flat_info[flat_key]['numels'].append(int(np.prod(shape)))
|
|
return flat_info
|
|
|
|
|
|
def convert(flat_dir: str, output_dir: str, part: int):
|
|
flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt')
|
|
output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt')
|
|
flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json'))
|
|
flat_sd = torch.load(flat_path)
|
|
print(f'Loaded flat state dict from {flat_path}')
|
|
output_sd = {}
|
|
for flat_key, param_meta in flat_meta.items():
|
|
flat_param = flat_sd['model'][flat_key]
|
|
assert sum(param_meta['numels']) == flat_param.numel(
|
|
), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}'
|
|
for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])):
|
|
output_sd[name] = param.view(shape)
|
|
|
|
torch.save(output_sd, output_path)
|
|
print(f'Saved unflat state dict to {output_path}')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('flat_dir')
|
|
parser.add_argument('output_dir')
|
|
parser.add_argument('part', type=int)
|
|
args = parser.parse_args()
|
|
convert(args.flat_dir, args.output_dir, args.part)
|