fix(core/trainer.py): fix streaming train state load error (#247)

pull/251/head
huangting4201 2023-08-29 18:47:21 +08:00 committed by GitHub
parent fc4b8918c4
commit b84d937478
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -78,8 +78,9 @@ class TrainState:
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1 self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
# track the actual updates of sampler when using weighted sampling # track the actual updates of sampler when using weighted sampling
self.batch_sampler = train_dl.batch_sampler.copy() if hasattr(self, "batch_sampler"):
self.batch_sampler_iter = iter(self.batch_sampler) self.batch_sampler = train_dl.batch_sampler.copy()
self.batch_sampler_iter = iter(self.batch_sampler)
# resume tensorboard from older tensorboard_folder # resume tensorboard from older tensorboard_folder
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None) self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)