mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
151 lines
5.8 KiB
151 lines
5.8 KiB
7 months ago
|
# Copied from https://github.com/yangluo7/CAME/blob/master/came_pytorch/CAME.py
|
||
|
import torch
|
||
|
import torch.optim
|
||
|
|
||
|
|
||
|
class CAME(torch.optim.Optimizer):
|
||
|
"""Implements CAME algorithm.
|
||
|
This implementation is based on:
|
||
|
`CAME: Confidence-guided Adaptive Memory Efficient Optimization`
|
||
|
Args:
|
||
|
params (iterable): iterable of parameters to optimize or dicts defining
|
||
|
parameter groups
|
||
|
lr (float, optional): external learning rate (default: None)
|
||
|
eps (tuple[float, float]): regularization constants for square gradient
|
||
|
and instability respectively (default: (1e-30, 1e-16))
|
||
|
clip_threshold (float): threshold of root-mean-square of
|
||
|
final gradient update (default: 1.0)
|
||
|
betas (tuple[float, float, float]): coefficient used for computing running averages of
|
||
|
update, square gradient and instability (default: (0.9, 0.999, 0.9999)))
|
||
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
params,
|
||
|
lr=None,
|
||
|
eps=(1e-30, 1e-16),
|
||
|
clip_threshold=1.0,
|
||
|
betas=(0.9, 0.999, 0.9999),
|
||
|
weight_decay=0.0,
|
||
|
):
|
||
|
assert lr > 0.0
|
||
|
assert all([0.0 <= beta <= 1.0 for beta in betas])
|
||
|
|
||
|
defaults = dict(
|
||
|
lr=lr,
|
||
|
eps=eps,
|
||
|
clip_threshold=clip_threshold,
|
||
|
betas=betas,
|
||
|
weight_decay=weight_decay,
|
||
|
)
|
||
|
super(CAME, self).__init__(params, defaults)
|
||
|
|
||
|
@property
|
||
|
def supports_memory_efficient_fp16(self):
|
||
|
return True
|
||
|
|
||
|
@property
|
||
|
def supports_flat_params(self):
|
||
|
return False
|
||
|
|
||
|
def _get_options(self, param_shape):
|
||
|
factored = len(param_shape) >= 2
|
||
|
return factored
|
||
|
|
||
|
def _rms(self, tensor):
|
||
|
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
||
|
|
||
|
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
|
||
|
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
||
|
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||
|
return torch.mul(r_factor, c_factor)
|
||
|
|
||
|
def step(self, closure=None):
|
||
|
"""Performs a single optimization step.
|
||
|
Args:
|
||
|
closure (callable, optional): A closure that reevaluates the model
|
||
|
and returns the loss.
|
||
|
"""
|
||
|
loss = None
|
||
|
if closure is not None:
|
||
|
loss = closure()
|
||
|
|
||
|
for group in self.param_groups:
|
||
|
for p in group["params"]:
|
||
|
if p.grad is None:
|
||
|
continue
|
||
|
grad = p.grad
|
||
|
if grad.is_sparse:
|
||
|
raise RuntimeError("CAME does not support sparse gradients.")
|
||
|
|
||
|
state = self.state[p]
|
||
|
grad_shape = grad.shape
|
||
|
|
||
|
factored = self._get_options(grad_shape)
|
||
|
# State Initialization
|
||
|
if len(state) == 0:
|
||
|
state["step"] = 0
|
||
|
|
||
|
state["exp_avg"] = torch.zeros_like(grad)
|
||
|
if factored:
|
||
|
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1], dtype=p.dtype, device=p.device)
|
||
|
state["exp_avg_sq_col"] = torch.zeros(
|
||
|
grad_shape[:-2] + grad_shape[-1:], dtype=p.dtype, device=p.device
|
||
|
)
|
||
|
|
||
|
state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1], dtype=p.dtype, device=p.device)
|
||
|
state["exp_avg_res_col"] = torch.zeros(
|
||
|
grad_shape[:-2] + grad_shape[-1:], dtype=p.dtype, device=p.device
|
||
|
)
|
||
|
else:
|
||
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||
|
|
||
|
state["step"] += 1
|
||
|
|
||
|
update = (grad**2) + group["eps"][0]
|
||
|
|
||
|
if factored:
|
||
|
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||
|
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||
|
|
||
|
exp_avg_sq_row.mul_(group["betas"][1]).add_(update.mean(dim=-1), alpha=1.0 - group["betas"][1])
|
||
|
exp_avg_sq_col.mul_(group["betas"][1]).add_(update.mean(dim=-2), alpha=1.0 - group["betas"][1])
|
||
|
|
||
|
# Approximation of exponential moving average of square of gradient
|
||
|
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||
|
update.mul_(grad)
|
||
|
else:
|
||
|
exp_avg_sq = state["exp_avg_sq"]
|
||
|
|
||
|
exp_avg_sq.mul_(group["betas"][1]).add_(update, alpha=1.0 - group["betas"][1])
|
||
|
update = exp_avg_sq.rsqrt().mul_(grad)
|
||
|
|
||
|
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||
|
|
||
|
exp_avg = state["exp_avg"]
|
||
|
exp_avg.mul_(group["betas"][0]).add_(update, alpha=1 - group["betas"][0])
|
||
|
|
||
|
# Confidence-guided strategy
|
||
|
# Calculation of instability
|
||
|
res = (update - exp_avg) ** 2 + group["eps"][1]
|
||
|
|
||
|
if factored:
|
||
|
exp_avg_res_row = state["exp_avg_res_row"]
|
||
|
exp_avg_res_col = state["exp_avg_res_col"]
|
||
|
exp_avg_res_row.mul_(group["betas"][2]).add_(res.mean(dim=-1), alpha=1.0 - group["betas"][2])
|
||
|
exp_avg_res_col.mul_(group["betas"][2]).add_(res.mean(dim=-2), alpha=1.0 - group["betas"][2])
|
||
|
|
||
|
# Approximation of exponential moving average of instability
|
||
|
res_approx = self._approx_sq_grad(exp_avg_res_row, exp_avg_res_col)
|
||
|
update = res_approx.mul_(exp_avg)
|
||
|
else:
|
||
|
update = exp_avg.clone()
|
||
|
|
||
|
if group["weight_decay"] != 0:
|
||
|
p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
|
||
|
update.mul_(group["lr"])
|
||
|
p.data.add_(-update)
|
||
|
|
||
|
return loss
|