[Fix]Fix wrong string cutoff in the script for sft text tokenizing (#106)

pull/121/head
Miao Zheng 2023-07-19 12:12:41 +08:00 committed by GitHub
parent efbf533570
commit 1095263082
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 44 deletions

View File

@ -1,10 +1,11 @@
import argparse import argparse
import json import json
import sentencepiece as spm
from tqdm import tqdm
import os.path as osp import os.path as osp
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import sentencepiece as spm
from tqdm import tqdm
def process(dataset_path, sp_model): def process(dataset_path, sp_model):
@ -33,15 +34,15 @@ def get_chat_format_data(ori_data):
Returns: Returns:
dict: data sample with chat format. dict: data sample with chat format.
""" """
input_str = ori_data['input'] input_str = ori_data["input"]
instruction_str = ori_data['instruction'] instruction_str = ori_data["instruction"]
output_str = ori_data['output'] output_str = ori_data["output"]
data = dict() data = dict()
if input_str != "": if input_str != "":
data['user'] = f'<|User|>:{instruction_str}\n{input_str}' data["user"] = f"<|User|>:{instruction_str}\n{input_str}"
else: else:
data['user'] = f'<|User|>:{instruction_str}' data["user"] = f"<|User|>:{instruction_str}"
data['bot'] = f'<|Bot|>:{output_str}' data["bot"] = f"<|Bot|>:{output_str}"
return data return data
@ -55,27 +56,27 @@ def tokenize(sample, sp_model):
Returns: Returns:
tuple: dumped processed data sample and length of tokens. tuple: dumped processed data sample and length of tokens.
""" """
special_tokens_map = {'<eoh>': 103167, '<eoa>': 103166, 'nl_id': 13} special_tokens_map = {"<eoh>": 103167, "<eoa>": 103166, "nl_id": 13}
token_ids = [sp_model.bos_id()] token_ids = [sp_model.bos_id()]
human_s = sample['user'] human_s = sample["user"]
ass_s = sample['bot'] ass_s = sample["bot"]
human_ids = sp_model.encode(human_s) + [ human_ids = sp_model.encode(human_s) + [special_tokens_map["<eoh>"], special_tokens_map["nl_id"]]
special_tokens_map["<eoh>"], special_tokens_map['nl_id']
]
human_ids_ignore = [-token_id for token_id in human_ids] human_ids_ignore = [-token_id for token_id in human_ids]
ass_template_ids = sp_model.encode('<|Assistant|>:') ass_template_ids = sp_model.encode("<|Bot|>:")
ass_template_ids_ignore = [-token_ids for token_ids in ass_template_ids] ass_template_ids_ignore = [-token_ids for token_ids in ass_template_ids]
ass_ids = ass_template_ids_ignore + sp_model.encode(ass_s[14:]) + [ ass_ids = (
special_tokens_map["<eoa>"], special_tokens_map['nl_id'] ass_template_ids_ignore
] + sp_model.encode(ass_s[8:])
+ [special_tokens_map["<eoa>"], special_tokens_map["nl_id"]]
)
token_ids += human_ids_ignore + ass_ids token_ids += human_ids_ignore + ass_ids
if len(token_ids) > 2047: if len(token_ids) > 2047:
token_ids = token_ids[:2047] token_ids = token_ids[:2047]
token_ids += [sp_model.eos_id()] token_ids += [sp_model.eos_id()]
line = str.encode(json.dumps({'tokens': token_ids}) + '\n') line = str.encode(json.dumps({"tokens": token_ids}) + "\n")
return line, len(token_ids) return line, len(token_ids)
@ -93,14 +94,14 @@ def dump_bin_meta_bin(samples, path, split_ratio=0.1):
number of train/valid samples of processed dataset. number of train/valid samples of processed dataset.
""" """
train_path = osp.join(path, 'train/en/') train_path = osp.join(path, "train/en/")
valid_path = osp.join(path, 'valid/en/') valid_path = osp.join(path, "valid/en/")
train_dir = Path(train_path) train_dir = Path(train_path)
valid_dir = Path(valid_path) valid_dir = Path(valid_path)
train_dir.mkdir(exist_ok=True, parents=True) train_dir.mkdir(exist_ok=True, parents=True)
valid_dir.mkdir(exist_ok=True, parents=True) valid_dir.mkdir(exist_ok=True, parents=True)
train_f = open(train_dir.joinpath('dataset.bin'), 'wb') train_f = open(train_dir.joinpath("dataset.bin"), "wb")
valid_f = open(valid_dir.joinpath('dataset.bin'), 'wb') valid_f = open(valid_dir.joinpath("dataset.bin"), "wb")
train_tokens = 0 train_tokens = 0
valid_tokens = 0 valid_tokens = 0
@ -113,8 +114,7 @@ def dump_bin_meta_bin(samples, path, split_ratio=0.1):
sample_length = len(samples) sample_length = len(samples)
np.random.seed(0) np.random.seed(0)
valid_indices = np.random.choice( valid_indices = np.random.choice(range(sample_length), int(sample_length * split_ratio)).tolist()
range(sample_length), int(sample_length * split_ratio)).tolist()
count = -1 count = -1
for line, token_num in samples: for line, token_num in samples:
@ -134,25 +134,19 @@ def dump_bin_meta_bin(samples, path, split_ratio=0.1):
train_f.close() train_f.close()
valid_f.close() valid_f.close()
np.save(open(train_dir.joinpath('dataset.bin.meta'), 'wb'), train_meta) np.save(open(train_dir.joinpath("dataset.bin.meta"), "wb"), train_meta)
np.save(open(valid_dir.joinpath('dataset.bin.meta'), "wb"), valid_meta) np.save(open(valid_dir.joinpath("dataset.bin.meta"), "wb"), valid_meta)
return train_tokens, valid_tokens, train_samples, valid_samples return train_tokens, valid_tokens, train_samples, valid_samples
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("dataset_path", type=str, help="path of dataset json file")
'dataset_path', type=str, help='path of dataset json file') parser.add_argument("output_path", type=str, help="path of processed dataset")
parser.add_argument( parser.add_argument("tokenizer_path", type=str, help="path of tokenizer")
'output_path', type=str, help='path of processed dataset') parser.add_argument("--split_ratio", type=float, default=0.1, help="ratio for validation dataset splitting")
parser.add_argument('tokenizer_path', type=str, help='path of tokenizer')
parser.add_argument(
'--split_ratio',
type=float,
default=0.1,
help='ratio for validation dataset splitting')
args = parser.parse_args() args = parser.parse_args()
sp_model = spm.SentencePieceProcessor(model_file=args.tokenizer_path) sp_model = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
@ -163,9 +157,8 @@ if __name__ == '__main__':
for sample in tqdm(dataset): for sample in tqdm(dataset):
samples.append(sample) samples.append(sample)
train_tokens, valid_tokens, train_samples, valid_samples = \ train_tokens, valid_tokens, train_samples, valid_samples = dump_bin_meta_bin(
dump_bin_meta_bin(samples, args.output_path, args.split_ratio) samples, args.output_path, args.split_ratio
print(f'number of train dataset: {train_samples}, ' )
'number of train dataset token: {train_tokens}') print(f"number of train dataset: {train_samples}, " "number of train dataset token: {train_tokens}")
print(f'number of validation dataset: {valid_samples}, ' print(f"number of validation dataset: {valid_samples}, " "number of validation dataset token: {valid_tokens}")
'number of validation dataset token: {valid_tokens}')