From 17ed33350b387ff4666e18d2d2b0be446e9bf313 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 12 Jul 2022 14:20:02 +0800 Subject: [PATCH] [hotfix] fix an assertion bug in base schedule. (#1250) --- colossalai/engine/schedule/_base_schedule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index b30aff784..ba797bad9 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -117,9 +117,9 @@ class BaseSchedule(ABC): @staticmethod def _call_engine_criterion(engine, outputs, labels): - assert isinstance( - outputs, - (torch.Tensor, list, tuple)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' + assert isinstance(outputs, + (torch.Tensor, list, tuple, + dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' if isinstance(outputs, torch.Tensor): outputs = (outputs,) if isinstance(labels, torch.Tensor):