mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
7.8 KiB
214 lines
7.8 KiB
from typing import Dict, List, Tuple |
|
|
|
import torch |
|
from torch import Tensor, inf |
|
from torch.nn import Module, Parameter |
|
from torch.optim import Optimizer |
|
|
|
from colossalai.interface import OptimizerWrapper |
|
|
|
from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin |
|
|
|
|
|
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): |
|
def __init__( |
|
self, |
|
working_params: List[Parameter], |
|
initial_scale: float = 2**16, |
|
min_scale: float = 1, |
|
growth_factor: float = 2, |
|
backoff_factor: float = 0.5, |
|
growth_interval: int = 1000, |
|
hysteresis: int = 2, |
|
max_scale: float = 2**32, |
|
) -> None: |
|
super().__init__( |
|
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale |
|
) |
|
self.params = working_params |
|
|
|
def check_local_overflow(self) -> bool: |
|
for p in self.params: |
|
if p.grad is not None and not torch.isfinite(p.grad).all(): |
|
return True |
|
return False |
|
|
|
|
|
class MixedPrecisionOptimizer(OptimizerWrapper): |
|
def __init__( |
|
self, |
|
optim: Optimizer, |
|
precision: str = "fp16", |
|
initial_scale: float = 2**16, |
|
min_scale: float = 1, |
|
growth_factor: float = 2, |
|
backoff_factor: float = 0.5, |
|
growth_interval: int = 1000, |
|
hysteresis: int = 2, |
|
max_scale: float = 2**32, |
|
max_norm: float = 0.0, |
|
): |
|
super().__init__(optim) |
|
if precision == "fp16": |
|
working_params = [] |
|
for group in self.optim.param_groups: |
|
for p in group["params"]: |
|
working_params.append(p) |
|
self.mixed_precision = NaiveFP16MixedPrecisionMixin( |
|
working_params, |
|
initial_scale=initial_scale, |
|
min_scale=min_scale, |
|
growth_factor=growth_factor, |
|
backoff_factor=backoff_factor, |
|
growth_interval=growth_interval, |
|
hysteresis=hysteresis, |
|
max_scale=max_scale, |
|
) |
|
elif precision == "bf16": |
|
self.mixed_precision = BF16MixedPrecisionMixin() |
|
else: |
|
raise ValueError(f"Unsupported precision: {precision}") |
|
self.max_norm = max_norm |
|
self.working_to_master_map: Dict[Parameter, Tensor] = {} |
|
self.master_to_working_map: Dict[Tensor, Parameter] = {} |
|
|
|
# create master weights |
|
for group in self.optim.param_groups: |
|
master_params = [] |
|
for p in group["params"]: |
|
if p.requires_grad: |
|
master_p = p |
|
if p.dtype != torch.float: |
|
master_p = p.detach().float() |
|
self.working_to_master_map[p] = master_p |
|
self.master_to_working_map[master_p] = p |
|
master_params.append(master_p) |
|
group["params"] = master_params |
|
|
|
def backward(self, loss: Tensor, *args, **kwargs): |
|
loss = self.mixed_precision.pre_backward(loss) |
|
loss.backward(*args, **kwargs) |
|
|
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor): |
|
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) |
|
tensor.backward(grad) |
|
|
|
def zero_grad(self, *args, **kwargs): |
|
for p in self.working_to_master_map.keys(): |
|
p.grad = None |
|
self.mixed_precision.pre_zero_grad() |
|
return super().zero_grad(*args, **kwargs) |
|
|
|
def _unscale_and_clip_grads(self, total_norm: float) -> None: |
|
""" |
|
Unscale and clip gradients before performing the optimization step. |
|
|
|
Args: |
|
total_norm (float): The computed total gradient norm. |
|
|
|
Returns: |
|
None |
|
""" |
|
div_scale = 1.0 |
|
|
|
# If mixed-precision training is used, get the gradient division scale from the mixed-precision handler. |
|
if self.mixed_precision is not None: |
|
div_scale = self.mixed_precision.get_grad_div_scale() |
|
|
|
if self.max_norm > 0.0: |
|
# Calculate the scaling factor for gradient clipping |
|
# The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm' |
|
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm |
|
|
|
# If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping |
|
if clip > 1: |
|
div_scale = clip * div_scale |
|
|
|
# Apply the scaling factor to gradients |
|
for group in self.param_groups: |
|
for p in group["params"]: |
|
if p.grad is None: |
|
continue |
|
p.grad.data.mul_(1.0 / div_scale) |
|
|
|
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: |
|
r""" |
|
Compute and return the gradient norm for gradient clipping. |
|
|
|
Args: |
|
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. |
|
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. |
|
|
|
Returns: |
|
float: The total norm of the given gradients. |
|
""" |
|
|
|
if len(param_gradient_pairs) == 0: |
|
return 0.0 |
|
|
|
# gradients used for norm calculation. |
|
gradients = [grad for param, grad in param_gradient_pairs] |
|
|
|
if norm_type == inf: |
|
total_norm = max(grad.data.abs().max() for grad in gradients) |
|
|
|
else: |
|
total_norm_exponentiated = 0.0 |
|
for grad in gradients: |
|
total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type |
|
total_norm = total_norm_exponentiated ** (1.0 / norm_type) |
|
|
|
return total_norm |
|
|
|
def step(self, *args, **kwargs): |
|
if self.mixed_precision.should_skip_step(): |
|
self.zero_grad() |
|
return |
|
# prepare grads |
|
for group in self.optim.param_groups: |
|
for p in group["params"]: |
|
working_param = self.master_to_working_map[p] |
|
if p is working_param: |
|
continue |
|
if working_param.grad is not None: |
|
p.grad = working_param.grad.data.float() |
|
working_param.grad = None |
|
|
|
# gradient unscale and clip. |
|
if self.max_norm <= 0: |
|
# no need to compute gradient norm. |
|
total_norm = 0.0 |
|
else: |
|
# compute the total norm. |
|
param_gradient_pairs = [ |
|
(self.master_to_working_map[p], p.grad) |
|
for group in self.param_groups |
|
for p in group["params"] |
|
if p.grad is not None |
|
] |
|
total_norm = self._compute_grad_norm(param_gradient_pairs) |
|
self._unscale_and_clip_grads(total_norm) |
|
|
|
self.optim.step(*args, **kwargs) |
|
# update working params |
|
for group in self.optim.param_groups: |
|
for p in group["params"]: |
|
working_param = self.master_to_working_map[p] |
|
if p is working_param: |
|
continue |
|
working_param.data.copy_(p.data) |
|
|
|
def update_master_params(self, model: Module): |
|
# Update master params from working params |
|
with torch.no_grad(): |
|
for p in model.parameters(): |
|
if (p is None) or (p not in self.working_to_master_map): |
|
continue |
|
master_param = self.working_to_master_map[p] |
|
master_param.data.copy_(p.data) |
|
|
|
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: |
|
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()} |
|
|
|
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: |
|
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
|
|