mirror of https://github.com/hpcaitech/ColossalAI
fix format (#570)
parent
2a915a8b62
commit
b0f708dfc1
|
@ -27,8 +27,7 @@ class _MultiDeviceReplicator(object):
|
|||
def get(self, device) -> torch.Tensor:
|
||||
retval = self._per_device_tensors.get(device, None)
|
||||
if retval is None:
|
||||
retval = self.master.to(
|
||||
device=device, non_blocking=True, copy=True)
|
||||
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
||||
self._per_device_tensors[device] = retval
|
||||
return retval
|
||||
|
||||
|
@ -116,15 +115,9 @@ class GradScaler(object):
|
|||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
init_scale=2.**16,
|
||||
growth_factor=2.0,
|
||||
backoff_factor=0.5,
|
||||
growth_interval=2000,
|
||||
enabled=True):
|
||||
def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
|
||||
if enabled and not torch.cuda.is_available():
|
||||
warnings.warn(
|
||||
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
|
||||
warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
|
||||
self._enabled = False
|
||||
else:
|
||||
self._enabled = enabled
|
||||
|
@ -142,23 +135,18 @@ class GradScaler(object):
|
|||
self._init_growth_tracker = 0
|
||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||
self._growth_tracker = None
|
||||
self._per_optimizer_states = defaultdict(
|
||||
_refresh_per_optimizer_state)
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
||||
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||
assert self._scale is not None, "Attempted {} but _scale is None. ".format(
|
||||
funcname) + fix
|
||||
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(
|
||||
funcname) + fix
|
||||
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
|
||||
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
|
||||
return (self._scale, self._growth_tracker)
|
||||
|
||||
def _lazy_init_scale_growth_tracker(self, dev):
|
||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||
self._scale = torch.full(
|
||||
(1,), self._init_scale, dtype=torch.float32, device=dev)
|
||||
self._growth_tracker = torch.full(
|
||||
(1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
||||
self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev)
|
||||
self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
|
@ -201,8 +189,7 @@ class GradScaler(object):
|
|||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError(
|
||||
"outputs must be a Tensor or an iterable of Tensors")
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
||||
|
||||
|
@ -216,16 +203,14 @@ class GradScaler(object):
|
|||
|
||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||
# Google says mypy struggles with defaultdicts type annotations.
|
||||
per_device_and_dtype_grads = defaultdict(
|
||||
lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||
with torch.no_grad():
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||
raise ValueError(
|
||||
"Attempting to unscale FP16 gradients.")
|
||||
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||
if param.grad.is_sparse:
|
||||
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||
|
@ -238,22 +223,17 @@ class GradScaler(object):
|
|||
to_unscale = param.grad
|
||||
|
||||
# TODO: is there a way to split by device and dtype without appending in the inner loop?
|
||||
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
|
||||
to_unscale)
|
||||
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)
|
||||
|
||||
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||
for grads in per_dtype_grads.values():
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grads,
|
||||
per_device_found_inf.get(
|
||||
device),
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
|
||||
per_device_inv_scale.get(device))
|
||||
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
||||
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
|
||||
coalesced = _flatten_dense_tensors(vals)
|
||||
dist.all_reduce(coalesced,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.MODEL))
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||
for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)):
|
||||
buf.copy_(synced)
|
||||
return per_device_found_inf._per_device_tensors
|
||||
|
@ -298,19 +278,16 @@ class GradScaler(object):
|
|||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||
raise RuntimeError(
|
||||
"unscale_() has already been called on this optimizer since the last update().")
|
||||
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
|
||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||
optimizer, inv_scale, found_inf, False)
|
||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
|
||||
optimizer_state["stage"] = OptState.UNSCALED
|
||||
|
||||
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||
|
@ -344,16 +321,14 @@ class GradScaler(object):
|
|||
return optimizer.step(*args, **kwargs)
|
||||
|
||||
if "closure" in kwargs:
|
||||
raise RuntimeError(
|
||||
"Closure use is not currently supported if GradScaler is enabled.")
|
||||
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
|
||||
|
||||
self._check_scale_growth_tracker("step")
|
||||
|
||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||
|
||||
if optimizer_state["stage"] is OptState.STEPPED:
|
||||
raise RuntimeError(
|
||||
"step() has already been called since the last update().")
|
||||
raise RuntimeError("step() has already been called since the last update().")
|
||||
|
||||
retval = None
|
||||
|
||||
|
@ -369,11 +344,9 @@ class GradScaler(object):
|
|||
if optimizer_state["stage"] is OptState.READY:
|
||||
self.unscale_(optimizer)
|
||||
|
||||
assert len(optimizer_state["found_inf_per_device"]
|
||||
) > 0, "No inf checks were recorded for this optimizer."
|
||||
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
|
||||
|
||||
retval = self._maybe_opt_step(
|
||||
optimizer, optimizer_state, *args, **kwargs)
|
||||
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
|
||||
|
||||
optimizer_state["stage"] = OptState.STEPPED
|
||||
|
||||
|
@ -407,35 +380,32 @@ class GradScaler(object):
|
|||
if new_scale is not None:
|
||||
# Accept a new user-defined scale.
|
||||
if isinstance(new_scale, float):
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||
# type: ignore[attr-defined]
|
||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
|
||||
for state in self._per_optimizer_states.values()
|
||||
for found_inf in state["found_inf_per_device"].values()]
|
||||
found_infs = [
|
||||
found_inf.to(device=_scale.device, non_blocking=True)
|
||||
for state in self._per_optimizer_states.values()
|
||||
for found_inf in state["found_inf_per_device"].values()
|
||||
]
|
||||
|
||||
assert len(
|
||||
found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
for i in range(1, len(found_infs)):
|
||||
found_inf_combined += found_infs[i]
|
||||
|
||||
torch._amp_update_scale_(_scale,
|
||||
_growth_tracker,
|
||||
found_inf_combined,
|
||||
self._growth_factor,
|
||||
self._backoff_factor,
|
||||
self._growth_interval)
|
||||
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
|
||||
self._backoff_factor, self._growth_interval)
|
||||
|
||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||
|
@ -522,11 +492,13 @@ class GradScaler(object):
|
|||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||
should be called after :meth:`update`.
|
||||
"""
|
||||
return {"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
||||
return {
|
||||
"scale": self.get_scale(),
|
||||
"growth_factor": self._growth_factor,
|
||||
"backoff_factor": self._backoff_factor,
|
||||
"growth_interval": self._growth_interval,
|
||||
"_growth_tracker": self._get_growth_tracker()
|
||||
} if self._enabled else {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""
|
||||
|
@ -572,10 +544,8 @@ class GradScaler(object):
|
|||
def _check_inf_per_device(self, optimizer):
|
||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||
|
||||
dummy_inv_scale = torch.full(
|
||||
(1,), 1.0, dtype=torch.float32, device=_scale.device)
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=_scale.device)
|
||||
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
|
||||
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
|
||||
|
||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||
|
|
Loading…
Reference in New Issue