2023-04-06 03:54:52 +00:00
|
|
|
import copy
|
2023-04-06 07:04:48 +00:00
|
|
|
import json
|
2023-04-06 03:54:52 +00:00
|
|
|
from typing import Dict, Sequence
|
2023-04-06 07:04:48 +00:00
|
|
|
|
|
|
|
import torch
|
2023-04-06 03:54:52 +00:00
|
|
|
from datasets import load_dataset
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from tqdm import tqdm
|
2023-04-06 07:04:48 +00:00
|
|
|
from transformers import AutoTokenizer
|
2023-04-06 03:54:52 +00:00
|
|
|
|
|
|
|
IGNORE_INDEX = -100
|
|
|
|
|
|
|
|
|
2023-04-06 07:04:48 +00:00
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
2023-04-06 03:54:52 +00:00
|
|
|
"""Tokenize a list of strings."""
|
|
|
|
tokenized_list = [
|
|
|
|
tokenizer(
|
|
|
|
text,
|
|
|
|
return_tensors="pt",
|
|
|
|
padding="longest",
|
|
|
|
max_length=max_length,
|
|
|
|
truncation=True,
|
|
|
|
) for text in strings
|
|
|
|
]
|
|
|
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
|
|
|
input_ids_lens = labels_lens = [
|
|
|
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
|
|
|
]
|
|
|
|
return dict(
|
|
|
|
input_ids=input_ids,
|
|
|
|
labels=labels,
|
|
|
|
input_ids_lens=input_ids_lens,
|
|
|
|
labels_lens=labels_lens,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-04-06 07:04:48 +00:00
|
|
|
def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict:
|
2023-04-06 03:54:52 +00:00
|
|
|
"""Preprocess the data by tokenizing."""
|
|
|
|
examples = [s + t for s, t in zip(sources, targets)]
|
2023-04-06 07:04:48 +00:00
|
|
|
examples_tokenized, sources_tokenized = [
|
|
|
|
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
|
|
|
]
|
2023-04-06 03:54:52 +00:00
|
|
|
input_ids = examples_tokenized["input_ids"]
|
|
|
|
labels = copy.deepcopy(input_ids)
|
|
|
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
|
|
|
label[:source_len] = IGNORE_INDEX
|
|
|
|
return dict(input_ids=input_ids, labels=labels)
|
|
|
|
|
|
|
|
|
|
|
|
class EasySupervisedDataset(Dataset):
|
2023-04-06 07:04:48 +00:00
|
|
|
|
|
|
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
|
|
|
super(EasySupervisedDataset, self).__init__()
|
|
|
|
with open(data_file, "r", encoding="UTF-8") as f:
|
2023-04-06 03:54:52 +00:00
|
|
|
all_lines = f.readlines()
|
|
|
|
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
2023-04-06 07:04:48 +00:00
|
|
|
sources, targets = [], []
|
2023-04-06 03:54:52 +00:00
|
|
|
for line in all_lines:
|
|
|
|
if "回答:" in line:
|
|
|
|
sep_index = line.index("回答:")
|
2023-04-06 07:04:48 +00:00
|
|
|
sources.append(line[:sep_index + 3])
|
|
|
|
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
|
2023-04-06 03:54:52 +00:00
|
|
|
else:
|
|
|
|
sources.append(line)
|
2023-04-06 07:04:48 +00:00
|
|
|
targets.append("" + tokenizer.eos_token)
|
|
|
|
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
2023-04-06 03:54:52 +00:00
|
|
|
|
|
|
|
self.input_ids = data_dict["input_ids"]
|
|
|
|
self.labels = data_dict["labels"]
|
|
|
|
self.data_file = data_file
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
def __len__(self):
|
|
|
|
return len(self.input_ids)
|
|
|
|
|
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
|
|
|
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
def __repr__(self):
|
|
|
|
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
def __str__(self):
|
|
|
|
return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})"
|
|
|
|
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
class EasyPromptsDataset(Dataset):
|
2023-04-06 07:04:48 +00:00
|
|
|
|
|
|
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
|
|
|
super(EasyPromptsDataset, self).__init__()
|
|
|
|
with open(data_file, "r", encoding="UTF-8") as f:
|
2023-04-06 03:54:52 +00:00
|
|
|
all_lines = f.readlines()
|
2023-04-06 07:04:48 +00:00
|
|
|
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
|
2023-04-06 03:54:52 +00:00
|
|
|
self.prompts = [
|
2023-04-06 07:04:48 +00:00
|
|
|
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
|
|
|
|
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
2023-04-06 03:54:52 +00:00
|
|
|
for line in tqdm(all_lines)
|
|
|
|
]
|
|
|
|
self.data_file = data_file
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
def __len__(self):
|
|
|
|
return len(self.prompts)
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
def __getitem__(self, idx):
|
|
|
|
return self.prompts[idx]
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
def __repr__(self):
|
|
|
|
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})"
|
|
|
|
|
|
|
|
|
|
|
|
class EasyRewardDataset(Dataset):
|
2023-04-06 07:04:48 +00:00
|
|
|
|
|
|
|
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
|
|
|
super(EasyRewardDataset, self).__init__()
|
2023-04-06 03:54:52 +00:00
|
|
|
self.chosen = []
|
|
|
|
self.reject = []
|
|
|
|
if special_token is None:
|
|
|
|
self.end_token = tokenizer.eos_token
|
|
|
|
else:
|
|
|
|
self.end_token = special_token
|
|
|
|
print(self.end_token)
|
|
|
|
#read all lines in the train_file to a list
|
2023-04-06 07:04:48 +00:00
|
|
|
with open(train_file, "r", encoding="UTF-8") as f:
|
2023-04-06 03:54:52 +00:00
|
|
|
all_lines = f.readlines()
|
|
|
|
for line in tqdm(all_lines):
|
|
|
|
data = json.loads(line)
|
2023-04-06 07:04:48 +00:00
|
|
|
prompt = "提问:" + data['prompt'] + " 回答:"
|
2023-04-06 03:54:52 +00:00
|
|
|
|
|
|
|
chosen = prompt + data['chosen'] + self.end_token
|
|
|
|
chosen_token = tokenizer(chosen,
|
|
|
|
max_length=max_length,
|
|
|
|
padding="max_length",
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="pt")
|
|
|
|
self.chosen.append({
|
|
|
|
"input_ids": chosen_token['input_ids'],
|
|
|
|
"attention_mask": chosen_token['attention_mask']
|
|
|
|
})
|
|
|
|
|
|
|
|
reject = prompt + data['rejected'] + self.end_token
|
|
|
|
reject_token = tokenizer(reject,
|
|
|
|
max_length=max_length,
|
|
|
|
padding="max_length",
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="pt")
|
|
|
|
self.reject.append({
|
|
|
|
"input_ids": reject_token['input_ids'],
|
|
|
|
"attention_mask": reject_token['attention_mask']
|
|
|
|
})
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
length = len(self.chosen)
|
|
|
|
return length
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
|
|
|
"input_ids"], self.reject[idx]["attention_mask"]
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
#python representation of the object and the string representation of the object
|
|
|
|
def __repr__(self):
|
|
|
|
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
|
|
|
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
'''
|
2023-04-20 09:22:15 +00:00
|
|
|
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
|
2023-04-06 03:54:52 +00:00
|
|
|
If individual lines are not related, just set is_group_texts to False.
|
|
|
|
'''
|
2023-04-06 07:04:48 +00:00
|
|
|
|
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
class EasySFTDataset(Dataset):
|
2023-04-06 07:04:48 +00:00
|
|
|
|
|
|
|
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
2023-04-06 03:54:52 +00:00
|
|
|
super().__init__()
|
|
|
|
#read the data_file line by line
|
2023-04-06 07:04:48 +00:00
|
|
|
with open(data_file, "r", encoding="UTF-8") as f:
|
2023-04-06 03:54:52 +00:00
|
|
|
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
|
|
|
raw_input_ids = []
|
|
|
|
for line in f:
|
|
|
|
encoded_ids = tokenizer.encode(line)
|
|
|
|
#if the encoded_ids is longer than max_length, then split it into several parts
|
|
|
|
if len(encoded_ids) > max_length:
|
2023-04-06 07:04:48 +00:00
|
|
|
for i in range(0, len(encoded_ids), max_length):
|
|
|
|
raw_input_ids.append(encoded_ids[i:i + max_length])
|
2023-04-06 03:54:52 +00:00
|
|
|
else:
|
|
|
|
raw_input_ids.append(encoded_ids)
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-05-06 03:27:23 +00:00
|
|
|
grouped_input_ids = []
|
2023-04-06 03:54:52 +00:00
|
|
|
current_input_ids = []
|
|
|
|
attention_mask = []
|
|
|
|
if tokenizer.pad_token_id is None:
|
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
|
if is_group_texts:
|
|
|
|
for input_ids in raw_input_ids:
|
|
|
|
if len(current_input_ids) + len(input_ids) > max_length:
|
|
|
|
#pad the current_input_ids to max_length with tokenizer.pad_token_id
|
|
|
|
padded_length = max_length - len(current_input_ids)
|
|
|
|
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
2023-05-06 03:27:23 +00:00
|
|
|
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
2023-04-06 07:04:48 +00:00
|
|
|
attention_mask.append(
|
|
|
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
2023-04-06 03:54:52 +00:00
|
|
|
current_input_ids = []
|
|
|
|
else:
|
|
|
|
current_input_ids.extend(input_ids)
|
|
|
|
if len(current_input_ids) > 0:
|
|
|
|
padded_length = max_length - len(current_input_ids)
|
|
|
|
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
2023-05-06 03:27:23 +00:00
|
|
|
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
2023-04-06 07:04:48 +00:00
|
|
|
attention_mask.append(
|
|
|
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
2023-04-06 03:54:52 +00:00
|
|
|
else:
|
|
|
|
#just append the raw_input_ids to max_length
|
|
|
|
for input_ids in raw_input_ids:
|
|
|
|
padded_length = max_length - len(input_ids)
|
|
|
|
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
2023-04-06 07:04:48 +00:00
|
|
|
attention_mask.append(
|
|
|
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
2023-05-06 03:27:23 +00:00
|
|
|
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
|
|
|
self.input_ids = grouped_input_ids
|
2023-04-06 03:54:52 +00:00
|
|
|
self.labels = copy.deepcopy(self.input_ids)
|
|
|
|
self.file_name = data_file
|
|
|
|
self.attention_mask = attention_mask
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
def __len__(self):
|
|
|
|
return len(self.input_ids)
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
#get item from dataset
|
2023-04-06 07:04:48 +00:00
|
|
|
def __getitem__(self, idx):
|
|
|
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
#generate the dataset description to be printed by print in python
|
|
|
|
def __repr__(self):
|
|
|
|
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
2023-04-06 07:04:48 +00:00
|
|
|
|
2023-04-06 03:54:52 +00:00
|
|
|
#generate the dataset description to be printed by print in python
|
|
|
|
def __str__(self):
|
|
|
|
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|