refactor code for pre/post hook

pull/319/head
Wenwen Qu 2023-09-22 14:03:42 +08:00
parent 89d0373a8c
commit b7229fd9fb
1 changed files with 9 additions and 16 deletions

View File

@ -150,28 +150,21 @@ class NaiveAMPModel(nn.Module):
def _register_fp32_parameters_hook(self) -> None:
dtype = torch.float32
def to_fp32(x, dtype=dtype):
if isinstance(x, Tensor) and x.dtype != dtype:
return x.to(dtype)
return x
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
inputs_fp32 = []
assert isinstance(inputs, tuple)
for input_data_ in inputs:
if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype:
inputs_fp32.append(input_data_.to(dtype))
else:
inputs_fp32.append(input_data_)
return tuple(inputs_fp32)
return tuple(map(to_fp32, inputs))
def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613
outputs_ = []
assert isinstance(outputs, (Tensor, tuple))
assert isinstance(inputs, Union[tuple, Tensor])
if isinstance(outputs, tuple):
for output_data_ in outputs:
if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype:
outputs_.append(output_data_.to(self.dtype))
else:
outputs_.append(output_data_)
return tuple(outputs_)
return tuple(map(to_fp32, outputs, self.dtype))
else:
return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs
return to_fp32(outputs, self.dtype)
# just want to share same for loop for ModuleList and Module
if isinstance(self.model, nn.ModuleList):