add dtype condition for post hook

pull/319/head
Qu Wenwen 2023-09-19 20:42:16 +08:00
parent 5a0b3d5d9a
commit 883160a558
1 changed files with 2 additions and 2 deletions

View File

@ -165,13 +165,13 @@ class NaiveAMPModel(nn.Module):
assert isinstance(outputs, (Tensor, tuple))
if isinstance(outputs, tuple):
for output_data_ in outputs:
if isinstance(output_data_, Tensor):
if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype:
outputs_.append(output_data_.to(self.dtype))
else:
outputs_.append(output_data_)
return tuple(outputs_)
else:
return outputs.to(self.dtype)
return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs
# just want to share same for loop for ModuleList and Module
if isinstance(self.model, nn.ModuleList):