[hotfix]fix some bugs caused by refactored schedule. (#1148)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [hotfix]fix some bugs caused by refactored schedule.
pull/1129/head
YuliangLiu0306 2022-06-21 22:46:30 +08:00 committed by GitHub
parent 8cdce0399c
commit f1f51990b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 2 deletions

View File

@ -36,7 +36,13 @@ class BaseSchedule(ABC):
if isinstance(data, torch.Tensor):
data = data.to(get_current_device())
elif isinstance(data, (list, tuple)):
data = [self._move_tensor(v) for v in data]
data_to_return = []
for element in data:
if isinstance(element, dict):
data_to_return.append({k: self._move_tensor(v) for k, v in element.items()})
else:
data_to_return.append(self._move_tensor(element))
data = data_to_return
elif isinstance(data, dict):
data = {k: self._move_tensor(v) for k, v in data.items()}
else:

View File

@ -66,7 +66,6 @@ class NonPipelineSchedule(BaseSchedule):
assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data = self.load_batch(data_iter)
if self.data_process_func:
data, label = self.data_process_func(batch_data)
else: