From 883160a558f6286eefda083cbdda0e02f598eda7 Mon Sep 17 00:00:00 2001 From: Qu Wenwen Date: Tue, 19 Sep 2023 20:42:16 +0800 Subject: [PATCH] add dtype condition for post hook --- internlm/core/naive_amp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index c8cfbf7..9e0fd5a 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -165,13 +165,13 @@ class NaiveAMPModel(nn.Module): assert isinstance(outputs, (Tensor, tuple)) if isinstance(outputs, tuple): for output_data_ in outputs: - if isinstance(output_data_, Tensor): + if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype: outputs_.append(output_data_.to(self.dtype)) else: outputs_.append(output_data_) return tuple(outputs_) else: - return outputs.to(self.dtype) + return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs # just want to share same for loop for ModuleList and Module if isinstance(self.model, nn.ModuleList):