From c4108d3431296e9dbc6a16a5d6a4359ca60636f7 Mon Sep 17 00:00:00 2001 From: Yang Gao Date: Wed, 10 Apr 2024 17:06:18 +0800 Subject: [PATCH] [Tool]: Update tools/convert2llama.py to support `safetensors` format (#730) --- tools/convert2llama.py | 55 +++++++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/tools/convert2llama.py b/tools/convert2llama.py index 48368b7..1462924 100644 --- a/tools/convert2llama.py +++ b/tools/convert2llama.py @@ -9,6 +9,28 @@ from tqdm import tqdm from transformers import AutoConfig, LlamaConfig, LlamaTokenizer +def weight_load(fp, **kwargs): + """Load weights from a file.""" + is_safetensors = kwargs.pop('is_safetensors', False) + + if is_safetensors: + try: + from safetensors import safe_open + except ImportError: + raise ImportError( + 'Before loading ckpts in the `safetensors` format, ' + 'please install the `safetensors` package first.') + + model = safe_open(fp, framework='pt') + state_dict = {} + for k in model.keys(): + state_dict[k] = model.get_tensor(k) + return state_dict + + else: + return torch.load(fp, **kwargs) + + def save_conifg(config, tgt): config_dict = config.to_dict() unnecessary_keys = [ @@ -41,19 +63,29 @@ def convert(src, tgt): // config.num_key_value_heads # load index json file - index_file = os.path.join(src, 'pytorch_model.bin.index.json') - if os.path.exists(index_file): - with open(index_file) as fp: + index_file = 'pytorch_model.bin.index.json' + if os.path.exists(os.path.join(src, index_file)): + with open(os.path.join(src, index_file)) as fp: index_dict = json.load(fp) index_dict['weight_map'] = {} else: - index_dict = None + index_file = 'model.safetensors.index.json' + if os.path.exists(os.path.join(src, index_file)): + with open(os.path.join(src, index_file)) as fp: + index_dict = json.load(fp) + index_dict['weight_map'] = {} + else: + index_dict = None os.makedirs(tgt, exist_ok=True) for filename in tqdm(os.listdir(src)): - if not filename.endswith('.bin'): + if not any(filename.endswith(ext) for ext in ('.bin', '.safetensors')): continue - states = torch.load(os.path.join(src, filename)) + + print(f'Loading {os.path.join(src, filename)}...', flush=True) + states = weight_load(os.path.join(src, filename), + is_safetensors=filename.endswith('.safetensors')) + llama_states = {} for k, v in states.copy().items(): if 'wqkv' in k: @@ -104,15 +136,15 @@ def convert(src, tgt): if index_dict is not None: for k in llama_states: index_dict['weight_map'][k] = filename - print(f"Saving to {os.path.join(tgt, filename)}...", flush=True) + + print(f'Saving to {os.path.join(tgt, filename)}...', flush=True) torch.save(llama_states, os.path.join(tgt, filename)) del states - print('Saving config and tokenizer...') + print('Saving config and tokenizer...', flush=True) # index.json if index_dict is not None: - with open(os.path.join(tgt, 'pytorch_model.bin.index.json'), - 'w') as fp: + with open(os.path.join(tgt, index_file), 'w') as fp: json.dump(index_dict, fp, indent=2) # tokenizer tokenizer = LlamaTokenizer.from_pretrained(src) @@ -120,7 +152,8 @@ def convert(src, tgt): tokenizer.save_pretrained(tgt) # config save_conifg(config, tgt) - print('Done!') + + print('Done!', flush=True) def parse_args():