From e4685832f82cbd265c5af92bc94ba20b1c3e4dcb Mon Sep 17 00:00:00 2001
From: Frank Lee <somerlee.9@gmail.com>
Date: Thu, 26 May 2022 14:28:23 +0800
Subject: [PATCH] [engine] fixed bug in gradient accumulation dataloader to
 keep the last step (#1030)

---
 .../engine/gradient_accumulation/_gradient_accumulation.py     | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py
index 5bfe3f449..89c28c3be 100644
--- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py
+++ b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py
@@ -145,6 +145,7 @@ class GradAccumDataloader:
     def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
         if self._cur_step < self.steps_per_epoch:
             self._cur_step += 1
+            data = next(self._dataiter)
 
             if self._cur_step == self.steps_per_epoch and self.consume_remain_data:
                 # this is to handle non standard pytorch dataloader
@@ -154,7 +155,7 @@ class GradAccumDataloader:
                         _ = next(self._dataiter)
                     except StopIteration:
                         break
-            return next(self._dataiter)
+            return data
         else:
             raise StopIteration