diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index a027fed..2839ad9 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -78,8 +78,9 @@ class TrainState: self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1 # track the actual updates of sampler when using weighted sampling - self.batch_sampler = train_dl.batch_sampler.copy() - self.batch_sampler_iter = iter(self.batch_sampler) + if hasattr(self, "batch_sampler"): + self.batch_sampler = train_dl.batch_sampler.copy() + self.batch_sampler_iter = iter(self.batch_sampler) # resume tensorboard from older tensorboard_folder self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)