mirror of https://github.com/hpcaitech/ColossalAI
Added rand augment and update the dataloader
parent
c7b8ece736
commit
d143396cac
46
examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py
Executable file → Normal file
46
examples/vit_b16_imagenet_data_parallel/dataloader/imagenet_dali_dataloader.py
Executable file → Normal file
|
@ -5,6 +5,7 @@ import nvidia.dali.types as types
|
|||
import nvidia.dali.tfrecord as tfrec
|
||||
import torch
|
||||
import numpy as np
|
||||
from .rand_augment import RandAugment
|
||||
|
||||
|
||||
class DaliDataloader(DALIClassificationIterator):
|
||||
|
@ -21,13 +22,17 @@ class DaliDataloader(DALIClassificationIterator):
|
|||
training=True,
|
||||
gpu_aug=False,
|
||||
cuda=True,
|
||||
mixup_alpha=0.0):
|
||||
mixup_alpha=0.0,
|
||||
randaug_magnitude=10,
|
||||
randaug_num_layers=0):
|
||||
self.mixup_alpha = mixup_alpha
|
||||
self.training = training
|
||||
self.randaug_magnitude = randaug_magnitude
|
||||
self.randaug_num_layers = randaug_num_layers
|
||||
pipe = Pipeline(batch_size=batch_size,
|
||||
num_threads=num_threads,
|
||||
device_id=torch.cuda.current_device() if cuda else None,
|
||||
seed=1024)
|
||||
seed=42)
|
||||
with pipe:
|
||||
inputs = fn.readers.tfrecord(
|
||||
path=tfrec_filenames,
|
||||
|
@ -44,38 +49,27 @@ class DaliDataloader(DALIClassificationIterator):
|
|||
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
|
||||
})
|
||||
images = inputs["image/encoded"]
|
||||
|
||||
images = fn.decoders.image(images,
|
||||
device='mixed' if gpu_aug else 'cpu',
|
||||
output_type=types.RGB)
|
||||
if training:
|
||||
images = fn.decoders.image(images,
|
||||
device='mixed' if gpu_aug else 'cpu',
|
||||
output_type=types.RGB)
|
||||
images = fn.random_resized_crop(images,
|
||||
size=crop,
|
||||
device='gpu' if gpu_aug else 'cpu')
|
||||
flip_lr = fn.random.coin_flip(probability=0.5)
|
||||
if randaug_num_layers == 0:
|
||||
flip_lr = fn.random.coin_flip(probability=0.5)
|
||||
images = fn.flip(images, horizontal=flip_lr)
|
||||
else:
|
||||
# decode jpeg and resize
|
||||
images = fn.decoders.image(images,
|
||||
device='mixed' if gpu_aug else 'cpu',
|
||||
output_type=types.RGB)
|
||||
images = fn.resize(images,
|
||||
device='gpu' if gpu_aug else 'cpu',
|
||||
resize_x=resize,
|
||||
resize_y=resize,
|
||||
dtype=types.FLOAT,
|
||||
interp_type=types.INTERP_TRIANGULAR)
|
||||
flip_lr = False
|
||||
|
||||
# center crop and normalise
|
||||
images = fn.crop_mirror_normalize(images,
|
||||
dtype=types.FLOAT,
|
||||
crop=(crop, crop),
|
||||
mean=[127.5],
|
||||
std=[127.5],
|
||||
mirror=flip_lr)
|
||||
images = fn.crop(images,
|
||||
dtype=types.FLOAT,
|
||||
crop=(crop, crop))
|
||||
label = inputs["image/class/label"] - 1 # 0-999
|
||||
# LSG: element_extract will raise exception, let's flatten outside
|
||||
# label = fn.element_extract(label, element_map=0) # Flatten
|
||||
if cuda: # transfer data to gpu
|
||||
pipe.set_outputs(images.gpu(), label.gpu())
|
||||
else:
|
||||
|
@ -96,6 +90,10 @@ class DaliDataloader(DALIClassificationIterator):
|
|||
def __next__(self):
|
||||
data = super().__next__()
|
||||
img, label = data[0]['data'], data[0]['label']
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
if self.randaug_num_layers > 0 and self.training:
|
||||
img = RandAugment(img, num_layers=self.randaug_num_layers, magnitude=self.randaug_magnitude)
|
||||
img = (img - 127.5) / 127.5
|
||||
label = label.squeeze()
|
||||
if self.mixup_alpha > 0.0:
|
||||
if self.training:
|
||||
|
@ -106,7 +104,7 @@ class DaliDataloader(DALIClassificationIterator):
|
|||
lam = torch.tensor([lam], device=img.device, dtype=img.dtype)
|
||||
label = {'targets_a': label_a, 'targets_b': label_b, 'lam': lam}
|
||||
else:
|
||||
label = {'targets_a': label, 'targets_b': label,
|
||||
'lam': torch.ones(1, device=img.device, dtype=img.dtype)}
|
||||
label = {'targets_a': label, 'targets_b': label, 'lam': torch.ones(
|
||||
1, device=img.device, dtype=img.dtype)}
|
||||
return img, label
|
||||
return img, label
|
||||
|
|
|
@ -0,0 +1,209 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
_MAX_LEVEL = 10
|
||||
|
||||
_HPARAMS = {
|
||||
'cutout_const': 40,
|
||||
'translate_const': 40,
|
||||
}
|
||||
|
||||
_FILL = tuple([128, 128, 128])
|
||||
# RGB
|
||||
|
||||
|
||||
def blend(image0, image1, factor):
|
||||
# blend image0 with image1
|
||||
# we only use this function in the 'color' function
|
||||
if factor == 0.0:
|
||||
return image0
|
||||
if factor == 1.0:
|
||||
return image1
|
||||
image0 = image0.type(torch.float32)
|
||||
image1 = image1.type(torch.float32)
|
||||
scaled = (image1 - image0) * factor
|
||||
image = image0 + scaled
|
||||
|
||||
if factor > 0.0 and factor < 1.0:
|
||||
return image.type(torch.uint8)
|
||||
|
||||
image = torch.clamp(image, 0, 255).type(torch.uint8)
|
||||
return image
|
||||
|
||||
|
||||
def autocontrast(image):
|
||||
image = TF.autocontrast(image)
|
||||
return image
|
||||
|
||||
|
||||
def equalize(image):
|
||||
image = TF.equalize(image)
|
||||
return image
|
||||
|
||||
|
||||
def rotate(image, degree, fill=_FILL):
|
||||
image = TF.rotate(image, angle=degree, fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def posterize(image, bits):
|
||||
image = TF.posterize(image, bits)
|
||||
return image
|
||||
|
||||
|
||||
def sharpness(image, factor):
|
||||
image = TF.adjust_sharpness(image, sharpness_factor=factor)
|
||||
return image
|
||||
|
||||
|
||||
def contrast(image, factor):
|
||||
image = TF.adjust_contrast(image, factor)
|
||||
return image
|
||||
|
||||
|
||||
def brightness(image, factor):
|
||||
image = TF.adjust_brightness(image, factor)
|
||||
return image
|
||||
|
||||
|
||||
def invert(image):
|
||||
return 255-image
|
||||
|
||||
|
||||
def solarize(image, threshold=128):
|
||||
return torch.where(image < threshold, image, 255-image)
|
||||
|
||||
|
||||
def solarize_add(image, addition=0, threshold=128):
|
||||
add_image = image.long() + addition
|
||||
add_image = torch.clamp(add_image, 0, 255).type(torch.uint8)
|
||||
return torch.where(image < threshold, add_image, image)
|
||||
|
||||
|
||||
def color(image, factor):
|
||||
new_image = TF.rgb_to_grayscale(image, num_output_channels=3)
|
||||
return blend(new_image, image, factor=factor)
|
||||
|
||||
|
||||
def shear_x(image, level, fill=_FILL):
|
||||
image = TF.affine(image, 0, [0, 0], 1.0, [level, 0], fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def shear_y(image, level, fill=_FILL):
|
||||
image = TF.affine(image, 0, [0, 0], 1.0, [0, level], fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def translate_x(image, level, fill=_FILL):
|
||||
image = TF.affine(image, 0, [level, 0], 1.0, [0, 0], fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def translate_y(image, level, fill=_FILL):
|
||||
image = TF.affine(image, 0, [0, level], 1.0, [0, 0], fill=fill)
|
||||
return image
|
||||
|
||||
|
||||
def cutout(image, pad_size, fill=_FILL):
|
||||
b, c, h, w = image.shape
|
||||
mask = torch.ones((b, c, h, w), dtype=torch.uint8).cuda()
|
||||
y = np.random.randint(pad_size, h-pad_size)
|
||||
x = np.random.randint(pad_size, w-pad_size)
|
||||
for i in range(c):
|
||||
mask[:, i, (y-pad_size): (y+pad_size), (x-pad_size): (x+pad_size)] = fill[i]
|
||||
image = torch.where(mask == 1, image, mask)
|
||||
return image
|
||||
|
||||
|
||||
def _randomly_negate_tensor(level):
|
||||
# With 50% prob turn the tensor negative.
|
||||
flip = np.random.randint(0, 2)
|
||||
final_level = -level if flip else level
|
||||
return final_level
|
||||
|
||||
|
||||
def _rotate_level_to_arg(level):
|
||||
level = (level/_MAX_LEVEL) * 30.
|
||||
level = _randomly_negate_tensor(level)
|
||||
return level
|
||||
|
||||
|
||||
def _shear_level_to_arg(level):
|
||||
level = (level/_MAX_LEVEL) * 0.3
|
||||
# Flip level to negative with 50% chance.
|
||||
level = _randomly_negate_tensor(level)
|
||||
return level
|
||||
|
||||
|
||||
def _translate_level_to_arg(level, translate_const):
|
||||
level = (level/_MAX_LEVEL) * float(translate_const)
|
||||
# Flip level to negative with 50% chance.
|
||||
level = _randomly_negate_tensor(level)
|
||||
return level
|
||||
|
||||
|
||||
def level(hparams):
|
||||
return {
|
||||
'AutoContrast': lambda level: None,
|
||||
'Equalize': lambda level: None,
|
||||
'Invert': lambda level: None,
|
||||
'Rotate': _rotate_level_to_arg,
|
||||
'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4)),
|
||||
'Solarize': lambda level: (int((level/_MAX_LEVEL) * 200)),
|
||||
'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110)),
|
||||
'Color': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1),
|
||||
'Contrast': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1),
|
||||
'Brightness': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1),
|
||||
'Sharpness': lambda level: ((level/_MAX_LEVEL) * 1.8 + 0.1),
|
||||
'ShearX': _shear_level_to_arg,
|
||||
'ShearY': _shear_level_to_arg,
|
||||
'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams['cutout_const'])),
|
||||
'TranslateX': lambda level: _translate_level_to_arg(level, hparams['translate_const']),
|
||||
'TranslateY': lambda level: _translate_level_to_arg(level, hparams['translate_const']),
|
||||
}
|
||||
|
||||
|
||||
AUGMENTS = {
|
||||
'AutoContrast': autocontrast,
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'Rotate': rotate,
|
||||
'Posterize': posterize,
|
||||
'Solarize': solarize,
|
||||
'SolarizeAdd': solarize_add,
|
||||
'Color': color,
|
||||
'Contrast': contrast,
|
||||
'Brightness': brightness,
|
||||
'Sharpness': sharpness,
|
||||
'ShearX': shear_x,
|
||||
'ShearY': shear_y,
|
||||
'TranslateX': translate_x,
|
||||
'TranslateY': translate_y,
|
||||
'Cutout': cutout,
|
||||
}
|
||||
|
||||
|
||||
def RandAugment(image, num_layers=2, magnitude=_MAX_LEVEL, augments=AUGMENTS):
|
||||
"""Random Augment for images, followed google randaug and the paper(https://arxiv.org/abs/2106.10270)
|
||||
:param image: the input image, in tensor format with shape of C, H, W
|
||||
:type image: uint8 Tensor
|
||||
:num_layers: how many layers will the randaug do, default=2
|
||||
:type num_layers: int
|
||||
:param magnitude: the magnitude of random augment, default=10
|
||||
:type magnitude: int
|
||||
"""
|
||||
if np.random.random() < 0.5:
|
||||
return image
|
||||
Choice_Augment = np.random.choice(a=list(augments.keys()),
|
||||
size=num_layers,
|
||||
replace=False)
|
||||
magnitude = float(magnitude)
|
||||
for i in range(num_layers):
|
||||
arg = level(_HPARAMS)[Choice_Augment[i]](magnitude)
|
||||
if arg is None:
|
||||
image = augments[Choice_Augment[i]](image)
|
||||
else:
|
||||
image = augments[Choice_Augment[i]](image, arg)
|
||||
return image
|
|
@ -26,10 +26,10 @@ def build_dali_train():
|
|||
batch_size=gpc.config.BATCH_SIZE,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
training=True,
|
||||
gpu_aug=gpc.config.dali.gpu_aug,
|
||||
cuda=True,
|
||||
mixup_alpha=gpc.config.dali.mixup_alpha
|
||||
mixup_alpha=gpc.config.dali.mixup_alpha,
|
||||
randaug_num_layers=2
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue