Merge pull request #2338 from haofanwang/patch-1

Fix a typo in train_dreambooth_colossalai.py
pull/2379/head
Fazzie-Maqianli 2023-01-06 11:50:18 +08:00 committed by GitHub
commit 7a332b1734
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -484,7 +484,7 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.scale_lr: if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2 args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
unet = gemini_zero_dpp(unet, pg, args.placement) unet = gemini_zero_dpp(unet, pg, args.placement)