2023-03-28 12:25:36 +00:00
|
|
|
import copy
|
|
|
|
import random
|
2023-04-26 08:32:40 +00:00
|
|
|
from collections import defaultdict
|
2023-03-28 12:25:36 +00:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from typing import Callable, Dict, Sequence
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import transformers
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
from .utils import is_rank_0, jload
|
|
|
|
|
|
|
|
logger = get_dist_logger()
|
|
|
|
|
|
|
|
|
|
|
|
class PromptDataset(Dataset):
|
|
|
|
"""Dataset for supervised fine-tuning."""
|
|
|
|
|
2023-04-26 08:32:40 +00:00
|
|
|
def __init__(self,
|
|
|
|
data_path: str,
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
|
max_datasets_size: int = None,
|
|
|
|
max_length: int = 96):
|
2023-03-28 12:25:36 +00:00
|
|
|
super(PromptDataset, self).__init__()
|
2023-04-26 08:32:40 +00:00
|
|
|
self.keyed_prompt = defaultdict(list)
|
2023-03-28 12:25:36 +00:00
|
|
|
logger.info("Loading data...")
|
|
|
|
list_data_dict = jload(data_path)
|
|
|
|
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
|
|
|
|
|
|
|
if max_datasets_size is not None:
|
|
|
|
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
|
|
|
list_data_dict = list_data_dict[:max_datasets_size]
|
|
|
|
|
|
|
|
for data_dict in list_data_dict:
|
|
|
|
token = tokenizer(data_dict["instruction"],
|
|
|
|
return_tensors='pt',
|
2023-04-26 08:32:40 +00:00
|
|
|
max_length=max_length,
|
2023-03-28 12:25:36 +00:00
|
|
|
padding='max_length',
|
|
|
|
truncation=True)
|
2023-04-26 08:32:40 +00:00
|
|
|
for k, tensor in token.items():
|
|
|
|
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def __len__(self):
|
2023-05-17 09:44:05 +00:00
|
|
|
return len(self.keyed_prompt["input_ids"])
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
2023-04-26 08:32:40 +00:00
|
|
|
return {k: v[i] for k, v in self.keyed_prompt.items()}
|