from colossalai.amp import AMP_TYPE # hyperparameters # BATCH_SIZE is as per GPU # global batch size = BATCH_SIZE x data parallel size BATCH_SIZE = 512 LEARNING_RATE = 3e-3 WEIGHT_DECAY = 0.3 NUM_EPOCHS = 2 WARMUP_EPOCHS = 1 # model config NUM_CLASSES = 10 fp16 = dict(mode=AMP_TYPE.NAIVE) clip_grad_norm = 1.0