|
|
|
@ -74,15 +74,10 @@ class SFTDataset(Dataset):
|
|
|
|
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str],
|
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
|
|
max_length: int
|
|
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
|
|
|
|
|
max_length: int) -> Dict[str, torch.Tensor]:
|
|
|
|
|
"""Tokenize a list of strings."""
|
|
|
|
|
tokenized_list = tokenizer(
|
|
|
|
|
strings, return_tensors="pt", padding="longest",
|
|
|
|
|
max_length=max_length, truncation=True
|
|
|
|
|
)
|
|
|
|
|
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
|
|
|
|
|
input_ids = labels = tokenized_list["input_ids"]
|
|
|
|
|
input_ids_lens = labels_lens = \
|
|
|
|
|
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
|
|
|
|
@ -103,8 +98,7 @@ def preprocess(
|
|
|
|
|
"""Preprocess the data by tokenizing."""
|
|
|
|
|
examples = [s + t for s, t in zip(sources, targets)]
|
|
|
|
|
examples_tokenized, sources_tokenized = [
|
|
|
|
|
_tokenize_fn(strings, tokenizer, max_length)
|
|
|
|
|
for strings in (examples, sources)
|
|
|
|
|
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
|
|
|
|
]
|
|
|
|
|
input_ids = examples_tokenized["input_ids"]
|
|
|
|
|
labels = copy.deepcopy(input_ids)
|
|
|
|
@ -116,7 +110,11 @@ def preprocess(
|
|
|
|
|
class SupervisedDataset(Dataset):
|
|
|
|
|
"""Dataset for supervised fine-tuning."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
data_path: str,
|
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
|
|
max_datasets_size: int = None,
|
|
|
|
|
max_length: int = 512):
|
|
|
|
|
super(SupervisedDataset, self).__init__()
|
|
|
|
|
logger.info("Loading data...")
|
|
|
|
|
list_data_dict = jload(data_path)
|
|
|
|
|