mirror of https://github.com/hpcaitech/ColossalAI
[engine] fixed bug in gradient accumulation dataloader to keep the last step (#1030)
parent
32291dd73f
commit
e4685832f8
|
@ -145,6 +145,7 @@ class GradAccumDataloader:
|
|||
def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
|
||||
if self._cur_step < self.steps_per_epoch:
|
||||
self._cur_step += 1
|
||||
data = next(self._dataiter)
|
||||
|
||||
if self._cur_step == self.steps_per_epoch and self.consume_remain_data:
|
||||
# this is to handle non standard pytorch dataloader
|
||||
|
@ -154,7 +155,7 @@ class GradAccumDataloader:
|
|||
_ = next(self._dataiter)
|
||||
except StopIteration:
|
||||
break
|
||||
return next(self._dataiter)
|
||||
return data
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
|
|
Loading…
Reference in New Issue