From d3379f0be7e30854ee2353924d735642f4909aab Mon Sep 17 00:00:00 2001 From: Maruyama_Aya Date: Tue, 6 Jun 2023 16:07:34 +0800 Subject: [PATCH] fixed model saving bugs --- examples/images/dreambooth/train_dreambooth_colossalai.py | 4 ++-- .../images/dreambooth/train_dreambooth_colossalai_lora.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 5436e7d6b..eae52b5ec 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -667,9 +667,9 @@ def main(args): if global_step % args.save_steps == 0: torch.cuda.synchronize() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if local_rank == 0: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if not os.path.exists(os.path.join(save_path, "config.json")): shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 64cdd2a31..dce65ff51 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -693,9 +693,9 @@ def main(args): if global_step % args.save_steps == 0: torch.cuda.synchronize() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if local_rank == 0: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) if not os.path.exists(os.path.join(save_path, "config.json")): shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])