From b84d937478afeff052d089162d551d9d5c1373f6 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 29 Aug 2023 18:47:21 +0800 Subject: [PATCH] fix(core/trainer.py): fix streaming train state load error (#247) --- internlm/core/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index a027fed..2839ad9 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -78,8 +78,9 @@ class TrainState: self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1 # track the actual updates of sampler when using weighted sampling - self.batch_sampler = train_dl.batch_sampler.copy() - self.batch_sampler_iter = iter(self.batch_sampler) + if hasattr(self, "batch_sampler"): + self.batch_sampler = train_dl.batch_sampler.copy() + self.batch_sampler_iter = iter(self.batch_sampler) # resume tensorboard from older tensorboard_folder self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)