mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.6 KiB
79 lines
2.6 KiB
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Dict
|
|
|
|
import transformers
|
|
|
|
from ..models.llama.llama_lm import LlamaLM
|
|
|
|
DEFAULT_PAD_TOKEN = "[PAD]"
|
|
DEFAULT_EOS_TOKEN = "</s>"
|
|
DEFAULT_BOS_TOKEN = "</s>"
|
|
DEFAULT_UNK_TOKEN = "</s>"
|
|
|
|
|
|
def prepare_llama_tokenizer_and_embedding(
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
model: transformers.PreTrainedModel,
|
|
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
|
):
|
|
"""prepare llama tokenizer and embedding.
|
|
|
|
"""
|
|
|
|
if tokenizer.pad_token is None:
|
|
smart_tokenizer_and_embedding_resize(
|
|
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
)
|
|
|
|
tokenizer.add_special_tokens({
|
|
"eos_token": DEFAULT_EOS_TOKEN,
|
|
"bos_token": DEFAULT_BOS_TOKEN,
|
|
"unk_token": DEFAULT_UNK_TOKEN,
|
|
})
|
|
|
|
return tokenizer
|
|
|
|
|
|
def smart_tokenizer_and_embedding_resize(
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
model: transformers.PreTrainedModel,
|
|
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
|
):
|
|
"""Resize tokenizer and embedding.
|
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
|
"""
|
|
|
|
if tokenizer.pad_token is None:
|
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
|
|
|
if isinstance(model, LlamaLM):
|
|
model = model.get_base_model()
|
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
if num_new_tokens > 0:
|
|
input_embeddings = model.get_input_embeddings().weight.data
|
|
output_embeddings = model.get_output_embeddings().weight.data
|
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|