mirror of https://github.com/hpcaitech/ColossalAI
fixed model saving bugs
parent
b29e1f0722
commit
d3379f0be7
|
@ -667,9 +667,9 @@ def main(args):
|
|||
|
||||
if global_step % args.save_steps == 0:
|
||||
torch.cuda.synchronize()
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||
if local_rank == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||
|
|
|
@ -693,9 +693,9 @@ def main(args):
|
|||
|
||||
if global_step % args.save_steps == 0:
|
||||
torch.cuda.synchronize()
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||
if local_rank == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||
|
|
Loading…
Reference in New Issue