@ -1,19 +1,17 @@
import copy
import json
from typing import Dict , Sequence
import torch
from datasets import load_dataset
from torch . utils . data import Dataset
from transformers import AutoTokenizer
import torch
from tqdm import tqdm
import json
from tqdm import tqdm
import json
from transformers import AutoTokenizer
IGNORE_INDEX = - 100
def _tokenize_fn ( strings : Sequence [ str ] , tokenizer : AutoTokenizer , max_length : int = 512 ) - > Dict :
def _tokenize_fn ( strings : Sequence [ str ] , tokenizer : AutoTokenizer , max_length : int = 512 ) - > Dict :
""" Tokenize a list of strings. """
tokenized_list = [
tokenizer (
@ -36,15 +34,12 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :in
)
def preprocess (
sources : Sequence [ str ] ,
targets : Sequence [ str ] ,
tokenizer : AutoTokenizer ,
max_length : int = 512
) - > Dict :
def preprocess ( sources : Sequence [ str ] , targets : Sequence [ str ] , tokenizer : AutoTokenizer , max_length : int = 512 ) - > Dict :
""" Preprocess the data by tokenizing. """
examples = [ s + t for s , t in zip ( sources , targets ) ]
examples_tokenized , sources_tokenized = [ _tokenize_fn ( strings , tokenizer , max_length ) for strings in ( examples , sources ) ]
examples_tokenized , sources_tokenized = [
_tokenize_fn ( strings , tokenizer , max_length ) for strings in ( examples , sources )
]
input_ids = examples_tokenized [ " input_ids " ]
labels = copy . deepcopy ( input_ids )
for label , source_len in zip ( labels , sources_tokenized [ " input_ids_lens " ] ) :
@ -53,59 +48,60 @@ def preprocess(
class EasySupervisedDataset ( Dataset ) :
def __init__ ( self , data_file : str , tokenizer : AutoTokenizer , max_length : int = 512 ) - > None :
super ( EasySupervisedDataset , self ) . __init__ ( )
with open ( data_file , " r " , encoding = " UTF-8 " ) as f :
def __init__ ( self , data_file : str , tokenizer : AutoTokenizer , max_length : int = 512 ) - > None :
super ( EasySupervisedDataset , self ) . __init__ ( )
with open ( data_file , " r " , encoding = " UTF-8 " ) as f :
all_lines = f . readlines ( )
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources , targets = [ ] , [ ]
sources , targets = [ ] , [ ]
for line in all_lines :
if " 回答: " in line :
sep_index = line . index ( " 回答: " )
sources . append ( line [ : sep_index + 3 ] )
targets . append ( line [ sep_index + 3 : ] + tokenizer . eos_token )
sources . append ( line [ : sep_index + 3 ] )
targets . append ( line [ sep_index + 3 : ] + tokenizer . eos_token )
else :
sources . append ( line )
targets . append ( " " + tokenizer . eos_token )
data_dict = preprocess ( sources , targets , tokenizer , max_length )
targets . append ( " " + tokenizer . eos_token )
data_dict = preprocess ( sources , targets , tokenizer , max_length )
self . input_ids = data_dict [ " input_ids " ]
self . labels = data_dict [ " labels " ]
self . data_file = data_file
def __len__ ( self ) :
return len ( self . input_ids )
def __getitem__ ( self , i ) - > Dict [ str , torch . Tensor ] :
return dict ( input_ids = self . input_ids [ i ] , labels = self . labels [ i ] )
def __repr__ ( self ) :
return f " LawSupervisedDataset(data_file= { self . data_file } , input_ids_len= { len ( self . input_ids ) } , labels_len= { len ( self . labels ) } ) "
def __str__ ( self ) :
return f " LawSupervisedDataset(data_file= { self . data_file } , input_ids_len= { len ( self . input_ids ) } , labels_len= { len ( self . labels ) } ) "
class EasyPromptsDataset ( Dataset ) :
def __init__ ( self , data_file : str , tokenizer : AutoTokenizer , max_length : int = 96 ) - > None :
super ( EasyPromptsDataset , self ) . __init__ ( )
with open ( data_file , " r " , encoding = " UTF-8 " ) as f :
def __init__ ( self , data_file : str , tokenizer : AutoTokenizer , max_length : int = 96 ) - > None :
super ( EasyPromptsDataset , self ) . __init__ ( )
with open ( data_file , " r " , encoding = " UTF-8 " ) as f :
all_lines = f . readlines ( )
all_lines = [ line if " 回答: " not in line else line [ : line . index ( " 回答: " ) + 3 ] for line in all_lines ]
all_lines = [ line if " 回答: " not in line else line [ : line . index ( " 回答: " ) + 3 ] for line in all_lines ]
self . prompts = [
tokenizer ( line ,
return_tensors = ' pt ' ,
max_length = max_length ,
padding = ' max_length ' ,
truncation = True ) [ ' input_ids ' ] . to ( torch . cuda . current_device ( ) ) . squeeze ( 0 )
tokenizer ( line , return_tensors = ' pt ' , max_length = max_length , padding = ' max_length ' ,
truncation = True ) [ ' input_ids ' ] . to ( torch . cuda . current_device ( ) ) . squeeze ( 0 )
for line in tqdm ( all_lines )
]
self . data_file = data_file
def __len__ ( self ) :
return len ( self . prompts )
def __getitem__ ( self , idx ) :
return self . prompts [ idx ]
def __repr__ ( self ) :
return f " LawPromptsDataset(data_file= { self . data_file } , prompts_len= { len ( self . prompts ) } ) "
@ -114,8 +110,9 @@ class EasyPromptsDataset(Dataset):
class EasyRewardDataset ( Dataset ) :
def __init__ ( self , train_file : str , tokenizer : AutoTokenizer , special_token = None , max_length = 512 ) - > None :
super ( EasyRewardDataset , self ) . __init__ ( )
def __init__ ( self , train_file : str , tokenizer : AutoTokenizer , special_token = None , max_length = 512 ) - > None :
super ( EasyRewardDataset , self ) . __init__ ( )
self . chosen = [ ]
self . reject = [ ]
if special_token is None :
@ -124,11 +121,11 @@ class EasyRewardDataset(Dataset):
self . end_token = special_token
print ( self . end_token )
#read all lines in the train_file to a list
with open ( train_file , " r " , encoding = " UTF-8 " ) as f :
with open ( train_file , " r " , encoding = " UTF-8 " ) as f :
all_lines = f . readlines ( )
for line in tqdm ( all_lines ) :
data = json . loads ( line )
prompt = " 提问: " + data [ ' prompt ' ] + " 回答: "
prompt = " 提问: " + data [ ' prompt ' ] + " 回答: "
chosen = prompt + data [ ' chosen ' ] + self . end_token
chosen_token = tokenizer ( chosen ,
@ -159,7 +156,7 @@ class EasyRewardDataset(Dataset):
def __getitem__ ( self , idx ) :
return self . chosen [ idx ] [ " input_ids " ] , self . chosen [ idx ] [ " attention_mask " ] , self . reject [ idx ] [
" input_ids " ] , self . reject [ idx ] [ " attention_mask " ]
#python representation of the object and the string representation of the object
def __repr__ ( self ) :
return f " LawRewardDataset(chosen_len= { len ( self . chosen ) } , reject_len= { len ( self . reject ) } ) "
@ -167,27 +164,30 @@ class EasyRewardDataset(Dataset):
def __str__ ( self ) :
return f " LawRewardDataset(chosen_len= { len ( self . chosen ) } , reject_len= { len ( self . reject ) } ) "
'''
Easy SFT just accept a text file which can be read line by line . However the datasest will group texts together to max_length so LLM will learn the texts meaning better .
If individual lines are not related , just set is_group_texts to False .
'''
class EasySFTDataset ( Dataset ) :
def __init__ ( self , data_file : str , tokenizer : AutoTokenizer , max_length = 512 , is_group_texts = True ) - > None :
def __init__ ( self , data_file : str , tokenizer : AutoTokenizer , max_length = 512 , is_group_texts = True ) - > None :
super ( ) . __init__ ( )
#read the data_file line by line
with open ( data_file , " r " , encoding = " UTF-8 " ) as f :
with open ( data_file , " r " , encoding = " UTF-8 " ) as f :
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = [ ]
for line in f :
encoded_ids = tokenizer . encode ( line )
#if the encoded_ids is longer than max_length, then split it into several parts
if len ( encoded_ids ) > max_length :
for i in range ( 0 , len ( encoded_ids ) , max_length ) :
raw_input_ids . append ( encoded_ids [ i : i + max_length ] )
for i in range ( 0 , len ( encoded_ids ) , max_length ) :
raw_input_ids . append ( encoded_ids [ i : i + max_length ] )
else :
raw_input_ids . append ( encoded_ids )
grouped_inpup_ids = [ ]
current_input_ids = [ ]
attention_mask = [ ]
@ -199,44 +199,42 @@ class EasySFTDataset(Dataset):
#pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len ( current_input_ids )
current_input_ids . extend ( [ tokenizer . pad_token_id ] * padded_length )
grouped_inpup_ids . append ( torch . tensor ( current_input_ids , dtype = torch . long ) )
attention_mask . append ( torch . tensor ( [ 1 ] * ( max_length - padded_length ) + [ 0 ] * padded_length , dtype = torch . long ) )
grouped_inpup_ids . append ( torch . tensor ( current_input_ids , dtype = torch . long ) )
attention_mask . append (
torch . tensor ( [ 1 ] * ( max_length - padded_length ) + [ 0 ] * padded_length , dtype = torch . long ) )
current_input_ids = [ ]
else :
current_input_ids . extend ( input_ids )
if len ( current_input_ids ) > 0 :
padded_length = max_length - len ( current_input_ids )
current_input_ids . extend ( [ tokenizer . pad_token_id ] * padded_length )
grouped_inpup_ids . append ( torch . tensor ( current_input_ids , dtype = torch . long ) )
attention_mask . append ( torch . tensor ( [ 1 ] * ( max_length - padded_length ) + [ 0 ] * padded_length , dtype = torch . long ) )
grouped_inpup_ids . append ( torch . tensor ( current_input_ids , dtype = torch . long ) )
attention_mask . append (
torch . tensor ( [ 1 ] * ( max_length - padded_length ) + [ 0 ] * padded_length , dtype = torch . long ) )
else :
#just append the raw_input_ids to max_length
for input_ids in raw_input_ids :
padded_length = max_length - len ( input_ids )
input_ids . extend ( [ tokenizer . pad_token_id ] * padded_length )
attention_mask . append ( torch . tensor ( [ 1 ] * ( max_length - padded_length ) + [ 0 ] * padded_length , dtype = torch . long ) )
grouped_inpup_ids . append ( torch . tensor ( input_ids , dtype = torch . long ) )
attention_mask . append (
torch . tensor ( [ 1 ] * ( max_length - padded_length ) + [ 0 ] * padded_length , dtype = torch . long ) )
grouped_inpup_ids . append ( torch . tensor ( input_ids , dtype = torch . long ) )
self . input_ids = grouped_inpup_ids
self . labels = copy . deepcopy ( self . input_ids )
self . file_name = data_file
self . attention_mask = attention_mask
def __len__ ( self ) :
return len ( self . input_ids )
#get item from dataset
def __getitem__ ( self , idx ) :
return dict ( input_ids = self . input_ids [ idx ] , labels = self . labels [ idx ] , attention_mask = self . attention_mask [ idx ] )
def __getitem__ ( self , idx ) :
return dict ( input_ids = self . input_ids [ idx ] , labels = self . labels [ idx ] , attention_mask = self . attention_mask [ idx ] )
#generate the dataset description to be printed by print in python
def __repr__ ( self ) :
return f " EasySFTDataset(len= { len ( self ) } , \n file_name is { self . file_name } ) "
#generate the dataset description to be printed by print in python
def __str__ ( self ) :
return f " EasySFTDataset(len= { len ( self ) } , \n file_name is { self . file_name } ) "