mirror of https://github.com/InternLM/InternLM
fix(core/trainer.py): fix streaming train state load error (#247)
parent
fc4b8918c4
commit
b84d937478
|
@ -78,8 +78,9 @@ class TrainState:
|
||||||
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
|
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
|
||||||
|
|
||||||
# track the actual updates of sampler when using weighted sampling
|
# track the actual updates of sampler when using weighted sampling
|
||||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
if hasattr(self, "batch_sampler"):
|
||||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||||
|
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||||
|
|
||||||
# resume tensorboard from older tensorboard_folder
|
# resume tensorboard from older tensorboard_folder
|
||||||
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
||||||
|
|
Loading…
Reference in New Issue