ColossalAI/colossalai/amp/naive_amp/mixed_precision_optimizer.py

170 lines
6.2 KiB
Python

from typing import Dict, List
import torch
from torch import Tensor
from torch.nn import Module, Parameter
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(
self,
working_params: List[Parameter],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
)
self.params = working_params
def check_local_overflow(self) -> bool:
for p in self.params:
if p.grad is not None and not torch.isfinite(p.grad).all():
return True
return False
class MixedPrecisionOptimizer(OptimizerWrapper):
def __init__(
self,
optim: Optimizer,
precision: str = "fp16",
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
):
super().__init__(optim)
if precision == "fp16":
working_params = []
for group in self.optim.param_groups:
for p in group["params"]:
working_params.append(p)
self.mixed_precision = NaiveFP16MixedPrecisionMixin(
working_params,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
elif precision == "bf16":
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
# create master weights
for group in self.optim.param_groups:
master_params = []
for p in group["params"]:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
master_p = p.detach().float()
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group["params"] = master_params
def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
p.grad = None
self.mixed_precision.pre_zero_grad()
return super().zero_grad(*args, **kwargs)
def _unscale_and_clip_grads(self, total_norm: float) -> None:
div_scale = 1.0
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()
if self.max_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(1.0 / div_scale)
def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.0:
return 0.0
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0:
return 0.0
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()
def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
# prepare grads
for group in self.optim.param_groups:
for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
working_param.data.copy_(p.data)
def update_master_params(self, model: Module):
# Update master params from working params
with torch.no_grad():
for p in model.parameters():
if (p is None) or (p not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[p]
master_param.data.copy_(p.data)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}