fix format (#570)

pull/673/head
Kai Wang (Victor Kai) 2022-03-31 15:26:41 +08:00 committed by binmakeswell
parent 2a915a8b62
commit b0f708dfc1
1 changed files with 40 additions and 70 deletions

View File

@ -27,8 +27,7 @@ class _MultiDeviceReplicator(object):
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(
device=device, non_blocking=True, copy=True)
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
@ -116,15 +115,9 @@ class GradScaler(object):
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
"""
def __init__(self,
init_scale=2.**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True):
def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
if enabled and not torch.cuda.is_available():
warnings.warn(
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
self._enabled = False
else:
self._enabled = enabled
@ -142,23 +135,18 @@ class GradScaler(object):
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(
_refresh_per_optimizer_state)
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, "Attempted {} but _scale is None. ".format(
funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(
funcname) + fix
assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self, dev):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full(
(1,), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full(
(1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
def scale(self, outputs):
"""
@ -201,8 +189,7 @@ class GradScaler(object):
else:
return iterable
else:
raise ValueError(
"outputs must be a Tensor or an iterable of Tensors")
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
@ -216,16 +203,14 @@ class GradScaler(object):
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(
lambda: defaultdict(list)) # type: ignore[var-annotated]
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError(
"Attempting to unscale FP16 gradients.")
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
@ -238,22 +223,17 @@ class GradScaler(object):
to_unscale = param.grad
# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(
to_unscale)
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads,
per_device_found_inf.get(
device),
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
coalesced = _flatten_dense_tensors(vals)
dist.all_reduce(coalesced,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL))
dist.all_reduce(coalesced, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)):
buf.copy_(synced)
return per_device_found_inf._per_device_tensors
@ -298,19 +278,16 @@ class GradScaler(object):
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update().")
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
@ -344,16 +321,14 @@ class GradScaler(object):
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError(
"Closure use is not currently supported if GradScaler is enabled.")
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError(
"step() has already been called since the last update().")
raise RuntimeError("step() has already been called since the last update().")
retval = None
@ -369,11 +344,9 @@ class GradScaler(object):
if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer)
assert len(optimizer_state["found_inf_per_device"]
) > 0, "No inf checks were recorded for this optimizer."
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
retval = self._maybe_opt_step(
optimizer, optimizer_state, *args, **kwargs)
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
optimizer_state["stage"] = OptState.STEPPED
@ -407,35 +380,32 @@ class GradScaler(object):
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
# type: ignore[attr-defined]
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()]
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(
found_infs) > 0, "No inf checks were recorded prior to update."
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
torch._amp_update_scale_(_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval)
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
self._backoff_factor, self._growth_interval)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
@ -522,11 +492,13 @@ class GradScaler(object):
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return {"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {}
return {
"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()
} if self._enabled else {}
def load_state_dict(self, state_dict):
r"""
@ -572,10 +544,8 @@ class GradScaler(object):
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full(
(1,), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=_scale.device)
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)