Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

181 lines
7.5 KiB

# Disclaimer: Modified from https://github.com/NUS-HPC-AI-Lab/pytorch-lamb/blob/master/optim/lamb.py
from typing import Dict, Optional
import torch
import torch.distributed as dist
from colossalai.interface.optimizer import DistributedOptim
from colossalai.tensor.d_tensor import is_distributed_tensor
__all__ = ["DistributedLamb"]
class DistributedLamb(DistributedOptim):
r"""Implements the Lamb algorithm, with extra support for ZeRO 2 and Tensor Parallel.
Proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
It's recommended to use this with HybridParallelPlugin/ZeRO plugin and booster,
which will take care of setup_distributed.
Example with 4 devices:
>>> optim = DistributedLamb(model.parameters(), lr=1e-3)
>>> proc_mesh = ProcessGroupMesh(tp_size, zero_size)
>>> tp_group = proc_mesh.get_group_along_axis(0)
>>> dp_group = proc_mesh.get_group_along_axis(1)
>>> optim.setup_distributed(tp_group, dp_group)
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0,
bias_correction=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
# self.setup_distributed(tp_group, dp_group)
self.shard_to_working_param = {}
self.tp_size = self.dp_size = 1
self.is_zero = False
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super().__init__(params, defaults)
def setup_distributed(
self,
tp_group: Optional[dist.ProcessGroup] = None,
dp_group: Optional[dist.ProcessGroup] = None,
shard_to_working_param: Optional[Dict] = {},
padding_map=None,
is_zero: Optional[bool] = False,
):
"""Assign process groups for TP and ZeRO 2.
Arguments:
tp_group (dist.ProcessGroup): Tensor Parallel process group
dp_group (dist.ProcessGroup): ZeRO 2 process group
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded.
This maps from id(view) to working params used in forward & backward.
padding_map: An empty interface placeholder
is_zero (bool): Whether to use ZeRO 2.
"""
self.tp_group = tp_group
self.dp_group = dp_group
if tp_group is not None:
self.tp_size = dist.get_world_size(tp_group)
if dp_group is not None:
self.dp_size = dist.get_world_size(dp_group)
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {}
self.is_zero = is_zero
self.is_dist = {}
# Cache parameter layout
for group in self.param_groups:
for p in group["params"]:
# w/o ZeRO: master param = working param
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p)
self.is_dist[p] = (
is_distributed_tensor(p)
if self.dp_size <= 1
else is_distributed_tensor(self.shard_to_working_param.get(id(p), None))
)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
scaled_lr = group["lr"]
if group["bias_correction"]:
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
# Apply debiasing to lr to avoid broadcast
scaled_lr *= (bias_correction2**0.5) / bias_correction1
# exp_avg.div_(bias_correction1)
# exp_avg_sq.div_(bias_correction2)
update = exp_avg / exp_avg_sq.sqrt().add(group["eps"])
if group["weight_decay"] != 0:
update.add_(p.data, alpha=group["weight_decay"])
# Compute global layer-wise trust ratio
if self.is_dist[p] or self.is_zero:
p_local = p
g_sum = (update**2).sum()
if self.dp_size > 1 and self.is_zero:
# ZeRO 2 doesn't shard param. Compute full param norm w/o communication.
dist.all_reduce(g_sum, group=self.dp_group)
p_local = self.shard_to_working_param[id(p)]
w_sum = (p_local**2).sum()
sums = torch.stack([w_sum, g_sum])
# Get global l2 norms
if self.tp_size > 1:
dist.all_reduce(sums, group=self.tp_group)
w_norm, g_norm = sums.sqrt().chunk(2)
else:
# Fall back to vanilla Lamb
w_norm = torch.norm(p)
g_norm = torch.norm(update)
trust_ratio = torch.where(w_norm > 0 and g_norm > 0, (w_norm / g_norm), 1.0).item()
scaled_lr *= trust_ratio
p.data.add_(update, alpha=-scaled_lr)
return loss