[hotfix] fix an assertion bug in base schedule. (#1250)

pull/1253/head
YuliangLiu0306 2022-07-12 14:20:02 +08:00 committed by GitHub
parent 97d713855a
commit 17ed33350b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -117,9 +117,9 @@ class BaseSchedule(ABC):
@staticmethod
def _call_engine_criterion(engine, outputs, labels):
assert isinstance(
outputs,
(torch.Tensor, list, tuple)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
assert isinstance(outputs,
(torch.Tensor, list, tuple,
dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(labels, torch.Tensor):