fix(core/trainer.py): fix streaming train state load error (#247)

pull/251/head
huangting4201 2023-08-29 18:47:21 +08:00 committed by GitHub
parent fc4b8918c4
commit b84d937478
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -78,6 +78,7 @@ class TrainState:
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
# track the actual updates of sampler when using weighted sampling
if hasattr(self, "batch_sampler"):
self.batch_sampler = train_dl.batch_sampler.copy()
self.batch_sampler_iter = iter(self.batch_sampler)