mirror of https://github.com/InternLM/InternLM
fix bugs if self.model is ModuleList
parent
fa4e973725
commit
5a0b3d5d9a
|
@ -152,6 +152,7 @@ class NaiveAMPModel(nn.Module):
|
||||||
|
|
||||||
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
||||||
inputs_fp32 = []
|
inputs_fp32 = []
|
||||||
|
assert isinstance(inputs, tuple)
|
||||||
for input_data_ in inputs:
|
for input_data_ in inputs:
|
||||||
if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype:
|
if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype:
|
||||||
inputs_fp32.append(input_data_.to(dtype))
|
inputs_fp32.append(input_data_.to(dtype))
|
||||||
|
@ -173,7 +174,9 @@ class NaiveAMPModel(nn.Module):
|
||||||
return outputs.to(self.dtype)
|
return outputs.to(self.dtype)
|
||||||
|
|
||||||
# just want to share same for loop for ModuleList and Module
|
# just want to share same for loop for ModuleList and Module
|
||||||
if not isinstance(self.model, nn.ModuleList):
|
if isinstance(self.model, nn.ModuleList):
|
||||||
|
model = self.model
|
||||||
|
else:
|
||||||
model = [self.model]
|
model = [self.model]
|
||||||
|
|
||||||
modules = []
|
modules = []
|
||||||
|
|
Loading…
Reference in New Issue