mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
6.4 KiB
171 lines
6.4 KiB
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import random
|
|
from dataclasses import dataclass, field
|
|
from typing import Callable, Dict, Sequence
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import transformers
|
|
from torch.utils.data import Dataset
|
|
from tqdm import tqdm
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
from .utils import is_rank_0, jload
|
|
|
|
logger = get_dist_logger()
|
|
|
|
IGNORE_INDEX = -100
|
|
PROMPT_DICT = {
|
|
"prompt_input":
|
|
("Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
"Write a response that appropriately completes the request.\n\n"
|
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
|
"prompt_no_input": ("Below is an instruction that describes a task. "
|
|
"Write a response that appropriately completes the request.\n\n"
|
|
"### Instruction:\n{instruction}\n\n### Response:"),
|
|
}
|
|
|
|
|
|
class SFTDataset(Dataset):
|
|
"""
|
|
Dataset for sft model
|
|
|
|
Args:
|
|
dataset: dataset for supervised model
|
|
tokenizer: tokenizer for supervised model
|
|
max_length: max length of input
|
|
"""
|
|
|
|
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
|
|
super().__init__()
|
|
# self.prompts = []
|
|
self.input_ids = []
|
|
|
|
for data in tqdm(dataset, disable=not is_rank_0()):
|
|
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
|
|
prompt_token = tokenizer(prompt,
|
|
max_length=max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="pt")
|
|
|
|
# self.prompts.append(prompt_token)s
|
|
self.input_ids.append(prompt_token)
|
|
self.labels = copy.deepcopy(self.input_ids)
|
|
|
|
def __len__(self):
|
|
length = len(self.prompts)
|
|
return length
|
|
|
|
def __getitem__(self, idx):
|
|
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
|
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
|
# return dict(self.prompts[idx], self.prompts[idx])
|
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
|
|
"""Tokenize a list of strings."""
|
|
tokenized_list = [
|
|
tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding="longest",
|
|
max_length=max_length,
|
|
truncation=True,
|
|
) for text in strings
|
|
]
|
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
|
input_ids_lens = labels_lens = [
|
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
|
]
|
|
return dict(
|
|
input_ids=input_ids,
|
|
labels=labels,
|
|
input_ids_lens=input_ids_lens,
|
|
labels_lens=labels_lens,
|
|
)
|
|
|
|
|
|
def preprocess(
|
|
sources: Sequence[str],
|
|
targets: Sequence[str],
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
max_length: int,
|
|
) -> Dict:
|
|
"""Preprocess the data by tokenizing."""
|
|
examples = [s + t for s, t in zip(sources, targets)]
|
|
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
|
|
input_ids = examples_tokenized["input_ids"]
|
|
labels = copy.deepcopy(input_ids)
|
|
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
|
label[:source_len] = IGNORE_INDEX
|
|
return dict(input_ids=input_ids, labels=labels)
|
|
|
|
|
|
class SupervisedDataset(Dataset):
|
|
"""Dataset for supervised fine-tuning."""
|
|
|
|
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512):
|
|
super(SupervisedDataset, self).__init__()
|
|
logger.info("Loading data...")
|
|
list_data_dict = jload(data_path)
|
|
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
|
|
|
if max_datasets_size is not None:
|
|
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
|
list_data_dict = list_data_dict[:max_datasets_size]
|
|
|
|
logger.info("Formatting inputs...")
|
|
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
|
sources = [
|
|
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
|
for example in list_data_dict
|
|
]
|
|
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
|
|
|
logger.info("Tokenizing inputs... This may take some time...")
|
|
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
|
|
|
self.input_ids = data_dict["input_ids"]
|
|
self.labels = data_dict["labels"]
|
|
|
|
def __len__(self):
|
|
return len(self.input_ids)
|
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
|
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForSupervisedDataset(object):
|
|
"""Collate examples for supervised fine-tuning."""
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
|
|
batch_first=True,
|
|
padding_value=self.tokenizer.pad_token_id)
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
|
return dict(
|
|
input_ids=input_ids,
|
|
labels=labels,
|
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
|
)
|