mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.8 KiB
75 lines
2.8 KiB
import math
|
|
from abc import abstractmethod
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
|
import os
|
|
import numpy as np
|
|
import cv2
|
|
|
|
class Txt2ImgIterableBaseDataset(IterableDataset):
|
|
'''
|
|
Define an interface to make the IterableDatasets for text2img data chainable
|
|
'''
|
|
def __init__(self, file_path: str, rank, world_size):
|
|
super().__init__()
|
|
self.file_path = file_path
|
|
self.folder_list = []
|
|
self.file_list = []
|
|
self.txt_list = []
|
|
self.info = self._get_file_info(file_path)
|
|
self.start = self.info['start']
|
|
self.end = self.info['end']
|
|
self.rank = rank
|
|
|
|
self.world_size = world_size
|
|
# self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
|
|
# self.iter_start = self.start + self.rank * self.per_worker
|
|
# self.iter_end = min(self.iter_start + self.per_worker, self.end)
|
|
# self.num_records = self.iter_end - self.iter_start
|
|
# self.valid_ids = [i for i in range(self.iter_end)]
|
|
self.num_records = self.end - self.start
|
|
self.valid_ids = [i for i in range(self.end)]
|
|
|
|
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
|
|
|
def __len__(self):
|
|
# return self.iter_end - self.iter_start
|
|
return self.end - self.start
|
|
|
|
def __iter__(self):
|
|
sample_iterator = self._sample_generator(self.start, self.end)
|
|
# sample_iterator = self._sample_generator(self.iter_start, self.iter_end)
|
|
return sample_iterator
|
|
|
|
def _sample_generator(self, start, end):
|
|
for idx in range(start, end):
|
|
file_name = self.file_list[idx]
|
|
txt_name = self.txt_list[idx]
|
|
f_ = open(txt_name, 'r')
|
|
txt_ = f_.read()
|
|
f_.close()
|
|
image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
image = torch.from_numpy(image) / 255
|
|
yield {"caption": txt_, "image":image}
|
|
|
|
|
|
def _get_file_info(self, file_path):
|
|
info = \
|
|
{
|
|
"start": 1,
|
|
"end": 0,
|
|
}
|
|
self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
|
|
for folder in self.folder_list:
|
|
files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
|
|
txts = [k.replace('jpg', 'txt') for k in files]
|
|
self.file_list.extend(files)
|
|
self.txt_list.extend(txts)
|
|
info['end'] = len(self.file_list)
|
|
# with open(file_path, 'r') as fin:
|
|
# for _ in enumerate(fin):
|
|
# info['end'] += 1
|
|
# self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
|
|
return info |