import copy import json from typing import Dict, Sequence import torch from datasets import load_dataset from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer IGNORE_INDEX = -100 def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", max_length=max_length, truncation=True, ) for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict: """Preprocess the data by tokenizing.""" examples = [s + t for s, t in zip(sources, targets)] examples_tokenized, sources_tokenized = [ _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources) ] input_ids = examples_tokenized["input_ids"] labels = copy.deepcopy(input_ids) for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): label[:source_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=labels) class EasySupervisedDataset(Dataset): def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None: super(EasySupervisedDataset, self).__init__() with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" sources, targets = [], [] for line in all_lines: if "回答:" in line: sep_index = line.index("回答:") sources.append(line[:sep_index + 3]) targets.append(line[sep_index + 3:] + tokenizer.eos_token) else: sources.append(line) targets.append("" + tokenizer.eos_token) data_dict = preprocess(sources, targets, tokenizer, max_length) self.input_ids = data_dict["input_ids"] self.labels = data_dict["labels"] self.data_file = data_file def __len__(self): return len(self.input_ids) def __getitem__(self, i) -> Dict[str, torch.Tensor]: return dict(input_ids=self.input_ids[i], labels=self.labels[i]) def __repr__(self): return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" def __str__(self): return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" class EasyPromptsDataset(Dataset): def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None: super(EasyPromptsDataset, self).__init__() with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines] self.prompts = [ tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) for line in tqdm(all_lines) ] self.data_file = data_file def __len__(self): return len(self.prompts) def __getitem__(self, idx): return self.prompts[idx] def __repr__(self): return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" def __str__(self): return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" class EasyRewardDataset(Dataset): def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None: super(EasyRewardDataset, self).__init__() self.chosen = [] self.reject = [] if special_token is None: self.end_token = tokenizer.eos_token else: self.end_token = special_token print(self.end_token) #read all lines in the train_file to a list with open(train_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() for line in tqdm(all_lines): data = json.loads(line) prompt = "提问:" + data['prompt'] + " 回答:" chosen = prompt + data['chosen'] + self.end_token chosen_token = tokenizer(chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") self.chosen.append({ "input_ids": chosen_token['input_ids'], "attention_mask": chosen_token['attention_mask'] }) reject = prompt + data['rejected'] + self.end_token reject_token = tokenizer(reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") self.reject.append({ "input_ids": reject_token['input_ids'], "attention_mask": reject_token['attention_mask'] }) def __len__(self): length = len(self.chosen) return length def __getitem__(self, idx): return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ "input_ids"], self.reject[idx]["attention_mask"] #python representation of the object and the string representation of the object def __repr__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" def __str__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" ''' Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better. If individual lines are not related, just set is_group_texts to False. ''' class EasySFTDataset(Dataset): def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None: super().__init__() #read the data_file line by line with open(data_file, "r", encoding="UTF-8") as f: #encode the text data line by line and put raw python list input_ids only to raw_input_ids list raw_input_ids = [] for line in f: encoded_ids = tokenizer.encode(line) #if the encoded_ids is longer than max_length, then split it into several parts if len(encoded_ids) > max_length: for i in range(0, len(encoded_ids), max_length): raw_input_ids.append(encoded_ids[i:i + max_length]) else: raw_input_ids.append(encoded_ids) grouped_input_ids = [] current_input_ids = [] attention_mask = [] if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id if is_group_texts: for input_ids in raw_input_ids: if len(current_input_ids) + len(input_ids) > max_length: #pad the current_input_ids to max_length with tokenizer.pad_token_id padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) current_input_ids = [] else: current_input_ids.extend(input_ids) if len(current_input_ids) > 0: padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) else: #just append the raw_input_ids to max_length for input_ids in raw_input_ids: padded_length = max_length - len(input_ids) input_ids.extend([tokenizer.pad_token_id] * padded_length) attention_mask.append( torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) self.input_ids = grouped_input_ids self.labels = copy.deepcopy(self.input_ids) self.file_name = data_file self.attention_mask = attention_mask def __len__(self): return len(self.input_ids) #get item from dataset def __getitem__(self, idx): return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) #generate the dataset description to be printed by print in python def __repr__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" #generate the dataset description to be printed by print in python def __str__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"