fixed model saving bugs

pull/3905/head
Maruyama_Aya 2023-06-06 16:07:34 +08:00
parent b29e1f0722
commit d3379f0be7
2 changed files with 4 additions and 4 deletions

View File

@ -667,9 +667,9 @@ def main(args):
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
torch.cuda.synchronize() torch.cuda.synchronize()
if local_rank == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
if local_rank == 0:
if not os.path.exists(os.path.join(save_path, "config.json")): if not os.path.exists(os.path.join(save_path, "config.json")):
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])

View File

@ -693,9 +693,9 @@ def main(args):
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
torch.cuda.synchronize() torch.cuda.synchronize()
if local_rank == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin")) booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
if local_rank == 0:
if not os.path.exists(os.path.join(save_path, "config.json")): if not os.path.exists(os.path.join(save_path, "config.json")):
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path) shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0]) logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])