2023-09-24 15:14:11 +00:00
import copy
import math
from typing import Any , Dict , List , Optional , Tuple
import numpy as np
import torch
from colossal_eval . utils import Conversation , get_batch_prompt , is_rank_0
from peft import PeftModel
from tqdm import tqdm
from transformers import AutoConfig , AutoModel , AutoModelForCausalLM , AutoTokenizer
from colossalai . logging import DistributedLogger
2023-12-12 06:47:35 +00:00
from colossalai . shardformer import ShardConfig , ShardFormer
2024-02-06 02:53:03 +00:00
from colossalai . utils import get_current_device
2023-09-24 15:14:11 +00:00
from . base import BaseModel
class HuggingFaceModel ( BaseModel ) :
Model wrapper around HuggingFace AutoModel models .
Args :
path : The path to a HuggingFace model .
model_max_length : The maximum sequence length of the model .
tokenizer_path : The path to the tokenizer .
tokenizer_kwargs : Keyword arguments for the tokenizer .
peft_path : The name or path to the HuggingFace ' s PEFT model.
model_kwargs : Keyword arguments for the model .
prompt_template : The model ' s prompt template.
batch_size : Batch size for inference .
logger : Logger for the model .
2023-12-12 06:47:35 +00:00
shard_config : Shard config for tensor parallel .
2023-09-24 15:14:11 +00:00
def __init__ (
self ,
path : str ,
model_max_length : int = 2048 ,
tokenizer_path : Optional [ str ] = None ,
tokenizer_kwargs : dict = dict ( ) ,
peft_path : Optional [ str ] = None ,
model_kwargs : Dict = None ,
prompt_template : Conversation = None ,
batch_size : int = 1 ,
logger : DistributedLogger = None ,
2023-12-12 06:47:35 +00:00
shard_config : ShardConfig = None ,
2023-09-24 15:14:11 +00:00
) :
super ( ) . __init__ (
path = path ,
model_max_length = model_max_length ,
prompt_template = prompt_template ,
batch_size = batch_size ,
logger = logger ,
self . _load_tokenizer ( path = path , tokenizer_path = tokenizer_path , tokenizer_kwargs = tokenizer_kwargs )
2023-12-12 06:47:35 +00:00
self . _load_model ( path = path , model_kwargs = model_kwargs , peft_path = peft_path , shard_config = shard_config )
2023-09-24 15:14:11 +00:00
def _get_choices_indices ( self , language : str ) :
Get indices for each choice
Some tokenizer will insert BOS if you don ' t specify add_special_tokens=False such as Llama-2.
The indices for choices may be different given the context . For example , for Llama - 2 tokenizer , for Chinese context like " 答案: {choice} " , indices for choices A , B , C and D are 29909 , 29933 , 29907 and 29928 , for English context like " Answer: {choice} " , indices for choices A , B , C and D are 319 , 350 , 315 and 360.
print ( self . tokenizer ( " 答案: A " ) ) to see
print ( self . tokenizer ( " Answer: A " ) ) to see
# A trick for get "all" tokens ids related to given choices.
self . indices_for_choices = [ [ ] for _ in range ( 2 ) ]
for choice in self . choices :
self . indices_for_choices [ 0 ] . append (
self . tokenizer ( f " Answer: { choice } " , add_special_tokens = False ) . input_ids [ - 1 ]
self . indices_for_choices [ 1 ] . append ( self . tokenizer ( f " 答案: { choice } " , add_special_tokens = False ) . input_ids [ - 1 ] )
def _load_tokenizer ( self , path : str , tokenizer_path : Optional [ str ] , tokenizer_kwargs : dict ) :
Load tokenizer .
Args :
path : The path to the model . Usually it also serves as the path to the tokenizer .
tokenizer_path : The path to the tokenzier .
tokenizer_kwargs : Keyword arguments for the tokenizer .
if self . batch_size > 1 :
tokenizer_kwargs . update ( { " padding_side " : " left " } )
tokenizer_kwargs . update ( { " truncation_side " : " left " } )
self . tokenizer = AutoTokenizer . from_pretrained ( tokenizer_path if tokenizer_path else path , * * tokenizer_kwargs )
if self . tokenizer . pad_token_id is None :
self . logger . warning ( " pad_token_id is not set for the tokenizer. " " Using eos_token_id as pad_token_id. " )
if self . tokenizer . eos_token :
self . tokenizer . pad_token = self . tokenizer . eos_token
2023-10-31 02:30:03 +00:00
elif hasattr ( self . tokenizer , " eod_id " ) :
2023-09-24 15:14:11 +00:00
# Qwen has an eod token "<|endoftext|>".
self . tokenizer . pad_token_id = self . tokenizer . eod_id
2023-12-12 06:47:35 +00:00
def _load_model (
self , path : str , model_kwargs : dict , peft_path : Optional [ str ] = None , shard_config : ShardConfig = None
) :
2023-09-24 15:14:11 +00:00
Load model .
Args :
path : The path to the model .
model_kwargs : Keyword arguments for the model .
peft_path : The path to the peft model .
2023-12-12 06:47:35 +00:00
shard_config : Shard config for tensor parallel .
2023-09-24 15:14:11 +00:00
if " torch_dtype " in model_kwargs :
model_kwargs [ " torch_dtype " ] = eval ( model_kwargs [ " torch_dtype " ] )
2023-12-15 07:06:06 +00:00
else :
model_kwargs . setdefault ( " torch_dtype " , torch . float16 )
2023-09-24 15:14:11 +00:00
2023-12-12 06:47:35 +00:00
if " config " in model_kwargs :
model_kwargs [ " config " ] = AutoConfig . from_pretrained ( model_kwargs [ " config " ] )
2023-09-24 15:14:11 +00:00
2023-12-12 06:47:35 +00:00
if shard_config is not None :
self . model = AutoModel . from_pretrained ( path , * * model_kwargs )
shard_former = ShardFormer ( shard_config )
self . model , sharded_parameters = shard_former . optimize ( self . model )
2024-02-06 02:53:03 +00:00
self . model . to ( get_current_device ( ) )
2023-12-12 06:47:35 +00:00
if peft_path is not None :
raise NotImplementedError ( " ShardFormer for PEFT models is not implemented. " )
else :
2024-02-06 02:53:03 +00:00
self . model = AutoModel . from_pretrained ( path , * * model_kwargs ) . to ( get_current_device ( ) )
2023-12-12 06:47:35 +00:00
if peft_path is not None :
self . model = PeftModel . from_pretrained ( self . model , peft_path , is_trainable = False )
2023-09-24 15:14:11 +00:00
self . model . eval ( )
def _calculate_loss ( self , input_ids_list : List [ torch . LongTensor ] , labels : List [ torch . LongTensor ] ) - > Tuple [ List ] :
Calculate loss only on target tokens .
Hugging Face generate ( ) function can ' t return per sample loss.
It will only return the mean of the loss in a batch .
In torch . nn . CrossEntropyLoss ( ) , reduction should be specified as " none " to get per sample loss .
Args :
input_ids_list : A batch of input token ids .
labels : A batch of labels .
Returns :
A list of loss .
input_ids = torch . nn . utils . rnn . pad_sequence (
input_ids_list , batch_first = True , padding_value = self . tokenizer . pad_token_id
2024-02-06 02:53:03 +00:00
) . to ( get_current_device ( ) )
2023-09-24 15:14:11 +00:00
labels = torch . nn . utils . rnn . pad_sequence ( labels , batch_first = True , padding_value = IGNORE_INDEX ) . to (
2024-02-06 02:53:03 +00:00
get_current_device ( )
2023-09-24 15:14:11 +00:00
2024-02-06 02:53:03 +00:00
attention_mask = input_ids . ne ( self . tokenizer . pad_token_id ) . to ( get_current_device ( ) )
2023-09-24 15:14:11 +00:00
outputs = self . model ( input_ids , attention_mask = attention_mask ) [ 0 ]
shift_logits = outputs [ . . . , : - 1 , : ] . contiguous ( )
shift_labels = labels [ . . . , 1 : ] . contiguous ( )
loss_fct = torch . nn . CrossEntropyLoss ( reduction = " none " , ignore_index = IGNORE_INDEX )
loss = loss_fct ( shift_logits . view ( - 1 , shift_logits . size ( - 1 ) ) , shift_labels . view ( - 1 ) ) . view ( shift_labels . size ( ) )
2023-12-12 06:47:35 +00:00
lens = ( labels [ . . . , 1 : ] != IGNORE_INDEX ) . sum ( - 1 ) . cpu ( ) . numpy ( )
2023-09-24 15:14:11 +00:00
loss_sum = loss . sum ( - 1 ) . to ( torch . float32 ) . cpu ( ) . detach ( ) . numpy ( )
return loss_sum . tolist ( ) , lens . tolist ( )
def _get_truncated_prompts ( self , inputs : List [ str ] , max_new_tokens : int ) - > List [ str ] :
Truncate the input sequence to fit model_max_length ( we suggest truncate in the middle , since the left and right side may contain crucial instructions )
https : / / github . com / THUDM / LongBench / blob / main / pred . py #L16
Args :
inputs : A batch of input prompts .
max_new_tokens : Max new tokens for model to generate .
Returns :
Truncated prompts .
truncated_inputs = copy . deepcopy ( inputs )
for i , input in enumerate ( inputs ) :
tokenized_prompt = self . tokenizer ( input , truncation = False , return_tensors = " pt " ) . input_ids [ 0 ]
if len ( tokenized_prompt ) > self . model_max_length - max_new_tokens :
half = ( self . model_max_length - max_new_tokens ) / / 2
prompt = self . tokenizer . decode (
tokenized_prompt [ : half ] , skip_special_tokens = True
) + self . tokenizer . decode ( tokenized_prompt [ - half : ] , skip_special_tokens = True )
truncated_inputs [ i ] = prompt
return truncated_inputs
def _get_input_ids_and_labels_pretrain ( self , batch_prompt : List [ str ] ) - > Tuple [ List [ torch . LongTensor ] ] :
Get input_ids and labels for pretrain data .
We only need batch_prompt because for pretain dataset , we don ' t need to predict new tokens.
Args :
batch_prompt : A batch of prompt .
Returns :
Input_ids and labels for the given batch .
input_ids_list = [ ]
labels_list = [ ]
bytes_list = [ ]
for input in batch_prompt :
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
# After all, the rest of the original string doesn't need to be tokenized at the first place.
ratio = [ 16 , 8 , 4 , 2 , 1 ]
tokenized = None
for r in ratio :
tokenized = self . tokenizer (
[ input [ 0 : len ( input ) / / r ] ] , truncation = True , max_length = self . model_max_length , return_tensors = " pt "
if tokenized . input_ids . size ( 1 ) > = self . model_max_length :
input_ids = copy . deepcopy ( tokenized [ " input_ids " ] ) [ 0 ]
target_ids = copy . deepcopy ( input_ids )
string = self . tokenizer . decode ( tokenized . input_ids [ 0 ] , skip_special_tokens = True )
bytes_list . append ( len ( string . encode ( " utf-8 " ) ) )
input_ids_list . append ( input_ids )
labels_list . append ( target_ids )
return input_ids_list , labels_list , bytes_list
def _get_input_ids_and_labels (
self , batch_prompt : List [ str ] , batch_target : List [ List [ str ] ] , pretrain : bool
) - > Tuple [ List [ torch . LongTensor ] ] :
Get input_ids and labels for the given data .
Args :
batch_prompt : A batch of prompt .
batch_target : A batch of target .
Returns :
Input_ids and labels for the given batch .
if pretrain :
2023-12-12 06:47:35 +00:00
batch = [ ]
# Concatenate prompt and target answers.
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
for p , b in zip ( batch_prompt , batch_target ) :
batch . append ( p + b [ 0 ] )
return self . _get_input_ids_and_labels_pretrain ( batch )
2023-09-24 15:14:11 +00:00
input_ids_list = [ ]
labels_list = [ ]
for input , targets in zip ( batch_prompt , batch_target ) :
for target in targets :
# TODO: Improve the labeling process. Should annotate the border by adding special tokens.
target_tokenized = self . tokenizer (
[ target ] , truncation = True , max_length = self . model_max_length , return_tensors = " pt "
# Get prompt with length model_max_length - len(target_tokenized).
# Reserve some space for target answer tokens using max_new_tokens.
# This will generate the correct start_idx and end_idx.
max_new_tokens = target_tokenized [ " input_ids " ] [ 0 ] . size ( 0 )
prompt_with_correct_length = self . _get_truncated_prompts ( [ input ] , max_new_tokens ) [ 0 ]
input_tokenized = self . tokenizer (
[ prompt_with_correct_length ] ,
truncation = True ,
max_length = self . model_max_length - max_new_tokens ,
return_tensors = " pt " ,
target_tokenized = self . tokenizer (
[ prompt_with_correct_length + target ] ,
truncation = True ,
max_length = self . model_max_length ,
return_tensors = " pt " ,
start_idx = input_tokenized [ " input_ids " ] [ 0 ] . size ( 0 )
end_idx = target_tokenized [ " input_ids " ] [ 0 ] . size ( 0 )
# Sometimes if the target is only an option such as A, B, C and D, the length of input_tokenized is equal to the length of target_tokenized, so we need -1.
# This is caused by the different behavior of tokenizers.
# For example, the tokenizer for Baichuan and Llama will cause such problem in a plain prompt setting.
# The length of the tokenized sequences for prompt "Answer: " and "Answer: A" is the same.
# Baichuan: [29394, 31143, 31106] [29394, 31143, 703]
# Llama: [673, 29901, 29871] [673, 29901, 319]
# The length for sequence "prompt" and "prompt + A" is equal.
# For ChatGLM, the length of the tokenized sequences is different.
# ChatGLM: [16583, 12] [16583, 12, 167]
if start_idx == end_idx :
start_idx - = 1
input_ids = copy . deepcopy ( target_tokenized [ " input_ids " ] ) [ 0 ]
target_ids = copy . deepcopy ( input_ids )
mask = torch . zeros_like ( target_ids , dtype = torch . bool )
mask [ start_idx : end_idx ] = True
target_ids [ ~ mask ] = IGNORE_INDEX
input_ids_list . append ( input_ids )
labels_list . append ( target_ids )
return input_ids_list , labels_list , None
def inference ( self , data : List [ Dict ] , inference_kwargs : Dict [ str , Any ] , debug : bool = False ) - > List [ Dict ] :
Infer the given data .
This function will call self . generate ( ) to get model outputs and also self . model ( ) to get logits .
Args :
data : The data for inference .
inference_kwargs : Arguments for inference .
debug : Whether to display generated prompt for debugging .
Returns :
Inference results .
calculate_loss = inference_kwargs [ " calculate_loss " ]
classes = inference_kwargs [ " all_classes " ]
language = inference_kwargs [ " language " ]
pretrain = inference_kwargs [ " pretrain " ]
max_new_tokens = inference_kwargs [ " max_new_tokens " ]
few_shot_data = inference_kwargs . get ( " few_shot_data " , None )
# Some classification questions' options are texts not a single letter such as A, B, C and D.
# If the text length is greater than 1, we won't calculate loss over choices.
if classes is not None and any ( len ( c ) > 1 for c in classes ) :
classes = None
self . choices = classes
self . indices_for_choices = None
if self . choices :
# Get indices for each choice
self . _get_choices_indices ( language )
self . str_label_map = { choice : idx for idx , choice in enumerate ( self . choices ) }
2023-11-09 05:41:50 +00:00
turn = 0 if not isinstance ( data [ 0 ] [ " output " ] , list ) else len ( data [ 0 ] [ " output " ] ) + 1
turn_desc = " " if turn == 0 else f " -turn { turn } "
2023-09-24 15:14:11 +00:00
bar = tqdm (
range ( math . ceil ( len ( data ) / self . batch_size ) ) ,
2023-11-09 05:41:50 +00:00
desc = f " { data [ 0 ] [ ' dataset ' ] } - { data [ 0 ] [ ' category ' ] } { turn_desc } Inference steps " ,
2023-09-24 15:14:11 +00:00
disable = not is_rank_0 ( ) ,
loss_fct = torch . nn . CrossEntropyLoss ( reduction = " none " )
answers = copy . deepcopy ( data )
for i in range ( 0 , len ( data ) , self . batch_size ) :
batch = data [ i : i + self . batch_size ]
batch_prompt , batch_target = get_batch_prompt (
self . prompt_template , batch , few_shot_data , self . tokenizer , language , self . model_max_length
if is_rank_0 ( ) and debug and i == 0 :
self . logger . info (
f " Inference arguments for dataset { data [ 0 ] [ ' dataset ' ] } category { data [ 0 ] [ ' category ' ] } is: \n { inference_kwargs } "
self . logger . info ( " - " * 120 )
self . logger . info ( " An example prompt and prompt with target is: " )
self . logger . info ( " - " * 120 )
self . logger . info ( batch_prompt [ 0 ] )
self . logger . info ( " - " * 120 )
self . logger . info ( batch_prompt [ 0 ] + batch_target [ 0 ] [ 0 ] )
if not pretrain :
batch_decodes , scores = self . generate ( batch_prompt , max_new_tokens )
if calculate_loss :
batch_losses , batch_target_token_nums , batch_bytes_nums = self . get_loss (
batch_prompt , batch_target , pretrain
probs = [ ]
if self . indices_for_choices :
scores = scores . to ( torch . float32 )
# If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample.
# Otherwise this will violate the single-choice setting.
if calculate_loss :
labels = [ self . str_label_map [ answers [ i + j ] [ " target " ] ] for j in range ( len ( batch_decodes ) ) ]
loss_over_choices = loss_fct ( scores , torch . tensor ( labels , dtype = torch . long ) ) . numpy ( ) . tolist ( )
2023-12-12 06:47:35 +00:00
probs = scores . numpy ( ) . tolist ( )
2023-09-24 15:14:11 +00:00
probs = [
{ choice : probs [ i ] [ self . str_label_map [ choice ] ] for choice in self . choices } for i in range ( len ( probs ) )
for j in range ( len ( batch_prompt ) ) :
if not pretrain :
2023-11-09 05:41:50 +00:00
if isinstance ( answers [ i + j ] [ " output " ] , list ) :
answers [ i + j ] [ " output " ] . append ( batch_decodes [ j ] . strip ( ) )
else :
answers [ i + j ] [ " output " ] = batch_decodes [ j ] . strip ( )
2023-09-24 15:14:11 +00:00
if isinstance ( scores , torch . Tensor ) :
2023-12-12 06:47:35 +00:00
answers [ i + j ] [ " logits_over_choices " ] = probs [ j ]
2023-09-24 15:14:11 +00:00
if calculate_loss :
answers [ i + j ] [ " loss_over_choices " ] = loss_over_choices [ j ]
if calculate_loss :
answers [ i + j ] [ " loss " ] = ( np . array ( batch_losses [ j ] ) / np . array ( batch_target_token_nums [ j ] ) ) . tolist ( )
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
# However, loss (which is per sample loss) suffices for most cases.
answers [ i + j ] [ " loss_sum " ] = batch_losses [ j ]
answers [ i + j ] [ " token_num " ] = batch_target_token_nums [ j ]
if batch_bytes_nums :
answers [ i + j ] [ " byte_num " ] = batch_bytes_nums [ j ]
bar . update ( )
return answers
@torch.no_grad ( )
def generate ( self , inputs : List [ str ] , max_new_tokens : int , * * kwargs ) - > List [ str ] :
""" Generate results given a list of inputs and get logits of the first new token over choices.
Args :
inputs : A list of strings .
max_new_tokens : Max new tokens for generation .
kwargs : Key arguments for generation
Returns :
A list of generated strings and logits over choices .
Note :
Currently the function only returns the logits of the first new token .
It is used for single choice question .
For multiple choices question , please avoid using the loss over choices .
You should set argument choices as None in self . inference ( ) .
truncated_inputs = self . _get_truncated_prompts ( inputs , max_new_tokens )
encoded_inputs = self . tokenizer (
truncated_inputs ,
padding = True ,
truncation = True ,
return_tensors = " pt " ,
return_token_type_ids = False ,
max_length = self . model_max_length - max_new_tokens ,
2024-02-06 02:53:03 +00:00
) . to ( get_current_device ( ) )
2023-09-24 15:14:11 +00:00
# Set output_scores=True to get prediction scores.
outputs = self . model . generate (
2023-12-12 06:47:35 +00:00
* * encoded_inputs ,
max_new_tokens = max_new_tokens ,
return_dict_in_generate = True ,
output_scores = True ,
do_sample = False ,
use_cache = True ,
* * kwargs ,
2023-09-24 15:14:11 +00:00
# We only need to decode predicted tokens.
sequences = outputs . sequences [ : , encoded_inputs [ " input_ids " ] . shape [ 1 ] : ]
scores = [ ]
if self . indices_for_choices :
# If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
# The indices are the tokenization results of the options for the single-choice question.
# For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
for option_indices in self . indices_for_choices :
scores . append ( outputs . scores [ 0 ] [ : , option_indices ] . detach ( ) . cpu ( ) )
scores = torch . max ( torch . stack ( scores ) , dim = 0 ) [ 0 ]
decoded_sequences = self . tokenizer . batch_decode ( sequences , skip_special_tokens = True )
return decoded_sequences , scores
@torch.no_grad ( )
def get_loss ( self , batch_prompt : List [ str ] , batch_target : List [ List [ str ] ] , pretrain : bool ) - > List [ List [ float ] ] :
Calculate loss only on target tokens .
Args :
batch : A batch of prompt without target answer .
batch_target : A batch of target answer . Sometimes one question can have multiple target answers .
Returns :
Loss .
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
# We don't need to generate new tokens.
# Target answer's length is usually << model_max_length, but we still call it in case.
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
if not pretrain :
batch_target = [ self . _get_truncated_prompts ( prompt_target , 0 ) for prompt_target in batch_target ]
# Get the number of target answers for different questions
batch_target_nums = [ len ( prompt_target ) for prompt_target in batch_target ]
input_ids_list , labels_list , bytes_list = self . _get_input_ids_and_labels ( batch_prompt , batch_target , pretrain )
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
# We will generate new batches.
losses = [ ]
target_token_nums = [ ]
batched_input_ids = [
input_ids_list [ i : i + self . batch_size ] for i in range ( 0 , len ( input_ids_list ) , self . batch_size )
batched_labels = [ labels_list [ i : i + self . batch_size ] for i in range ( 0 , len ( labels_list ) , self . batch_size ) ]
for batch_input_ids , batch_labels in zip ( batched_input_ids , batched_labels ) :
losses_per_batch , target_token_num_per_batch = self . _calculate_loss ( batch_input_ids , batch_labels )
losses . extend ( losses_per_batch )
target_token_nums . extend ( target_token_num_per_batch )
start_indice = 0
losses_per_sample = [ ]
target_token_nums_per_sample = [ ]
bytes_nums_per_sample = [ ]
for length in batch_target_nums :
losses_per_sample . append ( losses [ start_indice : start_indice + length ] )
target_token_nums_per_sample . append ( target_token_nums [ start_indice : start_indice + length ] )
if bytes_list :
bytes_nums_per_sample . append ( bytes_list [ start_indice : start_indice + length ] )
start_indice + = length
if bytes_list :
return losses_per_sample , target_token_nums_per_sample , bytes_nums_per_sample
return losses_per_sample , target_token_nums_per_sample , None
class HuggingFaceCausalLM ( HuggingFaceModel ) :
Model wrapper around HuggingFace AutoModelForCausalLM models .
Args :
path : The path to a HuggingFace model .
model_max_length : The maximum sequence length of the model .
tokenizer_path : The path to the tokenizer .
tokenizer_kwargs : Keyword arguments for the tokenizer .
peft_path : The name or path to the HuggingFace ' s PEFT model.
model_kwargs : Keyword arguments for the model .
prompt_template : The model ' s prompt template.
batch_size : Batch size for inference .
logger : Logger for the model .
2023-12-12 06:47:35 +00:00
shard_config : Shard config for tensor parallel .
2023-09-24 15:14:11 +00:00
2023-12-12 06:47:35 +00:00
def _load_model (
self , path : str , model_kwargs : dict , peft_path : Optional [ str ] = None , shard_config : ShardConfig = None
) :
2023-09-24 15:14:11 +00:00
Load model .
Args :
path : The path to the model .
model_kwargs : Keyword arguments for the model .
peft_path : The path to the peft model .
2023-12-12 06:47:35 +00:00
shard_config : Shard config for tensor parallel .
2023-09-24 15:14:11 +00:00
if " torch_dtype " in model_kwargs :
model_kwargs [ " torch_dtype " ] = eval ( model_kwargs [ " torch_dtype " ] )
2023-12-15 07:06:06 +00:00
else :
model_kwargs . setdefault ( " torch_dtype " , torch . float16 )
2023-09-24 15:14:11 +00:00
if " config " in model_kwargs :
model_kwargs [ " config " ] = AutoConfig . from_pretrained ( model_kwargs [ " config " ] )
2023-12-12 06:47:35 +00:00
if shard_config is not None :
self . model = AutoModelForCausalLM . from_pretrained ( path , * * model_kwargs )
shard_former = ShardFormer ( shard_config )
self . model , sharded_parameters = shard_former . optimize ( self . model )
2024-02-06 02:53:03 +00:00
self . model . to ( get_current_device ( ) )
2023-12-12 06:47:35 +00:00
if peft_path is not None :
raise NotImplementedError ( " ShardFormer for PEFT models is not implemented. " )
else :
2024-02-06 02:53:03 +00:00
self . model = AutoModelForCausalLM . from_pretrained ( path , * * model_kwargs ) . to ( get_current_device ( ) )
2023-12-12 06:47:35 +00:00
if peft_path is not None :
self . model = PeftModel . from_pretrained ( self . model , peft_path , is_trainable = False )
2023-09-24 15:14:11 +00:00
self . model . eval ( )