From b0f708dfc172b60ca959f245e10145ef2b045f48 Mon Sep 17 00:00:00 2001 From: "Kai Wang (Victor Kai)" <37533040+kaiwang960112@users.noreply.github.com> Date: Thu, 31 Mar 2022 15:26:41 +0800 Subject: [PATCH] fix format (#570) --- colossalai/amp/torch_amp/_grad_scaler.py | 110 +++++++++-------------- 1 file changed, 40 insertions(+), 70 deletions(-) diff --git a/colossalai/amp/torch_amp/_grad_scaler.py b/colossalai/amp/torch_amp/_grad_scaler.py index b3ad5c084..48c7eb949 100644 --- a/colossalai/amp/torch_amp/_grad_scaler.py +++ b/colossalai/amp/torch_amp/_grad_scaler.py @@ -27,8 +27,7 @@ class _MultiDeviceReplicator(object): def get(self, device) -> torch.Tensor: retval = self._per_device_tensors.get(device, None) if retval is None: - retval = self.master.to( - device=device, non_blocking=True, copy=True) + retval = self.master.to(device=device, non_blocking=True, copy=True) self._per_device_tensors[device] = retval return retval @@ -116,15 +115,9 @@ class GradScaler(object): invokes the underlying ``optimizer.step()``, and other methods become no-ops. """ - def __init__(self, - init_scale=2.**16, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, - enabled=True): + def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): if enabled and not torch.cuda.is_available(): - warnings.warn( - "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") + warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") self._enabled = False else: self._enabled = enabled @@ -142,23 +135,18 @@ class GradScaler(object): self._init_growth_tracker = 0 # self._growth_tracker will be lazily initialized during the first call to scale() self._growth_tracker = None - self._per_optimizer_states = defaultdict( - _refresh_per_optimizer_state) + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) def _check_scale_growth_tracker(self, funcname) -> Tuple[torch.Tensor, torch.Tensor]: fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." - assert self._scale is not None, "Attempted {} but _scale is None. ".format( - funcname) + fix - assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format( - funcname) + fix + assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix return (self._scale, self._growth_tracker) def _lazy_init_scale_growth_tracker(self, dev): assert self._growth_tracker is None, "_growth_tracker initialized before _scale" - self._scale = torch.full( - (1,), self._init_scale, dtype=torch.float32, device=dev) - self._growth_tracker = torch.full( - (1,), self._init_growth_tracker, dtype=torch.int32, device=dev) + self._scale = torch.full((1,), self._init_scale, dtype=torch.float32, device=dev) + self._growth_tracker = torch.full((1,), self._init_growth_tracker, dtype=torch.int32, device=dev) def scale(self, outputs): """ @@ -201,8 +189,7 @@ class GradScaler(object): else: return iterable else: - raise ValueError( - "outputs must be a Tensor or an iterable of Tensors") + raise ValueError("outputs must be a Tensor or an iterable of Tensors") return apply_scale(outputs) @@ -216,16 +203,14 @@ class GradScaler(object): # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict # Google says mypy struggles with defaultdicts type annotations. - per_device_and_dtype_grads = defaultdict( - lambda: defaultdict(list)) # type: ignore[var-annotated] + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] with torch.no_grad(): for group in optimizer.param_groups: for param in group["params"]: if param.grad is None: continue if (not allow_fp16) and param.grad.dtype == torch.float16: - raise ValueError( - "Attempting to unscale FP16 gradients.") + raise ValueError("Attempting to unscale FP16 gradients.") if param.grad.is_sparse: # is_coalesced() == False means the sparse grad has values with duplicate indices. # coalesce() deduplicates indices and adds all values that have the same index. @@ -238,22 +223,17 @@ class GradScaler(object): to_unscale = param.grad # TODO: is there a way to split by device and dtype without appending in the inner loop? - per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append( - to_unscale) + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale) for device, per_dtype_grads in per_device_and_dtype_grads.items(): for grads in per_dtype_grads.values(): - torch._amp_foreach_non_finite_check_and_unscale_(grads, - per_device_found_inf.get( - device), + torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), per_device_inv_scale.get(device)) # For tensor parallel paramters it should be all-reduced over tensor parallel process group if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: vals = [val for val in per_device_found_inf._per_device_tensors.values()] coalesced = _flatten_dense_tensors(vals) - dist.all_reduce(coalesced, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.MODEL)) + dist.all_reduce(coalesced, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) for buf, synced in zip(vals, _unflatten_dense_tensors(coalesced, vals)): buf.copy_(synced) return per_device_found_inf._per_device_tensors @@ -298,19 +278,16 @@ class GradScaler(object): optimizer_state = self._per_optimizer_states[id(optimizer)] if optimizer_state["stage"] is OptState.UNSCALED: - raise RuntimeError( - "unscale_() has already been called on this optimizer since the last update().") + raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") elif optimizer_state["stage"] is OptState.STEPPED: raise RuntimeError("unscale_() is being called after step().") # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None inv_scale = self._scale.double().reciprocal().float() - found_inf = torch.full( - (1,), 0.0, dtype=torch.float32, device=self._scale.device) + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) - optimizer_state["found_inf_per_device"] = self._unscale_grads_( - optimizer, inv_scale, found_inf, False) + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) optimizer_state["stage"] = OptState.UNSCALED def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): @@ -344,16 +321,14 @@ class GradScaler(object): return optimizer.step(*args, **kwargs) if "closure" in kwargs: - raise RuntimeError( - "Closure use is not currently supported if GradScaler is enabled.") + raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") self._check_scale_growth_tracker("step") optimizer_state = self._per_optimizer_states[id(optimizer)] if optimizer_state["stage"] is OptState.STEPPED: - raise RuntimeError( - "step() has already been called since the last update().") + raise RuntimeError("step() has already been called since the last update().") retval = None @@ -369,11 +344,9 @@ class GradScaler(object): if optimizer_state["stage"] is OptState.READY: self.unscale_(optimizer) - assert len(optimizer_state["found_inf_per_device"] - ) > 0, "No inf checks were recorded for this optimizer." + assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." - retval = self._maybe_opt_step( - optimizer, optimizer_state, *args, **kwargs) + retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) optimizer_state["stage"] = OptState.STEPPED @@ -407,35 +380,32 @@ class GradScaler(object): if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] + self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." # type: ignore[attr-defined] assert isinstance(new_scale, torch.cuda.FloatTensor), reason assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] + self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [found_inf.to(device=_scale.device, non_blocking=True) - for state in self._per_optimizer_states.values() - for found_inf in state["found_inf_per_device"].values()] + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] - assert len( - found_infs) > 0, "No inf checks were recorded prior to update." + assert len(found_infs) > 0, "No inf checks were recorded prior to update." found_inf_combined = found_infs[0] if len(found_infs) > 1: for i in range(1, len(found_infs)): found_inf_combined += found_infs[i] - torch._amp_update_scale_(_scale, - _growth_tracker, - found_inf_combined, - self._growth_factor, - self._backoff_factor, - self._growth_interval) + torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, + self._backoff_factor, self._growth_interval) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) @@ -522,11 +492,13 @@ class GradScaler(object): If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` should be called after :meth:`update`. """ - return {"scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} + return { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker() + } if self._enabled else {} def load_state_dict(self, state_dict): r""" @@ -572,10 +544,8 @@ class GradScaler(object): def _check_inf_per_device(self, optimizer): _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") - dummy_inv_scale = torch.full( - (1,), 1.0, dtype=torch.float32, device=_scale.device) - found_inf = torch.full( - (1,), 0.0, dtype=torch.float32, device=_scale.device) + dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)