[engine] fixed bug in gradient accumulation dataloader to keep the last step (#1030)

pull/1034/head
Frank Lee 2022-05-26 14:28:23 +08:00 committed by GitHub
parent 32291dd73f
commit e4685832f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -145,6 +145,7 @@ class GradAccumDataloader:
def __next__(self) -> Union[Tensor, Tuple[Tensor]]:
if self._cur_step < self.steps_per_epoch:
self._cur_step += 1
data = next(self._dataiter)
if self._cur_step == self.steps_per_epoch and self.consume_remain_data:
# this is to handle non standard pytorch dataloader
@ -154,7 +155,7 @@ class GradAccumDataloader:
_ = next(self._dataiter)
except StopIteration:
break
return next(self._dataiter)
return data
else:
raise StopIteration