mirror of https://github.com/InternLM/InternLM
fix bugs if self.model is ModuleList
parent
fa4e973725
commit
5a0b3d5d9a
|
@ -152,6 +152,7 @@ class NaiveAMPModel(nn.Module):
|
|||
|
||||
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
||||
inputs_fp32 = []
|
||||
assert isinstance(inputs, tuple)
|
||||
for input_data_ in inputs:
|
||||
if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype:
|
||||
inputs_fp32.append(input_data_.to(dtype))
|
||||
|
@ -173,7 +174,9 @@ class NaiveAMPModel(nn.Module):
|
|||
return outputs.to(self.dtype)
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(self.model, nn.ModuleList):
|
||||
if isinstance(self.model, nn.ModuleList):
|
||||
model = self.model
|
||||
else:
|
||||
model = [self.model]
|
||||
|
||||
modules = []
|
||||
|
|
Loading…
Reference in New Issue