mirror of https://github.com/hpcaitech/ColossalAI
fix format (#570)
parent
2a915a8b62
commit
b0f708dfc1
|
@ -27,8 +27,7 @@ class _MultiDeviceReplicator(object):
|
||||||
def get(self, device) -> torch.Tensor:
|
def get(self, device) -> torch.Tensor:
|
||||||
retval = self._per_device_tensors.get(device, None)
|
retval = self._per_device_tensors.get(device, None)
|
||||||
if retval is None:
|
if retval is None:
|
||||||
retval = self.master.to(
|
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
||||||
device=device, non_blocking=True, copy=True)
|
|
||||||
self._per_device_tensors[device] = retval
|
self._per_device_tensors[device] = retval
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
@ -116,15 +115,9 @@ class GradScaler(object):
|
||||||
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
|
||||||
init_scale=2.**16,
|
|
||||||
growth_factor=2.0,
|
|
||||||
backoff_factor=0.5,
|
|
||||||
growth_interval=2000,
|
|
||||||
enabled=True):
|
|
||||||
if enabled and not torch.cuda.is_available():
|
if enabled and not torch.cuda.is_available():
|
||||||
warnings.warn(
|
warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
|
||||||
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
|
|
||||||
self._enabled = False
|
self._enabled = False
|
||||||
else:
|
else:
|
||||||
self._enabled = enabled
|
self._enabled = enabled
|
||||||
|
@ -142,23 +135,18 @@ class GradScaler(object):
|
||||||
self._init_growth_tracker = 0
|
self._init_growth_tracker = 0
|
||||||
# self._growth_tracker will be lazily initialized during the first call to scale()
|
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||||
self._growth_tracker = None
|
self._growth_tracker = None
|
||||||
self._per_optimizer_states = defaultdict(
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
_refresh_per_optimizer_state)
|
|
||||||
|
|
||||||
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||||
assert self._scale is not None, "Attempted {} but _scale is None. ".format(
|
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
|
||||||
funcname) + fix
|
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
|
||||||
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(
|
|
||||||
funcname) + fix
|
|
||||||
return (self._scale, self._growth_tracker)
|
return (self._scale, self._growth_tracker)
|
||||||
|
|
||||||
def _lazy_init_scale_growth_tracker(self, dev):
|
def _lazy_init_scale_growth_tracker(self, dev):
|
||||||
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||||
self._scale = torch.full(
|
self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev)
|
||||||
(1,), self._init_scale, dtype=torch.float32, device=dev)
|
self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
||||||
self._growth_tracker = torch.full(
|
|
||||||
(1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
|
|
||||||
|
|
||||||
def scale(self, outputs):
|
def scale(self, outputs):
|
||||||
"""
|
"""
|
||||||
|
@ -201,8 +189,7 @@ class GradScaler(object):
|
||||||
else:
|
else:
|
||||||
return iterable
|
return iterable
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||||
"outputs must be a Tensor or an iterable of Tensors")
|
|
||||||
|
|
||||||
return apply_scale(outputs)
|
return apply_scale(outputs)
|
||||||
|
|
||||||
|
@ -216,16 +203,14 @@ class GradScaler(object):
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||||
# Google says mypy struggles with defaultdicts type annotations.
|
# Google says mypy struggles with defaultdicts type annotations.
|
||||||
per_device_and_dtype_grads = defaultdict(
|
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||||
lambda: defaultdict(list)) # type: ignore[var-annotated]
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
for param in group["params"]:
|
for param in group["params"]:
|
||||||
if param.grad is None:
|
if param.grad is None:
|
||||||
continue
|
continue
|
||||||
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||||
raise ValueError(
|
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||||
"Attempting to unscale FP16 gradients.")
|
|
||||||
if param.grad.is_sparse:
|
if param.grad.is_sparse:
|
||||||
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||||
# coalesce() deduplicates indices and adds all values that have the same index.
|
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||||
|
@ -238,22 +223,17 @@ class GradScaler(object):
|
||||||
to_unscale = param.grad
|
to_unscale = param.grad
|
||||||
|
|
||||||
# TODO: is there a way to split by device and dtype without appending in the inner loop?
|
# TODO: is there a way to split by device and dtype without appending in the inner loop?
|
||||||
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
|
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)
|
||||||
to_unscale)
|
|
||||||
|
|
||||||
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||||
for grads in per_dtype_grads.values():
|
for grads in per_dtype_grads.values():
|
||||||
torch._amp_foreach_non_finite_check_and_unscale_(grads,
|
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
|
||||||
per_device_found_inf.get(
|
|
||||||
device),
|
|
||||||
per_device_inv_scale.get(device))
|
per_device_inv_scale.get(device))
|
||||||
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
||||||
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||||
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
|
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
|
||||||
coalesced = _flatten_dense_tensors(vals)
|
coalesced = _flatten_dense_tensors(vals)
|
||||||
dist.all_reduce(coalesced,
|
dist.all_reduce(coalesced, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||||
op=dist.ReduceOp.MAX,
|
|
||||||
group=gpc.get_group(ParallelMode.MODEL))
|
|
||||||
for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)):
|
for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)):
|
||||||
buf.copy_(synced)
|
buf.copy_(synced)
|
||||||
return per_device_found_inf._per_device_tensors
|
return per_device_found_inf._per_device_tensors
|
||||||
|
@ -298,19 +278,16 @@ class GradScaler(object):
|
||||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
if optimizer_state["stage"] is OptState.UNSCALED:
|
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||||
raise RuntimeError(
|
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
|
||||||
"unscale_() has already been called on this optimizer since the last update().")
|
|
||||||
elif optimizer_state["stage"] is OptState.STEPPED:
|
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||||
raise RuntimeError("unscale_() is being called after step().")
|
raise RuntimeError("unscale_() is being called after step().")
|
||||||
|
|
||||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||||
assert self._scale is not None
|
assert self._scale is not None
|
||||||
inv_scale = self._scale.double().reciprocal().float()
|
inv_scale = self._scale.double().reciprocal().float()
|
||||||
found_inf = torch.full(
|
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device)
|
|
||||||
|
|
||||||
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
|
||||||
optimizer, inv_scale, found_inf, False)
|
|
||||||
optimizer_state["stage"] = OptState.UNSCALED
|
optimizer_state["stage"] = OptState.UNSCALED
|
||||||
|
|
||||||
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||||
|
@ -344,16 +321,14 @@ class GradScaler(object):
|
||||||
return optimizer.step(*args, **kwargs)
|
return optimizer.step(*args, **kwargs)
|
||||||
|
|
||||||
if "closure" in kwargs:
|
if "closure" in kwargs:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
|
||||||
"Closure use is not currently supported if GradScaler is enabled.")
|
|
||||||
|
|
||||||
self._check_scale_growth_tracker("step")
|
self._check_scale_growth_tracker("step")
|
||||||
|
|
||||||
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
if optimizer_state["stage"] is OptState.STEPPED:
|
if optimizer_state["stage"] is OptState.STEPPED:
|
||||||
raise RuntimeError(
|
raise RuntimeError("step() has already been called since the last update().")
|
||||||
"step() has already been called since the last update().")
|
|
||||||
|
|
||||||
retval = None
|
retval = None
|
||||||
|
|
||||||
|
@ -369,11 +344,9 @@ class GradScaler(object):
|
||||||
if optimizer_state["stage"] is OptState.READY:
|
if optimizer_state["stage"] is OptState.READY:
|
||||||
self.unscale_(optimizer)
|
self.unscale_(optimizer)
|
||||||
|
|
||||||
assert len(optimizer_state["found_inf_per_device"]
|
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
|
||||||
) > 0, "No inf checks were recorded for this optimizer."
|
|
||||||
|
|
||||||
retval = self._maybe_opt_step(
|
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
|
||||||
optimizer, optimizer_state, *args, **kwargs)
|
|
||||||
|
|
||||||
optimizer_state["stage"] = OptState.STEPPED
|
optimizer_state["stage"] = OptState.STEPPED
|
||||||
|
|
||||||
|
@ -407,35 +380,32 @@ class GradScaler(object):
|
||||||
if new_scale is not None:
|
if new_scale is not None:
|
||||||
# Accept a new user-defined scale.
|
# Accept a new user-defined scale.
|
||||||
if isinstance(new_scale, float):
|
if isinstance(new_scale, float):
|
||||||
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||||
else:
|
else:
|
||||||
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
|
||||||
# type: ignore[attr-defined]
|
# type: ignore[attr-defined]
|
||||||
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
|
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
|
||||||
assert new_scale.numel() == 1, reason
|
assert new_scale.numel() == 1, reason
|
||||||
assert new_scale.requires_grad is False, reason
|
assert new_scale.requires_grad is False, reason
|
||||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||||
else:
|
else:
|
||||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||||
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||||
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
|
found_infs = [
|
||||||
for state in self._per_optimizer_states.values()
|
found_inf.to(device=_scale.device, non_blocking=True)
|
||||||
for found_inf in state["found_inf_per_device"].values()]
|
for state in self._per_optimizer_states.values()
|
||||||
|
for found_inf in state["found_inf_per_device"].values()
|
||||||
|
]
|
||||||
|
|
||||||
assert len(
|
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||||
found_infs) > 0, "No inf checks were recorded prior to update."
|
|
||||||
|
|
||||||
found_inf_combined = found_infs[0]
|
found_inf_combined = found_infs[0]
|
||||||
if len(found_infs) > 1:
|
if len(found_infs) > 1:
|
||||||
for i in range(1, len(found_infs)):
|
for i in range(1, len(found_infs)):
|
||||||
found_inf_combined += found_infs[i]
|
found_inf_combined += found_infs[i]
|
||||||
|
|
||||||
torch._amp_update_scale_(_scale,
|
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
|
||||||
_growth_tracker,
|
self._backoff_factor, self._growth_interval)
|
||||||
found_inf_combined,
|
|
||||||
self._growth_factor,
|
|
||||||
self._backoff_factor,
|
|
||||||
self._growth_interval)
|
|
||||||
|
|
||||||
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||||
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
@ -522,11 +492,13 @@ class GradScaler(object):
|
||||||
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||||
should be called after :meth:`update`.
|
should be called after :meth:`update`.
|
||||||
"""
|
"""
|
||||||
return {"scale": self.get_scale(),
|
return {
|
||||||
"growth_factor": self._growth_factor,
|
"scale": self.get_scale(),
|
||||||
"backoff_factor": self._backoff_factor,
|
"growth_factor": self._growth_factor,
|
||||||
"growth_interval": self._growth_interval,
|
"backoff_factor": self._backoff_factor,
|
||||||
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
|
"growth_interval": self._growth_interval,
|
||||||
|
"_growth_tracker": self._get_growth_tracker()
|
||||||
|
} if self._enabled else {}
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
r"""
|
r"""
|
||||||
|
@ -572,10 +544,8 @@ class GradScaler(object):
|
||||||
def _check_inf_per_device(self, optimizer):
|
def _check_inf_per_device(self, optimizer):
|
||||||
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||||
|
|
||||||
dummy_inv_scale = torch.full(
|
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
|
||||||
(1,), 1.0, dtype=torch.float32, device=_scale.device)
|
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
|
||||||
found_inf = torch.full(
|
|
||||||
(1,), 0.0, dtype=torch.float32, device=_scale.device)
|
|
||||||
|
|
||||||
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
|
||||||
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||||
|
|
Loading…
Reference in New Issue