ColossalAI/applications/ColossalChat/coati/dataset/loader.py

347 lines
15 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Dataloader for sft, dpo, ppo
"""
import os
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
from datasets import Dataset as HFDataset
from datasets import dataset_dict, load_from_disk
from torch.utils.data import ConcatDataset, Dataset, DistributedSampler
from transformers.tokenization_utils import PreTrainedTokenizer
DatasetType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
PathType = Union[str, os.PathLike]
def load_tokenized_dataset(
dataset_paths: Union[PathType, List[PathType]], mode: str = "train", **kwargs
) -> Optional[DatasetType]:
"""
Load pre-tokenized dataset.
Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
"""
if not dataset_paths:
return None
mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"})
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"
if isinstance(dataset_paths, (str, os.PathLike)):
dataset_paths = [dataset_paths]
datasets = [] # `List[datasets.dataset_dict.Dataset]`
for ds_path in dataset_paths:
ds_path = os.path.abspath(ds_path)
assert os.path.exists(ds_path), f"Not existed file path {ds_path}"
ds_dict = load_from_disk(dataset_path=ds_path, keep_in_memory=False)
if isinstance(ds_dict, HFDataset):
datasets.append(ds_dict)
else:
if mode_map[mode] in ds_dict:
datasets.append(ds_dict[mode_map[mode]])
if len(datasets) == 0:
return None
if len(datasets) == 1:
return datasets.pop()
return ConcatDataset(datasets=datasets)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""
Collate instances for supervised dataset.
Each instance is a tokenized dictionary with fields
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
# `List[torch.Tensor]`
batch_input_ids = [
(
torch.LongTensor(instance["input_ids"][: self.max_length])
if len(instance["input_ids"]) > self.max_length
else torch.LongTensor(instance["input_ids"])
)
for instance in instances
]
batch_labels = [
(
torch.LongTensor(instance["labels"][: self.max_length])
if len(instance["labels"]) > self.max_length
else torch.LongTensor(instance["labels"])
)
for instance in instances
]
if self.tokenizer.padding_side == "right":
input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=batch_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
labels = torch.nn.utils.rnn.pad_sequence(
sequences=batch_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
# pad to max
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
labels = F.pad(labels, (0, to_pad), value=self.ignore_index)
elif self.tokenizer.padding_side == "left":
reversed_input_ids = [seq.flip(dims=(0,)) for seq in batch_input_ids]
reversed_input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
input_ids = torch.flip(reversed_input_ids, dims=(1,)) # (bsz, max_len)
reversed_labels = [seq.flip(dims=(0,)) for seq in batch_labels]
reversed_labels = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_labels,
batch_first=True,
padding_value=self.ignore_index,
) # (bsz, max_len)
labels = torch.flip(reversed_labels, dims=(1,)) # (bsz, max_len)
else:
raise RuntimeError(
f"`{self.tokenizer.__class__.__name__}.padding_side` can only be `left` or `right`, "
f"but now `{self.tokenizer.padding_side}`"
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
@dataclass
class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset):
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
"""
instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
ret = super().__call__(instances=instances)
input_ids = F.pad(
ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
)
attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@dataclass
class DataCollatorForPreferenceDataset(object):
"""
Collate instances for supervised dataset.
Each instance is a tokenized dictionary with fields
`input_ids`(List[int]), `labels`(List[int]) and `sequence`(str).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary.
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
(
chosen_input_ids,
chosen_loss_mask, # [batch_size * seq_len]
reject_input_ids,
reject_loss_mask,
) = (
chuncate_sequence([ins["chosen_input_ids"] for ins in instances], self.max_length, torch.int64),
chuncate_sequence([ins["chosen_loss_mask"] for ins in instances], self.max_length, torch.bool),
chuncate_sequence([ins["rejected_input_ids"] for ins in instances], self.max_length, torch.int64),
chuncate_sequence([ins["rejected_loss_mask"] for ins in instances], self.max_length, torch.bool),
)
padding_side = self.tokenizer.padding_side
chosen_attention_mask = [torch.ones_like(seq).bool() for seq in chosen_input_ids]
reject_attention_mask = [torch.ones_like(seq).bool() for seq in reject_input_ids]
(
chosen_input_ids,
chosen_attention_mask,
chosen_loss_mask,
reject_input_ids,
reject_attention_mask,
reject_loss_mask,
) = (
pad_to_max_len(chosen_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side),
pad_to_max_len(chosen_attention_mask, self.max_length, False, padding_side=padding_side),
pad_to_max_len(chosen_loss_mask, self.max_length, False, padding_side=padding_side),
pad_to_max_len(reject_input_ids, self.max_length, self.tokenizer.pad_token_id, padding_side=padding_side),
pad_to_max_len(reject_attention_mask, self.max_length, False, padding_side=padding_side),
pad_to_max_len(reject_loss_mask, self.max_length, False, padding_side=padding_side),
)
return dict(
chosen_input_ids=chosen_input_ids,
chosen_attention_mask=chosen_attention_mask,
chosen_loss_mask=chosen_loss_mask,
reject_input_ids=reject_input_ids,
reject_attention_mask=reject_attention_mask,
reject_loss_mask=reject_loss_mask,
)
@dataclass
class DataCollatorForKTODataset(object):
"""
Collate instances for kto dataset.
Each input instance is a tokenized dictionary with fields
`prompt`(List[int]), `completion`(List[int]) and `label`(bool).
Each output instance is a tokenized dictionary with fields
`kl_input_ids`(List[int]), `kl_attention_mask`(List[int]) and `kl_loss_mask`(List[int]).
`input_ids`(List[int]), `attention_mask`(List[int]), `loss_mask`(List[int]) and `label`(bool).
"""
tokenizer: PreTrainedTokenizer
max_length: int = 4096
ignore_index: int = -100
def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
"""
Args:
instances (`Sequence[Dict[str, List[int]]]`):
Mini-batch samples, each sample is stored in an individual dictionary contains the following fields:
`prompt`(List[int]), `completion`(List[int]) and `label`(bool, if the sample is desirable or not).
Returns:
(`Dict[str, torch.Tensor]`): Contains the following `torch.Tensor`:
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
`labels`: `torch.Tensor` of shape (bsz, max_len), which contains `IGNORE_INDEX`.
"""
assert isinstance(self.tokenizer.pad_token_id, int) and self.tokenizer.pad_token_id >= 0, (
f"`{self.tokenizer.__class__.__name__}.pad_token_id` must be a valid non-negative integer index value, "
f"but now `{self.tokenizer.pad_token_id}`"
)
# prepare the preference data
prompt = [torch.LongTensor(instance["prompt"]) for instance in instances]
prompt_zeros = [torch.zeros_like(t) for t in prompt]
completion = [torch.LongTensor(instance["completion"]) for instance in instances]
completion_ones = [torch.ones_like(t) for t in completion]
label = [torch.tensor(instance["label"], dtype=torch.bool) for instance in instances]
input_ids = [torch.cat([prompt[i], completion[i]], dim=-1) for i in range(len(instances))]
loss_mask = [torch.cat([prompt_zeros[i], completion_ones[i]], dim=-1) for i in range(len(instances))]
# right padding
input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
loss_mask = torch.nn.utils.rnn.pad_sequence(
sequences=loss_mask, batch_first=True, padding_value=0
) # (bsz, max_len)
to_pad = self.max_length - input_ids.size(1)
input_ids = F.pad(input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
loss_mask = F.pad(loss_mask, (0, to_pad), value=0)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
# prepare kt data
kl_completion = completion[::-1] # y'
kl_completion_ones = [torch.ones_like(t) for t in kl_completion]
kl_input_ids = [torch.cat([prompt[i], kl_completion[i]], dim=-1) for i in range(len(instances))]
kl_loss_mask = [torch.cat([prompt_zeros[i], kl_completion_ones[i]], dim=-1) for i in range(len(instances))]
# right padding
kl_input_ids = torch.nn.utils.rnn.pad_sequence(
sequences=kl_input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
) # (bsz, max_len)
kl_loss_mask = torch.nn.utils.rnn.pad_sequence(
sequences=kl_loss_mask, batch_first=True, padding_value=0
) # (bsz, max_len)
to_pad = self.max_length - kl_input_ids.size(1)
kl_input_ids = F.pad(kl_input_ids, (0, to_pad), value=self.tokenizer.pad_token_id)
kl_loss_mask = F.pad(kl_loss_mask, (0, to_pad), value=0)
kl_attention_mask = kl_input_ids.ne(self.tokenizer.pad_token_id) # `torch.BoolTensor`, (bsz, max_len)
data_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"label": torch.stack(label),
"kl_input_ids": kl_input_ids,
"kl_attention_mask": kl_attention_mask,
"kl_loss_mask": kl_loss_mask,
}
return data_dict
class StatefulDistributedSampler(DistributedSampler):
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.start_index: int = 0
def __iter__(self) -> Iterator:
iterator = super().__iter__()
indices = list(iterator)
indices = indices[self.start_index :]
return iter(indices)
def __len__(self) -> int:
return self.num_samples - self.start_index
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index