From e4685832f82cbd265c5af92bc94ba20b1c3e4dcb Mon Sep 17 00:00:00 2001 From: Frank Lee <somerlee.9@gmail.com> Date: Thu, 26 May 2022 14:28:23 +0800 Subject: [PATCH] [engine] fixed bug in gradient accumulation dataloader to keep the last step (#1030) --- .../engine/gradient_accumulation/_gradient_accumulation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py index 5bfe3f449..89c28c3be 100644 --- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py @@ -145,6 +145,7 @@ class GradAccumDataloader: def __next__(self) -> Union[Tensor, Tuple[Tensor]]: if self._cur_step < self.steps_per_epoch: self._cur_step += 1 + data = next(self._dataiter) if self._cur_step == self.steps_per_epoch and self.consume_remain_data: # this is to handle non standard pytorch dataloader @@ -154,7 +155,7 @@ class GradAccumDataloader: _ = next(self._dataiter) except StopIteration: break - return next(self._dataiter) + return data else: raise StopIteration