|
|
|
@ -188,7 +188,7 @@ class EasySFTDataset(Dataset):
|
|
|
|
|
else: |
|
|
|
|
raw_input_ids.append(encoded_ids) |
|
|
|
|
|
|
|
|
|
grouped_inpup_ids = [] |
|
|
|
|
grouped_input_ids = [] |
|
|
|
|
current_input_ids = [] |
|
|
|
|
attention_mask = [] |
|
|
|
|
if tokenizer.pad_token_id is None: |
|
|
|
@ -199,7 +199,7 @@ class EasySFTDataset(Dataset):
|
|
|
|
|
#pad the current_input_ids to max_length with tokenizer.pad_token_id |
|
|
|
|
padded_length = max_length - len(current_input_ids) |
|
|
|
|
current_input_ids.extend([tokenizer.pad_token_id] * padded_length) |
|
|
|
|
grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) |
|
|
|
|
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) |
|
|
|
|
attention_mask.append( |
|
|
|
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) |
|
|
|
|
current_input_ids = [] |
|
|
|
@ -208,7 +208,7 @@ class EasySFTDataset(Dataset):
|
|
|
|
|
if len(current_input_ids) > 0: |
|
|
|
|
padded_length = max_length - len(current_input_ids) |
|
|
|
|
current_input_ids.extend([tokenizer.pad_token_id] * padded_length) |
|
|
|
|
grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) |
|
|
|
|
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) |
|
|
|
|
attention_mask.append( |
|
|
|
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) |
|
|
|
|
else: |
|
|
|
@ -218,8 +218,8 @@ class EasySFTDataset(Dataset):
|
|
|
|
|
input_ids.extend([tokenizer.pad_token_id] * padded_length) |
|
|
|
|
attention_mask.append( |
|
|
|
|
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) |
|
|
|
|
grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long)) |
|
|
|
|
self.input_ids = grouped_inpup_ids |
|
|
|
|
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) |
|
|
|
|
self.input_ids = grouped_input_ids |
|
|
|
|
self.labels = copy.deepcopy(self.input_ids) |
|
|
|
|
self.file_name = data_file |
|
|
|
|
self.attention_mask = attention_mask |
|
|
|
|