ColossalAI/applications/Colossal-LLaMA-2/colossal_llama2/model/init_model.py

112 lines
4.1 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Initialize new model with updated tokenizer by calculating the mean values from original model
"""
import argparse
import numpy as np
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--source_model_and_tokenizer_path",
type=str,
required=True,
default=None,
help="Source path of model & tokenizer",
)
parser.add_argument("--target_tokenizer_path", type=str, required=True, default=None, help="Target tokenizer path")
parser.add_argument("--target_model_path", type=str, required=True, default=None, help="Target model path")
args = parser.parse_args()
source_tokenizer = LlamaTokenizer.from_pretrained(args.source_model_and_tokenizer_path)
source_tokenizer.add_bos_token = False
source_tokenizer.add_eos_token = False
if source_tokenizer.pad_token is None:
source_tokenizer.pad_token = source_tokenizer.unk_token
source_vocab = source_tokenizer.get_vocab()
target_tokenizer = LlamaTokenizer.from_pretrained(args.target_tokenizer_path)
target_tokenizer.add_bos_token = False
target_tokenizer.add_eos_token = False
if target_tokenizer.pad_token is None:
target_tokenizer.pad_token = target_tokenizer.unk_token
target_vocab = target_tokenizer.get_vocab()
target_inverted_vocab = {v: k for k, v in target_vocab.items()}
assert len(target_vocab) > len(
source_vocab
), f"Target vocab size({len(target_vocab)}) must be greater than source vocab size({len(source_vocab)})"
gpu_device = torch.device("cuda:0")
cpu_device = torch.device("cpu")
source_model = LlamaForCausalLM.from_pretrained(args.source_model_and_tokenizer_path)
source_model.eval()
source_model = source_model.to(gpu_device)
source_input_embeddings = source_model.get_input_embeddings()
assert isinstance(source_input_embeddings, torch.nn.Embedding)
assert source_input_embeddings.weight.shape[0] == len(source_vocab)
source_input_embeddings.eval()
source_output_embeddings = source_model.get_output_embeddings()
assert isinstance(source_output_embeddings, torch.nn.Linear)
assert source_output_embeddings.bias is None
assert source_output_embeddings.weight.shape[0] == len(source_vocab)
source_output_embeddings.eval()
input_embeddings = source_input_embeddings.weight.cpu().detach().numpy()
output_embeddings = source_output_embeddings.weight.cpu().detach().numpy()
for i in range(len(source_vocab), len(target_vocab)):
if i % 500 == 0:
logger.info(f"processing {i}/{len(target_vocab)} target tokens")
target_token = target_inverted_vocab[i]
target_to_source_token_ids = torch.LongTensor(source_tokenizer([target_token])["input_ids"][0])
target_to_source_token_ids = target_to_source_token_ids.to(gpu_device)
target_to_source_input_embedding = (
source_input_embeddings.weight[target_to_source_token_ids]
.mean(dim=0)
.unsqueeze(dim=0)
.cpu()
.detach()
.numpy()
)
target_to_source_output_embedding = (
source_output_embeddings.weight[target_to_source_token_ids]
.mean(dim=0)
.unsqueeze(dim=0)
.cpu()
.detach()
.numpy()
)
input_embeddings = np.concatenate((input_embeddings, target_to_source_input_embedding), axis=0)
output_embeddings = np.concatenate((output_embeddings, target_to_source_output_embedding), axis=0)
source_model = source_model.to(cpu_device)
assert isinstance(source_model, LlamaForCausalLM)
# expand
source_model.resize_token_embeddings(new_num_tokens=len(target_vocab))
source_model.model.embed_tokens.weight.data = torch.Tensor(input_embeddings)
source_model.lm_head.weight.data = torch.Tensor(output_embeddings)
source_model = source_model.half()
source_model.save_pretrained(save_directory=args.target_model_path)
if __name__ == "__main__":
main()