mirror of https://github.com/InternLM/InternLM
add dtype condition for post hook
parent
5a0b3d5d9a
commit
883160a558
|
@ -165,13 +165,13 @@ class NaiveAMPModel(nn.Module):
|
||||||
assert isinstance(outputs, (Tensor, tuple))
|
assert isinstance(outputs, (Tensor, tuple))
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, tuple):
|
||||||
for output_data_ in outputs:
|
for output_data_ in outputs:
|
||||||
if isinstance(output_data_, Tensor):
|
if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype:
|
||||||
outputs_.append(output_data_.to(self.dtype))
|
outputs_.append(output_data_.to(self.dtype))
|
||||||
else:
|
else:
|
||||||
outputs_.append(output_data_)
|
outputs_.append(output_data_)
|
||||||
return tuple(outputs_)
|
return tuple(outputs_)
|
||||||
else:
|
else:
|
||||||
return outputs.to(self.dtype)
|
return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs
|
||||||
|
|
||||||
# just want to share same for loop for ModuleList and Module
|
# just want to share same for loop for ModuleList and Module
|
||||||
if isinstance(self.model, nn.ModuleList):
|
if isinstance(self.model, nn.ModuleList):
|
||||||
|
|
Loading…
Reference in New Issue