From b2debdc09bc9dd7d87a409e7a4485d1eca74cb61 Mon Sep 17 00:00:00 2001 From: "Zheng Zangwei (Alex Zheng)" Date: Tue, 18 Jul 2023 10:59:38 +0800 Subject: [PATCH] [NFC] polish applications/Chat/coati/dataset/sft_dataset.py code style (#4259) --- .../Chat/coati/dataset/sft_dataset.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index 3702d00cc..3038fbe07 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -74,15 +74,10 @@ class SFTDataset(Dataset): return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) -def _tokenize_fn(strings: Sequence[str], - tokenizer: transformers.PreTrainedTokenizer, - max_length: int - ) -> Dict[str, torch.Tensor]: +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, + max_length: int) -> Dict[str, torch.Tensor]: """Tokenize a list of strings.""" - tokenized_list = tokenizer( - strings, return_tensors="pt", padding="longest", - max_length=max_length, truncation=True - ) + tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True) input_ids = labels = tokenized_list["input_ids"] input_ids_lens = labels_lens = \ tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1) @@ -103,8 +98,7 @@ def preprocess( """Preprocess the data by tokenizing.""" examples = [s + t for s, t in zip(sources, targets)] examples_tokenized, sources_tokenized = [ - _tokenize_fn(strings, tokenizer, max_length) - for strings in (examples, sources) + _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources) ] input_ids = examples_tokenized["input_ids"] labels = copy.deepcopy(input_ids) @@ -116,7 +110,11 @@ def preprocess( class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512): + def __init__(self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + max_datasets_size: int = None, + max_length: int = 512): super(SupervisedDataset, self).__init__() logger.info("Loading data...") list_data_dict = jload(data_path)