From 17ed33350b387ff4666e18d2d2b0be446e9bf313 Mon Sep 17 00:00:00 2001
From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Date: Tue, 12 Jul 2022 14:20:02 +0800
Subject: [PATCH] [hotfix] fix an assertion bug in base schedule. (#1250)

---
 colossalai/engine/schedule/_base_schedule.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py
index b30aff784..ba797bad9 100644
--- a/colossalai/engine/schedule/_base_schedule.py
+++ b/colossalai/engine/schedule/_base_schedule.py
@@ -117,9 +117,9 @@ class BaseSchedule(ABC):
 
     @staticmethod
     def _call_engine_criterion(engine, outputs, labels):
-        assert isinstance(
-            outputs,
-            (torch.Tensor, list, tuple)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
+        assert isinstance(outputs,
+                          (torch.Tensor, list, tuple,
+                           dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
         if isinstance(outputs, torch.Tensor):
             outputs = (outputs,)
         if isinstance(labels, torch.Tensor):