Browse Source

[NFC] polish applications/Chat/coati/dataset/sft_dataset.py code style (#4259)

pull/4338/head
Zheng Zangwei (Alex Zheng) 1 year ago committed by binmakeswell
parent
commit
b2debdc09b
  1. 20
      applications/Chat/coati/dataset/sft_dataset.py

20
applications/Chat/coati/dataset/sft_dataset.py

@ -74,15 +74,10 @@ class SFTDataset(Dataset):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int
) -> Dict[str, torch.Tensor]:
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
max_length: int) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
tokenized_list = tokenizer(
strings, return_tensors="pt", padding="longest",
max_length=max_length, truncation=True
)
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
@ -103,8 +98,7 @@ def preprocess(
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length)
for strings in (examples, sources)
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
@ -116,7 +110,11 @@ def preprocess(
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512):
def __init__(self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 512):
super(SupervisedDataset, self).__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)

Loading…
Cancel
Save