# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch from torch.optim import Optimizer __all__ = ["Adafactor"] # Adafactor class Adafactor(Optimizer): def __init__( self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True, warmup_init=False, ): lr = None if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: raise ValueError("`warmup_init=True` requires `relative_step=True`") defaults = { "lr": lr, "eps": eps, "clip_threshold": clip_threshold, "decay_rate": decay_rate, "beta1": beta1, "weight_decay": weight_decay, "scale_parameter": scale_parameter, "relative_step": relative_step, "warmup_init": warmup_init, } super().__init__(params, defaults) @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] if param_group["relative_step"]: min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) param_scale = 1.0 if param_group["scale_parameter"]: param_scale = max(param_group["eps"][1], param_state["RMS"]) return param_scale * rel_step_sz @staticmethod def _get_options(param_group, param_shape): factored = len(param_shape) >= 2 use_first_moment = param_group["beta1"] is not None return factored, use_first_moment @staticmethod def _rms(tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) @staticmethod def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) @torch.no_grad() def step(self, closure=None): """ Performs a single optimization step Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() """ param_groups: Dict { "params":[weight, bias] "lr" "eps" "clip_threshold" "decay_rate" "beta1" "weight_decay" "scale_parameter" "relative_step" "warmup_init" } """ for group in self.param_groups: # update weight & bias for p in group["params"]: if p.grad is None: continue """ # grad shape is same as weigh / bias """ grad = p.grad if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") """ p is weight state {'step', 'exp_avg_sq_row', 'exp_avg_sq_col', 'RMS' } p is bias state {'step', 'exp_avg_sq', 'RMS' } """ state = self.state[p] grad_shape = grad.shape factored, use_first_moment = self._get_options(group, grad_shape) # State Initialization if len(state) == 0: state["step"] = 0 if use_first_moment: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad) if factored: state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], device=grad.device) state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:], device=grad.device) else: state["exp_avg_sq"] = torch.zeros_like(grad) state["RMS"] = 0 else: if use_first_moment: state["exp_avg"] = state["exp_avg"] if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"] state["exp_avg_sq_col"] = state["exp_avg_sq_col"] else: state["exp_avg_sq"] = state["exp_avg_sq"] state["step"] += 1 # state["RMS"] = self._rms(p_data_fp32) lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] # Exponential average of row indexes exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) # Exponential average of columns indexes exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) # Approximation of exponential moving average of square of gradient update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) # RMS update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) update = exp_avg if group["weight_decay"] != 0: p.add_(p, alpha=(-group["weight_decay"] * lr)) p.add_(-update) return loss