diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index aff4d925d..b95353d9b 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -484,7 +484,7 @@ def main(args): unet.enable_gradient_checkpointing() if args.scale_lr: - args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2 + args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) unet = gemini_zero_dpp(unet, pg, args.placement)