mirror of https://github.com/InternLM/InternLM
fix(core/trainer.py): fix streaming train state load error (#247)
parent
fc4b8918c4
commit
b84d937478
|
@ -78,6 +78,7 @@ class TrainState:
|
|||
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
|
||||
|
||||
# track the actual updates of sampler when using weighted sampling
|
||||
if hasattr(self, "batch_sampler"):
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
|
||||
|
|
Loading…
Reference in New Issue