mirror of https://github.com/hpcaitech/ColossalAI
fixed model saving bugs
parent
b29e1f0722
commit
d3379f0be7
|
@ -667,9 +667,9 @@ def main(args):
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||||
|
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
|
||||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
|
||||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||||
|
|
|
@ -693,9 +693,9 @@ def main(args):
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||||
|
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
|
||||||
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
|
|
||||||
if not os.path.exists(os.path.join(save_path, "config.json")):
|
if not os.path.exists(os.path.join(save_path, "config.json")):
|
||||||
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
|
||||||
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
|
||||||
|
|
Loading…
Reference in New Issue