[bug] fixed grad scaler compatibility with torch 1.8 (#735)

pull/739/head
Frank Lee 2022-04-12 16:04:21 +08:00 committed by GitHub
parent 53cb584808
commit a4e91bc87f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 2 deletions

View File

@ -12,6 +12,7 @@ from colossalai.context import ParallelMode
import torch.distributed as dist
from colossalai.core import global_context as gpc
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from packaging import version
class _MultiDeviceReplicator(object):
@ -122,6 +123,14 @@ class GradScaler(object):
else:
self._enabled = enabled
# check version
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
if torch_version.minor > 8:
self._higher_than_torch18 = True
else:
self._higher_than_torch18 = False
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
@ -404,8 +413,12 @@ class GradScaler(object):
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
self._backoff_factor, self._growth_interval)
if self._higher_than_torch18:
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
self._backoff_factor, self._growth_interval)
else:
self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor,
self._backoff_factor, self._growth_interval)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)