diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 5436e7d6b..eae52b5ec 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -667,9 +667,9 @@ def main(args): if global_step % args.save_steps == 0: torch.cuda.synchronize() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if local_rank == 0: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if not os.path.exists(os.path.join(save_path, "config.json")): shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 64cdd2a31..dce65ff51 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -693,9 +693,9 @@ def main(args): if global_step % args.save_steps == 0: torch.cuda.synchronize() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if local_rank == 0: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if not os.path.exists(os.path.join(save_path, "config.json")): shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])