From 72bd7c696b9dddbff43815137defeb93dd993578 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 13 Jun 2022 14:18:04 +0800 Subject: [PATCH] [amp] included dict for type casting of model output (#1102) --- colossalai/amp/naive_amp/naive_amp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py index d8bbaad8f..02eae80b9 100644 --- a/colossalai/amp/naive_amp/naive_amp.py +++ b/colossalai/amp/naive_amp/naive_amp.py @@ -149,4 +149,6 @@ class NaiveAMPModel(nn.Module): out = self._convert_to_fp32(out) elif isinstance(out, (tuple, list)): out = [self._convert_to_fp32(val) for val in out] + elif isinstance(out, dict): + out = {key: self._convert_to_fp32(val) for key, val in out.items()} return out