BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
NUM_EPOCHS = 200
WARMUP_EPOCHS = 40
WORLD_SIZE = 4
MOE_MODEL_PARALLEL_SIZE = 4
parallel = dict(
moe=dict(size=MOE_MODEL_PARALLEL_SIZE)
)
LOG_PATH = f"./cifar10_moe"