mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1.2 KiB
37 lines
1.2 KiB
import torch
|
|
from datasets import load_dataset
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class BeansDataset(Dataset):
|
|
def __init__(self, image_processor, tp_size=1, split="train"):
|
|
super().__init__()
|
|
self.image_processor = image_processor
|
|
self.ds = load_dataset("beans")[split]
|
|
self.label_names = self.ds.features["labels"].names
|
|
while len(self.label_names) % tp_size != 0:
|
|
# ensure that the number of labels is multiple of tp_size
|
|
self.label_names.append(f"pad_label_{len(self.label_names)}")
|
|
self.num_labels = len(self.label_names)
|
|
self.inputs = []
|
|
for example in self.ds:
|
|
self.inputs.append(self.process_example(example))
|
|
|
|
def __len__(self):
|
|
return len(self.inputs)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.inputs[idx]
|
|
|
|
def process_example(self, example):
|
|
input = self.image_processor(example["image"], return_tensors="pt")
|
|
input["labels"] = example["labels"]
|
|
return input
|
|
|
|
|
|
def beans_collator(batch):
|
|
return {
|
|
"pixel_values": torch.cat([data["pixel_values"] for data in batch], dim=0),
|
|
"labels": torch.tensor([data["labels"] for data in batch], dtype=torch.int64),
|
|
}
|