mirror of https://github.com/hpcaitech/ColossalAI
[bug] fixed grad scaler compatibility with torch 1.8 (#735)
parent
53cb584808
commit
a4e91bc87f
|
@ -12,6 +12,7 @@ from colossalai.context import ParallelMode
|
|||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from packaging import version
|
||||
|
||||
|
||||
class _MultiDeviceReplicator(object):
|
||||
|
@ -122,6 +123,14 @@ class GradScaler(object):
|
|||
else:
|
||||
self._enabled = enabled
|
||||
|
||||
# check version
|
||||
torch_version = version.parse(torch.__version__)
|
||||
assert torch_version.major == 1
|
||||
if torch_version.minor > 8:
|
||||
self._higher_than_torch18 = True
|
||||
else:
|
||||
self._higher_than_torch18 = False
|
||||
|
||||
if self._enabled:
|
||||
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||
|
@ -404,8 +413,12 @@ class GradScaler(object):
|
|||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
|
||||
self._backoff_factor, self._growth_interval)
|
||||
if self._higher_than_torch18:
|
||||
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
|
||||
self._backoff_factor, self._growth_interval)
|
||||
else:
|
||||
self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor,
|
||||
self._backoff_factor, self._growth_interval)
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
|
Loading…
Reference in New Issue