|
|
|
@ -74,15 +74,10 @@ class SFTDataset(Dataset):
|
|
|
|
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], |
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
|
|
max_length: int |
|
|
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, |
|
|
|
|
max_length: int) -> Dict[str, torch.Tensor]: |
|
|
|
|
"""Tokenize a list of strings.""" |
|
|
|
|
tokenized_list = tokenizer( |
|
|
|
|
strings, return_tensors="pt", padding="longest", |
|
|
|
|
max_length=max_length, truncation=True |
|
|
|
|
) |
|
|
|
|
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True) |
|
|
|
|
input_ids = labels = tokenized_list["input_ids"] |
|
|
|
|
input_ids_lens = labels_lens = \ |
|
|
|
|
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1) |
|
|
|
@ -103,8 +98,7 @@ def preprocess(
|
|
|
|
|
"""Preprocess the data by tokenizing.""" |
|
|
|
|
examples = [s + t for s, t in zip(sources, targets)] |
|
|
|
|
examples_tokenized, sources_tokenized = [ |
|
|
|
|
_tokenize_fn(strings, tokenizer, max_length) |
|
|
|
|
for strings in (examples, sources) |
|
|
|
|
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources) |
|
|
|
|
] |
|
|
|
|
input_ids = examples_tokenized["input_ids"] |
|
|
|
|
labels = copy.deepcopy(input_ids) |
|
|
|
@ -116,7 +110,11 @@ def preprocess(
|
|
|
|
|
class SupervisedDataset(Dataset): |
|
|
|
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512): |
|
|
|
|
def __init__(self, |
|
|
|
|
data_path: str, |
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
|
|
max_datasets_size: int = None, |
|
|
|
|
max_length: int = 512): |
|
|
|
|
super(SupervisedDataset, self).__init__() |
|
|
|
|
logger.info("Loading data...") |
|
|
|
|
list_data_dict = jload(data_path) |
|
|
|
|