'fix/format' (#573)

pull/673/head
yuxuan-lou 2022-03-31 15:46:11 +08:00 committed by binmakeswell
parent b0f708dfc1
commit cfb41297ff
2 changed files with 6 additions and 6 deletions

View File

@ -85,8 +85,7 @@ class BaseSchedule(ABC):
data_iter: Iterable,
forward_only: bool,
return_loss: bool = True,
return_output_label: bool = True
):
return_output_label: bool = True):
"""The process function over a batch of dataset for training or evaluation.
Args:
@ -107,8 +106,9 @@ class BaseSchedule(ABC):
@staticmethod
def _call_engine_criterion(engine, outputs, labels):
assert isinstance(outputs, (torch.Tensor, list, tuple)
), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
assert isinstance(
outputs,
(torch.Tensor, list, tuple)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(labels, torch.Tensor):