fix format (#570)

pull/673/head
Kai Wang (Victor Kai) 2022-03-31 15:26:41 +08:00 committed by binmakeswell
parent 2a915a8b62
commit b0f708dfc1
1 changed files with 40 additions and 70 deletions

View File

@ -27,8 +27,7 @@ class _MultiDeviceReplicator(object):
def get(self, device) -> torch.Tensor: def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None) retval = self._per_device_tensors.get(device, None)
if retval is None: if retval is None:
retval = self.master.to( retval = self.master.to(device=device, non_blocking=True, copy=True)
device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval self._per_device_tensors[device] = retval
return retval return retval
@ -116,15 +115,9 @@ class GradScaler(object):
invokes the underlying ``optimizer.step()``, and other methods become no-ops. invokes the underlying ``optimizer.step()``, and other methods become no-ops.
""" """
def __init__(self, def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
init_scale=2.**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True):
if enabled and not torch.cuda.is_available(): if enabled and not torch.cuda.is_available():
warnings.warn( warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
self._enabled = False self._enabled = False
else: else:
self._enabled = enabled self._enabled = enabled
@ -142,23 +135,18 @@ class GradScaler(object):
self._init_growth_tracker = 0 self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale() # self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None self._growth_tracker = None
self._per_optimizer_states = defaultdict( self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
_refresh_per_optimizer_state)
def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, "Attempted {} but _scale is None. ".format( assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix
funcname) + fix assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix
assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(
funcname) + fix
return (self._scale, self._growth_tracker) return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self, dev): def _lazy_init_scale_growth_tracker(self, dev):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale" assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full( self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev)
(1,), self._init_scale, dtype=torch.float32, device=dev) self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
self._growth_tracker = torch.full(
(1,), self._init_growth_tracker, dtype=torch.int32, device=dev)
def scale(self, outputs): def scale(self, outputs):
""" """
@ -201,8 +189,7 @@ class GradScaler(object):
else: else:
return iterable return iterable
else: else:
raise ValueError( raise ValueError("outputs must be a Tensor or an iterable of Tensors")
"outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs) return apply_scale(outputs)
@ -216,16 +203,14 @@ class GradScaler(object):
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations. # Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict( per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad(): with torch.no_grad():
for group in optimizer.param_groups: for group in optimizer.param_groups:
for param in group["params"]: for param in group["params"]:
if param.grad is None: if param.grad is None:
continue continue
if (not allow_fp16) and param.grad.dtype == torch.float16: if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError( raise ValueError("Attempting to unscale FP16 gradients.")
"Attempting to unscale FP16 gradients.")
if param.grad.is_sparse: if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices. # is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index. # coalesce() deduplicates indices and adds all values that have the same index.
@ -238,22 +223,17 @@ class GradScaler(object):
to_unscale = param.grad to_unscale = param.grad
# TODO: is there a way to split by device and dtype without appending in the inner loop? # TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append( per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)
to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items(): for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values(): for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads, torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
per_device_found_inf.get(
device),
per_device_inv_scale.get(device)) per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group # For tensor parallel paramters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()] vals = [val for val in per_device_found_inf._per_device_tensors.values()]
coalesced = _flatten_dense_tensors(vals) coalesced = _flatten_dense_tensors(vals)
dist.all_reduce(coalesced, dist.all_reduce(coalesced, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL))
for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)): for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)):
buf.copy_(synced) buf.copy_(synced)
return per_device_found_inf._per_device_tensors return per_device_found_inf._per_device_tensors
@ -298,19 +278,16 @@ class GradScaler(object):
optimizer_state = self._per_optimizer_states[id(optimizer)] optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED: if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError( raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
"unscale_() has already been called on this optimizer since the last update().")
elif optimizer_state["stage"] is OptState.STEPPED: elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().") raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float() inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full( found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device)
(1,), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_( optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
optimizer, inv_scale, found_inf, False)
optimizer_state["stage"] = OptState.UNSCALED optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
@ -344,16 +321,14 @@ class GradScaler(object):
return optimizer.step(*args, **kwargs) return optimizer.step(*args, **kwargs)
if "closure" in kwargs: if "closure" in kwargs:
raise RuntimeError( raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
"Closure use is not currently supported if GradScaler is enabled.")
self._check_scale_growth_tracker("step") self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)] optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.STEPPED: if optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError( raise RuntimeError("step() has already been called since the last update().")
"step() has already been called since the last update().")
retval = None retval = None
@ -369,11 +344,9 @@ class GradScaler(object):
if optimizer_state["stage"] is OptState.READY: if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer) self.unscale_(optimizer)
assert len(optimizer_state["found_inf_per_device"] assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
) > 0, "No inf checks were recorded for this optimizer."
retval = self._maybe_opt_step( retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
optimizer, optimizer_state, *args, **kwargs)
optimizer_state["stage"] = OptState.STEPPED optimizer_state["stage"] = OptState.STEPPED
@ -407,35 +380,32 @@ class GradScaler(object):
if new_scale is not None: if new_scale is not None:
# Accept a new user-defined scale. # Accept a new user-defined scale.
if isinstance(new_scale, float): if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr] self._scale.fill_(new_scale) # type: ignore[union-attr]
else: else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
# type: ignore[attr-defined] # type: ignore[attr-defined]
assert isinstance(new_scale, torch.cuda.FloatTensor), reason assert isinstance(new_scale, torch.cuda.FloatTensor), reason
assert new_scale.numel() == 1, reason assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr] self._scale.copy_(new_scale) # type: ignore[union-attr]
else: else:
# Consume shared inf/nan data collected from optimizers to update the scale. # Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [found_inf.to(device=_scale.device, non_blocking=True) found_infs = [
for state in self._per_optimizer_states.values() found_inf.to(device=_scale.device, non_blocking=True)
for found_inf in state["found_inf_per_device"].values()] for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len( assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0] found_inf_combined = found_infs[0]
if len(found_infs) > 1: if len(found_infs) > 1:
for i in range(1, len(found_infs)): for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i] found_inf_combined += found_infs[i]
torch._amp_update_scale_(_scale, torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
_growth_tracker, self._backoff_factor, self._growth_interval)
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval)
# To prepare for next iteration, clear the data collected from optimizers this iteration. # To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
@ -522,11 +492,13 @@ class GradScaler(object):
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`. should be called after :meth:`update`.
""" """
return {"scale": self.get_scale(), return {
"growth_factor": self._growth_factor, "scale": self.get_scale(),
"backoff_factor": self._backoff_factor, "growth_factor": self._growth_factor,
"growth_interval": self._growth_interval, "backoff_factor": self._backoff_factor,
"_growth_tracker": self._get_growth_tracker()} if self._enabled else {} "growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()
} if self._enabled else {}
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r""" r"""
@ -572,10 +544,8 @@ class GradScaler(object):
def _check_inf_per_device(self, optimizer): def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full( dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
(1,), 1.0, dtype=torch.float32, device=_scale.device) found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)