mirror of https://github.com/InternLM/InternLM
refactor code for pre/post hook
parent
89d0373a8c
commit
b7229fd9fb
|
@ -150,28 +150,21 @@ class NaiveAMPModel(nn.Module):
|
|||
def _register_fp32_parameters_hook(self) -> None:
|
||||
dtype = torch.float32
|
||||
|
||||
def to_fp32(x, dtype=dtype):
|
||||
if isinstance(x, Tensor) and x.dtype != dtype:
|
||||
return x.to(dtype)
|
||||
return x
|
||||
|
||||
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
||||
inputs_fp32 = []
|
||||
assert isinstance(inputs, tuple)
|
||||
for input_data_ in inputs:
|
||||
if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype:
|
||||
inputs_fp32.append(input_data_.to(dtype))
|
||||
else:
|
||||
inputs_fp32.append(input_data_)
|
||||
return tuple(inputs_fp32)
|
||||
return tuple(map(to_fp32, inputs))
|
||||
|
||||
def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613
|
||||
outputs_ = []
|
||||
assert isinstance(outputs, (Tensor, tuple))
|
||||
assert isinstance(inputs, Union[tuple, Tensor])
|
||||
if isinstance(outputs, tuple):
|
||||
for output_data_ in outputs:
|
||||
if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype:
|
||||
outputs_.append(output_data_.to(self.dtype))
|
||||
else:
|
||||
outputs_.append(output_data_)
|
||||
return tuple(outputs_)
|
||||
return tuple(map(to_fp32, outputs, self.dtype))
|
||||
else:
|
||||
return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs
|
||||
return to_fp32(outputs, self.dtype)
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if isinstance(self.model, nn.ModuleList):
|
||||
|
|
Loading…
Reference in New Issue