[amp] included dict for type casting of model output (#1102)

pull/1106/head
Frank Lee 2022-06-13 14:18:04 +08:00 committed by GitHub
parent 5a9d8ef4d5
commit 72bd7c696b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -149,4 +149,6 @@ class NaiveAMPModel(nn.Module):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
return out