mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish applications/Chat/coati/dataset/sft_dataset.py code style (#4259)
parent
abe4f971e0
commit
b2debdc09b
|
@ -74,15 +74,10 @@ class SFTDataset(Dataset):
|
||||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||||
|
|
||||||
|
|
||||||
def _tokenize_fn(strings: Sequence[str],
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer,
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
max_length: int) -> Dict[str, torch.Tensor]:
|
||||||
max_length: int
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
|
||||||
"""Tokenize a list of strings."""
|
"""Tokenize a list of strings."""
|
||||||
tokenized_list = tokenizer(
|
tokenized_list = tokenizer(strings, return_tensors="pt", padding="longest", max_length=max_length, truncation=True)
|
||||||
strings, return_tensors="pt", padding="longest",
|
|
||||||
max_length=max_length, truncation=True
|
|
||||||
)
|
|
||||||
input_ids = labels = tokenized_list["input_ids"]
|
input_ids = labels = tokenized_list["input_ids"]
|
||||||
input_ids_lens = labels_lens = \
|
input_ids_lens = labels_lens = \
|
||||||
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
|
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
|
||||||
|
@ -103,8 +98,7 @@ def preprocess(
|
||||||
"""Preprocess the data by tokenizing."""
|
"""Preprocess the data by tokenizing."""
|
||||||
examples = [s + t for s, t in zip(sources, targets)]
|
examples = [s + t for s, t in zip(sources, targets)]
|
||||||
examples_tokenized, sources_tokenized = [
|
examples_tokenized, sources_tokenized = [
|
||||||
_tokenize_fn(strings, tokenizer, max_length)
|
_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)
|
||||||
for strings in (examples, sources)
|
|
||||||
]
|
]
|
||||||
input_ids = examples_tokenized["input_ids"]
|
input_ids = examples_tokenized["input_ids"]
|
||||||
labels = copy.deepcopy(input_ids)
|
labels = copy.deepcopy(input_ids)
|
||||||
|
@ -116,7 +110,11 @@ def preprocess(
|
||||||
class SupervisedDataset(Dataset):
|
class SupervisedDataset(Dataset):
|
||||||
"""Dataset for supervised fine-tuning."""
|
"""Dataset for supervised fine-tuning."""
|
||||||
|
|
||||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512):
|
def __init__(self,
|
||||||
|
data_path: str,
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
max_datasets_size: int = None,
|
||||||
|
max_length: int = 512):
|
||||||
super(SupervisedDataset, self).__init__()
|
super(SupervisedDataset, self).__init__()
|
||||||
logger.info("Loading data...")
|
logger.info("Loading data...")
|
||||||
list_data_dict = jload(data_path)
|
list_data_dict = jload(data_path)
|
||||||
|
|
Loading…
Reference in New Issue