mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
5.4 KiB
147 lines
5.4 KiB
1 year ago
|
import datasets
|
||
|
from torch.utils.data import DataLoader
|
||
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||
|
|
||
|
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
|
||
|
|
||
|
|
||
|
class GLUEDataBuilder:
|
||
|
|
||
|
task_text_field_map = {
|
||
|
"cola": ["sentence"],
|
||
|
"sst2": ["sentence"],
|
||
|
"mrpc": ["sentence1", "sentence2"],
|
||
|
"qqp": ["question1", "question2"],
|
||
|
"stsb": ["sentence1", "sentence2"],
|
||
|
"mnli": ["premise", "hypothesis"],
|
||
|
"qnli": ["question", "sentence"],
|
||
|
"rte": ["sentence1", "sentence2"],
|
||
|
"wnli": ["sentence1", "sentence2"],
|
||
|
"ax": ["premise", "hypothesis"],
|
||
|
}
|
||
|
|
||
|
glue_task_num_labels = {
|
||
|
"cola": 2,
|
||
|
"sst2": 2,
|
||
|
"mrpc": 2,
|
||
|
"qqp": 2,
|
||
|
"stsb": 1,
|
||
|
"mnli": 3,
|
||
|
"qnli": 2,
|
||
|
"rte": 2,
|
||
|
"wnli": 2,
|
||
|
"ax": 3,
|
||
|
}
|
||
|
|
||
|
loader_columns = [
|
||
|
"datasets_idx",
|
||
|
"input_ids",
|
||
|
"token_type_ids",
|
||
|
"attention_mask",
|
||
|
"start_positions",
|
||
|
"end_positions",
|
||
|
"labels",
|
||
|
]
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
model_name_or_path: str,
|
||
|
plugin: DPPluginBase = None,
|
||
|
task_name: str = "mrpc",
|
||
|
max_seq_length: int = 128,
|
||
|
train_batch_size: int = 32,
|
||
|
eval_batch_size: int = 32,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.model_name_or_path = model_name_or_path
|
||
|
self.task_name = task_name
|
||
|
self.max_seq_length = max_seq_length
|
||
|
self.train_batch_size = train_batch_size
|
||
|
self.eval_batch_size = eval_batch_size
|
||
|
self.plugin = plugin
|
||
|
|
||
|
self.text_fields = self.task_text_field_map[task_name]
|
||
|
self.num_labels = self.glue_task_num_labels[task_name]
|
||
|
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
||
|
self.setup()
|
||
|
|
||
|
def setup(self):
|
||
|
self.dataset = datasets.load_dataset("glue", self.task_name)
|
||
|
|
||
|
for split in self.dataset.keys():
|
||
|
self.dataset[split] = self.dataset[split].map(
|
||
|
self.convert_to_features,
|
||
|
batched=True,
|
||
|
remove_columns=["label"],
|
||
|
)
|
||
|
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
|
||
|
self.dataset[split].set_format(type="torch", columns=self.columns)
|
||
|
|
||
|
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
|
||
|
|
||
|
def prepare_data(self):
|
||
|
datasets.load_dataset("glue", self.task_name)
|
||
|
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
if self.plugin == None:
|
||
|
return self.native_prepare_dataloader(self.dataset["train"],
|
||
|
batch_size=self.train_batch_size,
|
||
|
shuffle=True,
|
||
|
drop_last=True)
|
||
|
return self.plugin.prepare_dataloader(self.dataset["train"],
|
||
|
batch_size=self.train_batch_size,
|
||
|
shuffle=True,
|
||
|
drop_last=True)
|
||
|
|
||
|
def val_dataloader(self):
|
||
|
if self.plugin == None:
|
||
|
return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
||
|
if len(self.eval_splits) == 1:
|
||
|
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
|
||
|
elif len(self.eval_splits) > 1:
|
||
|
return [
|
||
|
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||
|
for x in self.eval_splits
|
||
|
]
|
||
|
|
||
|
def test_dataloader(self):
|
||
|
if self.plugin == None:
|
||
|
return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size)
|
||
|
if len(self.eval_splits) == 1:
|
||
|
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
|
||
|
elif len(self.eval_splits) > 1:
|
||
|
return [
|
||
|
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
|
||
|
for x in self.eval_splits
|
||
|
]
|
||
|
|
||
|
def convert_to_features(self, example_batch):
|
||
|
|
||
|
# Either encode single sentence or sentence pairs
|
||
|
if len(self.text_fields) > 1:
|
||
|
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
|
||
|
else:
|
||
|
texts_or_text_pairs = example_batch[self.text_fields[0]]
|
||
|
|
||
|
# Tokenize the text/text pairs
|
||
|
features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
|
||
|
max_length=self.max_seq_length,
|
||
|
padding='max_length',
|
||
|
truncation=True)
|
||
|
|
||
|
# Rename label to labels to make it easier to pass to model forward
|
||
|
features["labels"] = example_batch["label"]
|
||
|
|
||
|
return features
|
||
|
|
||
|
def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False):
|
||
|
|
||
|
return DataLoader(dataset,
|
||
|
batch_size=batch_size,
|
||
|
sampler=None,
|
||
|
shuffle=shuffle,
|
||
|
drop_last=drop_last,
|
||
|
pin_memory=pin_memory)
|