mirror of https://github.com/hpcaitech/ColossalAI
182 lines
7.5 KiB
Python
182 lines
7.5 KiB
Python
# Disclaimer: Modified from https://github.com/NUS-HPC-AI-Lab/pytorch-lamb/blob/master/optim/lamb.py
|
|
|
|
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from colossalai.interface.optimizer import DistributedOptim
|
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
|
|
|
__all__ = ["DistributedLamb"]
|
|
|
|
|
|
class DistributedLamb(DistributedOptim):
|
|
r"""Implements the Lamb algorithm, with extra support for ZeRO 2 and Tensor Parallel.
|
|
Proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
|
It's recommended to use this with HybridParallelPlugin/ZeRO plugin and booster,
|
|
which will take care of setup_distributed.
|
|
Example with 4 devices:
|
|
>>> optim = DistributedLamb(model.parameters(), lr=1e-3)
|
|
>>> proc_mesh = ProcessGroupMesh(tp_size, zero_size)
|
|
>>> tp_group = proc_mesh.get_group_along_axis(0)
|
|
>>> dp_group = proc_mesh.get_group_along_axis(1)
|
|
>>> optim.setup_distributed(tp_group, dp_group)
|
|
|
|
Arguments:
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
parameter groups
|
|
lr (float, optional): learning rate (default: 1e-3)
|
|
betas (Tuple[float, float], optional): coefficients used for computing
|
|
running averages of gradient and its square (default: (0.9, 0.999))
|
|
eps (float, optional): term added to the denominator to improve
|
|
numerical stability (default: 1e-8)
|
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
|
|
https://arxiv.org/abs/1904.00962
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr=1e-3,
|
|
betas=(0.9, 0.999),
|
|
eps=1e-6,
|
|
weight_decay=0,
|
|
bias_correction=True,
|
|
):
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
|
|
# self.setup_distributed(tp_group, dp_group)
|
|
self.shard_to_working_param = {}
|
|
self.tp_size = self.dp_size = 1
|
|
self.is_zero = False
|
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
|
super().__init__(params, defaults)
|
|
|
|
def setup_distributed(
|
|
self,
|
|
tp_group: Optional[dist.ProcessGroup] = None,
|
|
dp_group: Optional[dist.ProcessGroup] = None,
|
|
shard_to_working_param: Optional[Dict] = {},
|
|
padding_map=None,
|
|
is_zero: Optional[bool] = False,
|
|
):
|
|
"""Assign process groups for TP and ZeRO 2.
|
|
Arguments:
|
|
tp_group (dist.ProcessGroup): Tensor Parallel process group
|
|
dp_group (dist.ProcessGroup): ZeRO 2 process group
|
|
shard_to_working_param (Dict): ZeRO 2 feeds the optimizer a sharded param view as grads are sharded.
|
|
This maps from id(view) to working params used in forward & backward.
|
|
padding_map: An empty interface placeholder
|
|
is_zero (bool): Whether to use ZeRO 2.
|
|
"""
|
|
self.tp_group = tp_group
|
|
self.dp_group = dp_group
|
|
if tp_group is not None:
|
|
self.tp_size = dist.get_world_size(tp_group)
|
|
if dp_group is not None:
|
|
self.dp_size = dist.get_world_size(dp_group)
|
|
|
|
self.shard_to_working_param = shard_to_working_param if shard_to_working_param is not None else {}
|
|
self.is_zero = is_zero
|
|
self.is_dist = {}
|
|
# Cache parameter layout
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
# w/o ZeRO: master param = working param
|
|
self.shard_to_working_param[id(p)] = self.shard_to_working_param.get(id(p), p)
|
|
self.is_dist[p] = (
|
|
is_distributed_tensor(p)
|
|
if self.dp_size <= 1
|
|
else is_distributed_tensor(self.shard_to_working_param.get(id(p), None))
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
continue
|
|
grad = p.grad.data
|
|
if grad.is_sparse:
|
|
raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
|
|
|
|
state = self.state[p]
|
|
# State initialization
|
|
if len(state) == 0:
|
|
state["step"] = 0
|
|
# Exponential moving average of gradient values
|
|
state["exp_avg"] = torch.zeros_like(p.data)
|
|
# Exponential moving average of squared gradient values
|
|
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
|
|
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
|
beta1, beta2 = group["betas"]
|
|
|
|
state["step"] += 1
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
# m_t
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
# v_t
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
|
scaled_lr = group["lr"]
|
|
if group["bias_correction"]:
|
|
bias_correction1 = 1 - beta1 ** state["step"]
|
|
bias_correction2 = 1 - beta2 ** state["step"]
|
|
# Apply debiasing to lr to avoid broadcast
|
|
scaled_lr *= (bias_correction2**0.5) / bias_correction1
|
|
# exp_avg.div_(bias_correction1)
|
|
# exp_avg_sq.div_(bias_correction2)
|
|
|
|
update = exp_avg / exp_avg_sq.sqrt().add(group["eps"])
|
|
if group["weight_decay"] != 0:
|
|
update.add_(p.data, alpha=group["weight_decay"])
|
|
|
|
# Compute global layer-wise trust ratio
|
|
if self.is_dist[p] or self.is_zero:
|
|
p_local = p
|
|
g_sum = (update**2).sum()
|
|
if self.dp_size > 1 and self.is_zero:
|
|
# ZeRO 2 doesn't shard param. Compute full param norm w/o communication.
|
|
dist.all_reduce(g_sum, group=self.dp_group)
|
|
p_local = self.shard_to_working_param[id(p)]
|
|
|
|
w_sum = (p_local**2).sum()
|
|
sums = torch.stack([w_sum, g_sum])
|
|
|
|
# Get global l2 norms
|
|
if self.tp_size > 1:
|
|
dist.all_reduce(sums, group=self.tp_group)
|
|
w_norm, g_norm = sums.sqrt().chunk(2)
|
|
else:
|
|
# Fall back to vanilla Lamb
|
|
w_norm = torch.norm(p)
|
|
g_norm = torch.norm(update)
|
|
|
|
trust_ratio = torch.where(w_norm > 0 and g_norm > 0, (w_norm / g_norm), 1.0).item()
|
|
|
|
scaled_lr *= trust_ratio
|
|
p.data.add_(update, alpha=-scaled_lr)
|
|
|
|
return loss
|