|
|
|
@ -85,8 +85,7 @@ class BaseSchedule(ABC):
|
|
|
|
|
data_iter: Iterable,
|
|
|
|
|
forward_only: bool,
|
|
|
|
|
return_loss: bool = True,
|
|
|
|
|
return_output_label: bool = True
|
|
|
|
|
):
|
|
|
|
|
return_output_label: bool = True):
|
|
|
|
|
"""The process function over a batch of dataset for training or evaluation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -107,8 +106,9 @@ class BaseSchedule(ABC):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _call_engine_criterion(engine, outputs, labels):
|
|
|
|
|
assert isinstance(outputs, (torch.Tensor, list, tuple)
|
|
|
|
|
), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
|
|
|
|
|
assert isinstance(
|
|
|
|
|
outputs,
|
|
|
|
|
(torch.Tensor, list, tuple)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
|
|
|
|
|
if isinstance(outputs, torch.Tensor):
|
|
|
|
|
outputs = (outputs,)
|
|
|
|
|
if isinstance(labels, torch.Tensor):
|
|
|
|
|