2023-04-26 08:32:40 +00:00
|
|
|
from collections import defaultdict
|
2023-08-02 02:17:36 +00:00
|
|
|
from typing import Dict
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import transformers
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
2023-08-02 02:17:36 +00:00
|
|
|
from .utils import jload
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
class PromptDataset(Dataset):
|
|
|
|
"""Dataset for supervised fine-tuning."""
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
data_path: str,
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
|
max_datasets_size: int = None,
|
|
|
|
max_length: int = 96,
|
|
|
|
):
|
2023-03-28 12:25:36 +00:00
|
|
|
super(PromptDataset, self).__init__()
|
2023-04-26 08:32:40 +00:00
|
|
|
self.keyed_prompt = defaultdict(list)
|
2023-08-02 02:17:36 +00:00
|
|
|
self.logger = get_dist_logger()
|
|
|
|
self.logger.info("Loading data...")
|
2023-03-28 12:25:36 +00:00
|
|
|
list_data_dict = jload(data_path)
|
2023-08-02 02:17:36 +00:00
|
|
|
self.logger.info(f"Loaded {len(list_data_dict)} examples.")
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
if max_datasets_size is not None:
|
2023-08-02 02:17:36 +00:00
|
|
|
self.logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
2023-03-28 12:25:36 +00:00
|
|
|
list_data_dict = list_data_dict[:max_datasets_size]
|
|
|
|
|
2023-06-13 05:31:56 +00:00
|
|
|
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
|
2023-09-19 06:20:26 +00:00
|
|
|
tokens = tokenizer(
|
|
|
|
instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
|
|
|
|
)
|
2023-06-13 05:31:56 +00:00
|
|
|
for k, tensor in tokens.items():
|
|
|
|
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
2023-05-17 09:44:05 +00:00
|
|
|
return len(self.keyed_prompt["input_ids"])
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
2023-04-26 08:32:40 +00:00
|
|
|
return {k: v[i] for k, v in self.keyed_prompt.items()}
|