diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 9e0fd5a..2f4d832 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -150,28 +150,21 @@ class NaiveAMPModel(nn.Module): def _register_fp32_parameters_hook(self) -> None: dtype = torch.float32 + def to_fp32(x, dtype=dtype): + if isinstance(x, Tensor) and x.dtype != dtype: + return x.to(dtype) + return x + def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613 - inputs_fp32 = [] assert isinstance(inputs, tuple) - for input_data_ in inputs: - if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype: - inputs_fp32.append(input_data_.to(dtype)) - else: - inputs_fp32.append(input_data_) - return tuple(inputs_fp32) + return tuple(map(to_fp32, inputs)) def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613 - outputs_ = [] - assert isinstance(outputs, (Tensor, tuple)) + assert isinstance(inputs, Union[tuple, Tensor]) if isinstance(outputs, tuple): - for output_data_ in outputs: - if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype: - outputs_.append(output_data_.to(self.dtype)) - else: - outputs_.append(output_data_) - return tuple(outputs_) + return tuple(map(to_fp32, outputs, self.dtype)) else: - return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs + return to_fp32(outputs, self.dtype) # just want to share same for loop for ModuleList and Module if isinstance(self.model, nn.ModuleList):