ColossalAI/applications/Chat/coati/dataset/prompt_dataset.py

52 lines
1.7 KiB
Python
Raw Normal View History

2023-03-28 12:25:36 +00:00
import copy
import random
from collections import defaultdict
2023-03-28 12:25:36 +00:00
from dataclasses import dataclass, field
from typing import Callable, Dict, Sequence
import torch
import torch.distributed as dist
import transformers
from torch.utils.data import Dataset
from tqdm import tqdm
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
logger = get_dist_logger()
class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 96):
2023-03-28 12:25:36 +00:00
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
2023-03-28 12:25:36 +00:00
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_datasets_size is not None:
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
tokens = tokenizer(instructions,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
2023-03-28 12:25:36 +00:00
def __len__(self):
return len(self.keyed_prompt["input_ids"])
2023-03-28 12:25:36 +00:00
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return {k: v[i] for k, v in self.keyed_prompt.items()}