fix bugs if self.model is ModuleList

pull/319/head
Qu Wenwen 2023-09-19 20:31:33 +08:00
parent fa4e973725
commit 5a0b3d5d9a
1 changed files with 4 additions and 1 deletions

View File

@ -152,6 +152,7 @@ class NaiveAMPModel(nn.Module):
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
inputs_fp32 = []
assert isinstance(inputs, tuple)
for input_data_ in inputs:
if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype:
inputs_fp32.append(input_data_.to(dtype))
@ -173,7 +174,9 @@ class NaiveAMPModel(nn.Module):
return outputs.to(self.dtype)
# just want to share same for loop for ModuleList and Module
if not isinstance(self.model, nn.ModuleList):
if isinstance(self.model, nn.ModuleList):
model = self.model
else:
model = [self.model]
modules = []