mirror of https://github.com/hpcaitech/ColossalAI
[amp] included dict for type casting of model output (#1102)
parent
5a9d8ef4d5
commit
72bd7c696b
|
@ -149,4 +149,6 @@ class NaiveAMPModel(nn.Module):
|
|||
out = self._convert_to_fp32(out)
|
||||
elif isinstance(out, (tuple, list)):
|
||||
out = [self._convert_to_fp32(val) for val in out]
|
||||
elif isinstance(out, dict):
|
||||
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue