2023-07-25 16:53:57 +00:00
|
|
|
from typing import Dict, List
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import Tensor
|
|
|
|
from torch.nn import Parameter
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
|
|
from colossalai.interface import OptimizerWrapper
|
|
|
|
|
|
|
|
from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
|
|
|
|
|
|
|
|
|
|
|
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
working_params: List[Parameter],
|
|
|
|
initial_scale: float = 2**16,
|
|
|
|
min_scale: float = 1,
|
|
|
|
growth_factor: float = 2,
|
|
|
|
backoff_factor: float = 0.5,
|
|
|
|
growth_interval: int = 1000,
|
|
|
|
hysteresis: int = 2,
|
|
|
|
max_scale: float = 2**32) -> None:
|
|
|
|
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
|
|
|
max_scale)
|
|
|
|
self.params = working_params
|
|
|
|
|
|
|
|
def check_local_overflow(self) -> bool:
|
|
|
|
for p in self.params:
|
|
|
|
if p.grad is not None and not torch.isfinite(p.grad).all():
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
class MixedPrecisionOptimizer(OptimizerWrapper):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
optim: Optimizer,
|
|
|
|
precision: str = 'fp16',
|
|
|
|
initial_scale: float = 2**16,
|
|
|
|
min_scale: float = 1,
|
|
|
|
growth_factor: float = 2,
|
|
|
|
backoff_factor: float = 0.5,
|
|
|
|
growth_interval: int = 1000,
|
|
|
|
hysteresis: int = 2,
|
|
|
|
max_scale: float = 2**32,
|
|
|
|
max_norm: float = 0.0):
|
|
|
|
super().__init__(optim)
|
|
|
|
if precision == 'fp16':
|
|
|
|
working_params = []
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
for p in group['params']:
|
|
|
|
working_params.append(p)
|
|
|
|
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
|
|
|
|
initial_scale=initial_scale,
|
|
|
|
min_scale=min_scale,
|
|
|
|
growth_factor=growth_factor,
|
|
|
|
backoff_factor=backoff_factor,
|
|
|
|
growth_interval=growth_interval,
|
|
|
|
hysteresis=hysteresis,
|
|
|
|
max_scale=max_scale)
|
|
|
|
elif precision == 'bf16':
|
|
|
|
self.mixed_precision = BF16MixedPrecisionMixin()
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported precision: {precision}')
|
|
|
|
if max_norm > 0.0:
|
|
|
|
raise NotImplementedError('max_norm is not supported yet.')
|
|
|
|
self.max_norm = max_norm
|
|
|
|
self.working_to_master_map: Dict[Parameter, Tensor] = {}
|
|
|
|
self.master_to_working_map: Dict[Tensor, Parameter] = {}
|
|
|
|
|
|
|
|
# create master weights
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
master_params = []
|
|
|
|
for p in group['params']:
|
|
|
|
if p.requires_grad:
|
|
|
|
master_p = p
|
|
|
|
if p.dtype != torch.float:
|
|
|
|
master_p = p.detach().float()
|
|
|
|
self.working_to_master_map[p] = master_p
|
|
|
|
self.master_to_working_map[master_p] = p
|
|
|
|
master_params.append(master_p)
|
|
|
|
group['params'] = master_params
|
|
|
|
|
|
|
|
def backward(self, loss: Tensor, *args, **kwargs):
|
|
|
|
loss = self.mixed_precision.pre_backward(loss)
|
|
|
|
loss.backward(*args, **kwargs)
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
|
|
|
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
|
|
|
tensor.backward(grad)
|
|
|
|
|
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
|
|
for p in self.working_to_master_map.keys():
|
|
|
|
p.grad = None
|
|
|
|
self.mixed_precision.pre_zero_grad()
|
|
|
|
return super().zero_grad(*args, **kwargs)
|
|
|
|
|
|
|
|
def _unscale_and_clip_grads(self, total_norm: float) -> None:
|
|
|
|
div_scale = 1.0
|
|
|
|
if self.mixed_precision is not None:
|
|
|
|
div_scale = self.mixed_precision.get_grad_div_scale()
|
|
|
|
|
|
|
|
if self.max_norm > 0.:
|
|
|
|
# norm is in fact norm*scale
|
|
|
|
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
|
|
|
if clip > 1:
|
|
|
|
div_scale = clip * div_scale
|
|
|
|
|
|
|
|
for group in self.param_groups:
|
|
|
|
for p in group['params']:
|
|
|
|
if p.grad is None:
|
|
|
|
continue
|
|
|
|
p.grad.data.mul_(1. / div_scale)
|
|
|
|
|
|
|
|
def _compute_grad_norm(self) -> float:
|
|
|
|
if self.max_norm <= 0.:
|
|
|
|
return 0.
|
|
|
|
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
|
|
|
|
if len(grads) == 0:
|
|
|
|
return 0.
|
|
|
|
device = grads[0].device
|
|
|
|
# TODO(ver217): support tp
|
|
|
|
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
|
|
|
return total_norm.item()
|
|
|
|
|
|
|
|
def step(self, *args, **kwargs):
|
|
|
|
if self.mixed_precision.should_skip_step():
|
|
|
|
self.zero_grad()
|
|
|
|
return
|
|
|
|
# prepare grads
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
for p in group['params']:
|
|
|
|
working_param = self.master_to_working_map[p]
|
|
|
|
if p is working_param:
|
|
|
|
continue
|
2023-08-01 09:29:09 +00:00
|
|
|
if working_param.grad is not None:
|
2023-07-25 16:53:57 +00:00
|
|
|
p.grad = working_param.grad.data.float()
|
|
|
|
working_param.grad = None
|
|
|
|
total_norm = self._compute_grad_norm()
|
|
|
|
self._unscale_and_clip_grads(total_norm)
|
|
|
|
self.optim.step(*args, **kwargs)
|
|
|
|
# update working params
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
for p in group['params']:
|
|
|
|
working_param = self.master_to_working_map[p]
|
|
|
|
if p is working_param:
|
|
|
|
continue
|
|
|
|
working_param.data.copy_(p.data)
|