mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.2 KiB
36 lines
1.2 KiB
import torch |
|
from datasets import load_dataset |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class BeansDataset(Dataset): |
|
def __init__(self, image_processor, tp_size=1, split="train"): |
|
super().__init__() |
|
self.image_processor = image_processor |
|
self.ds = load_dataset("beans")[split] |
|
self.label_names = self.ds.features["labels"].names |
|
while len(self.label_names) % tp_size != 0: |
|
# ensure that the number of labels is multiple of tp_size |
|
self.label_names.append(f"pad_label_{len(self.label_names)}") |
|
self.num_labels = len(self.label_names) |
|
self.inputs = [] |
|
for example in self.ds: |
|
self.inputs.append(self.process_example(example)) |
|
|
|
def __len__(self): |
|
return len(self.inputs) |
|
|
|
def __getitem__(self, idx): |
|
return self.inputs[idx] |
|
|
|
def process_example(self, example): |
|
input = self.image_processor(example["image"], return_tensors="pt") |
|
input["labels"] = example["labels"] |
|
return input |
|
|
|
|
|
def beans_collator(batch): |
|
return { |
|
"pixel_values": torch.cat([data["pixel_values"] for data in batch], dim=0), |
|
"labels": torch.tensor([data["labels"] for data in batch], dtype=torch.int64), |
|
}
|
|
|