mirror of https://github.com/hpcaitech/ColossalAI
[hotfix]fix some bugs caused by refactored schedule. (#1148)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [hotfix]fix some bugs caused by refactored schedule.
pull/1129/head
parent
8cdce0399c
commit
f1f51990b9
|
@ -36,7 +36,13 @@ class BaseSchedule(ABC):
|
|||
if isinstance(data, torch.Tensor):
|
||||
data = data.to(get_current_device())
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data = [self._move_tensor(v) for v in data]
|
||||
data_to_return = []
|
||||
for element in data:
|
||||
if isinstance(element, dict):
|
||||
data_to_return.append({k: self._move_tensor(v) for k, v in element.items()})
|
||||
else:
|
||||
data_to_return.append(self._move_tensor(element))
|
||||
data = data_to_return
|
||||
elif isinstance(data, dict):
|
||||
data = {k: self._move_tensor(v) for k, v in data.items()}
|
||||
else:
|
||||
|
|
|
@ -66,7 +66,6 @@ class NonPipelineSchedule(BaseSchedule):
|
|||
assert forward_only or return_loss, \
|
||||
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
batch_data = self.load_batch(data_iter)
|
||||
|
||||
if self.data_process_func:
|
||||
data, label = self.data_process_func(batch_data)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue