|
|
|
@ -53,29 +53,25 @@ class SFTDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
# self.prompts = []
|
|
|
|
|
self.input_ids = []
|
|
|
|
|
|
|
|
|
|
for data in tqdm(dataset, disable=not is_rank_0()):
|
|
|
|
|
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
|
|
|
|
|
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
|
|
|
|
|
prompt_token = tokenizer(prompt,
|
|
|
|
|
max_length=max_length,
|
|
|
|
|
padding="max_length",
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
# self.prompts.append(prompt_token)s
|
|
|
|
|
self.input_ids.append(prompt_token)
|
|
|
|
|
self.labels = copy.deepcopy(self.input_ids)
|
|
|
|
|
self.input_ids.append(prompt_token['input_ids'][0])
|
|
|
|
|
self.labels = copy.deepcopy(self.input_ids)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
length = len(self.prompts)
|
|
|
|
|
length = len(self.input_ids)
|
|
|
|
|
return length
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
|
|
|
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
|
|
|
|
# return dict(self.prompts[idx], self.prompts[idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
|
|
|
|
|