BATCH_SIZE = 512 LEARNING_RATE = 2e-3 WEIGHT_DECAY = 3e-2 NUM_EPOCHS = 200 WARMUP_EPOCHS = 40 WORLD_SIZE = 4 MOE_MODEL_PARALLEL_SIZE = 4 parallel = dict( moe=dict(size=MOE_MODEL_PARALLEL_SIZE) ) LOG_PATH = f"./cifar10_moe"