mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish applications/Chat/coati/dataset/sft_dataset.py code style (#4259)
parent
abe4f971e0
commit
b2debdc09b
|
@ -74,15 +74,10 @@ class SFTDataset(Dataset):
|
|||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int) -> Dict[str, torch.Tensor]:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = tokenizer(
|
||||
strings, return_tensors="pt", padding="longest",
|
||||
max_length=max_length, truncation=True
|
||||
)
|
||||
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
|
||||
input_ids = labels = tokenized_list["input_ids"]
|
||||
input_ids_lens = labels_lens = \
|
||||
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
|
||||
|
@ -103,8 +98,7 @@ def preprocess(
|
|||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [
|
||||
_tokenize_fn(strings, tokenizer, max_length)
|
||||
for strings in (examples, sources)
|
||||
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
||||
]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
|
@ -116,7 +110,11 @@ def preprocess(
|
|||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512):
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_datasets_size: int = None,
|
||||
max_length: int = 512):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
|
|
Loading…
Reference in New Issue