From b7229fd9fbac64660a9f3498688eccb94e6b72f3 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 22 Sep 2023 14:03:42 +0800 Subject: [PATCH] refactor code for pre/post hook --- internlm/core/naive_amp.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 9e0fd5a..2f4d832 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -150,28 +150,21 @@ class NaiveAMPModel(nn.Module): def _register_fp32_parameters_hook(self) -> None: dtype = torch.float32 + def to_fp32(x, dtype=dtype): + if isinstance(x, Tensor) and x.dtype != dtype: + return x.to(dtype) + return x + def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613 - inputs_fp32 = [] assert isinstance(inputs, tuple) - for input_data_ in inputs: - if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype: - inputs_fp32.append(input_data_.to(dtype)) - else: - inputs_fp32.append(input_data_) - return tuple(inputs_fp32) + return tuple(map(to_fp32, inputs)) def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613 - outputs_ = [] - assert isinstance(outputs, (Tensor, tuple)) + assert isinstance(inputs, Union[tuple, Tensor]) if isinstance(outputs, tuple): - for output_data_ in outputs: - if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype: - outputs_.append(output_data_.to(self.dtype)) - else: - outputs_.append(output_data_) - return tuple(outputs_) + return tuple(map(to_fp32, outputs, self.dtype)) else: - return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs + return to_fp32(outputs, self.dtype) # just want to share same for loop for ModuleList and Module if isinstance(self.model, nn.ModuleList):