mirror of https://github.com/InternLM/InternLM
add dtype condition for post hook
parent
5a0b3d5d9a
commit
883160a558
|
@ -165,13 +165,13 @@ class NaiveAMPModel(nn.Module):
|
|||
assert isinstance(outputs, (Tensor, tuple))
|
||||
if isinstance(outputs, tuple):
|
||||
for output_data_ in outputs:
|
||||
if isinstance(output_data_, Tensor):
|
||||
if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype:
|
||||
outputs_.append(output_data_.to(self.dtype))
|
||||
else:
|
||||
outputs_.append(output_data_)
|
||||
return tuple(outputs_)
|
||||
else:
|
||||
return outputs.to(self.dtype)
|
||||
return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if isinstance(self.model, nn.ModuleList):
|
||||
|
|
Loading…
Reference in New Issue