mirror of https://github.com/hpcaitech/ColossalAI
Update train_dreambooth_colossalai.py
accelerator.num_processes -> gpc.get_world_size(ParallelMode.DATA)pull/2338/head
parent
f1bc2418c4
commit
9edd0aa75e
|
@ -484,7 +484,7 @@ def main(args):
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2
|
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
|
||||||
unet = gemini_zero_dpp(unet, pg, args.placement)
|
unet = gemini_zero_dpp(unet, pg, args.placement)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue