|
|
|
@ -54,7 +54,8 @@ class SFTDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None: |
|
|
|
|
super().__init__() |
|
|
|
|
self.prompts = [] |
|
|
|
|
# self.prompts = [] |
|
|
|
|
self.input_ids = [] |
|
|
|
|
|
|
|
|
|
for data in tqdm(dataset, disable=not is_rank_0()): |
|
|
|
|
prompt = data['prompt'] + data['completion'] + "<|endoftext|>" |
|
|
|
@ -64,14 +65,18 @@ class SFTDataset(Dataset):
|
|
|
|
|
truncation=True, |
|
|
|
|
return_tensors="pt") |
|
|
|
|
|
|
|
|
|
self.prompts.append(prompt_token) |
|
|
|
|
# self.prompts.append(prompt_token)s |
|
|
|
|
self.input_ids.append(prompt_token) |
|
|
|
|
self.labels = copy.deepcopy(self.input_ids) |
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
|
length = len(self.prompts) |
|
|
|
|
return length |
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
return self.prompts[idx] |
|
|
|
|
# dict(input_ids=self.input_ids[i], labels=self.labels[i]) |
|
|
|
|
return dict(input_ids=self.input_ids[i], labels=self.labels[i]) |
|
|
|
|
# return dict(self.prompts[idx], self.prompts[idx]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
|
|
|
|