From 5a0b3d5d9a7934ec96d82186db78f4cda6ad6829 Mon Sep 17 00:00:00 2001 From: Qu Wenwen Date: Tue, 19 Sep 2023 20:31:33 +0800 Subject: [PATCH] fix bugs if self.model is ModuleList --- internlm/core/naive_amp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 5a7838a..c8cfbf7 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -152,6 +152,7 @@ class NaiveAMPModel(nn.Module): def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613 inputs_fp32 = [] + assert isinstance(inputs, tuple) for input_data_ in inputs: if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype: inputs_fp32.append(input_data_.to(dtype)) @@ -173,7 +174,9 @@ class NaiveAMPModel(nn.Module): return outputs.to(self.dtype) # just want to share same for loop for ModuleList and Module - if not isinstance(self.model, nn.ModuleList): + if isinstance(self.model, nn.ModuleList): + model = self.model + else: model = [self.model] modules = []