From 9edd0aa75e2c2b5f308a1b19c3b691d58f23aae3 Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Thu, 5 Jan 2023 15:49:57 +0800 Subject: [PATCH] Update train_dreambooth_colossalai.py accelerator.num_processes -> gpc.get_world_size(ParallelMode.DATA) --- examples/images/dreambooth/train_dreambooth_colossalai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index aff4d925d..b95353d9b 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -484,7 +484,7 @@ def main(args): unet.enable_gradient_checkpointing() if args.scale_lr: - args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2 + args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) unet = gemini_zero_dpp(unet, pg, args.placement)