mirror of https://github.com/hpcaitech/ColossalAI
[amp] included dict for type casting of model output (#1102)
parent
5a9d8ef4d5
commit
72bd7c696b
|
@ -149,4 +149,6 @@ class NaiveAMPModel(nn.Module):
|
||||||
out = self._convert_to_fp32(out)
|
out = self._convert_to_fp32(out)
|
||||||
elif isinstance(out, (tuple, list)):
|
elif isinstance(out, (tuple, list)):
|
||||||
out = [self._convert_to_fp32(val) for val in out]
|
out = [self._convert_to_fp32(val) for val in out]
|
||||||
|
elif isinstance(out, dict):
|
||||||
|
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
||||||
return out
|
return out
|
||||||
|
|
Loading…
Reference in New Issue