[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)

* Add clip_grad_norm for hibrid_parallel_plugin

* polish code

* add unittests

* Move tp to a higher-level optimizer interface.

* bug fix

* polish code
pull/4864/head
littsk 2023-10-12 11:32:37 +08:00 committed by GitHub
parent df63564184
commit 83b52c56cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1158 additions and 90 deletions

View File

@ -1,7 +1,7 @@
from typing import Dict, List from typing import Dict, List, Tuple
import torch import torch
from torch import Tensor from torch import Tensor, inf
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
@ -68,8 +68,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.mixed_precision = BF16MixedPrecisionMixin() self.mixed_precision = BF16MixedPrecisionMixin()
else: else:
raise ValueError(f"Unsupported precision: {precision}") raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {} self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {} self.master_to_working_map: Dict[Tensor, Parameter] = {}
@ -102,32 +100,65 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
return super().zero_grad(*args, **kwargs) return super().zero_grad(*args, **kwargs)
def _unscale_and_clip_grads(self, total_norm: float) -> None: def _unscale_and_clip_grads(self, total_norm: float) -> None:
"""
Unscale and clip gradients before performing the optimization step.
Args:
total_norm (float): The computed total gradient norm.
Returns:
None
"""
div_scale = 1.0 div_scale = 1.0
# If mixed-precision training is used, get the gradient division scale from the mixed-precision handler.
if self.mixed_precision is not None: if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale() div_scale = self.mixed_precision.get_grad_div_scale()
if self.max_norm > 0.0: if self.max_norm > 0.0:
# norm is in fact norm*scale # Calculate the scaling factor for gradient clipping
# The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm'
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
# If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping
if clip > 1: if clip > 1:
div_scale = clip * div_scale div_scale = clip * div_scale
# Apply the scaling factor to gradients
for group in self.param_groups: for group in self.param_groups:
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
p.grad.data.mul_(1.0 / div_scale) p.grad.data.mul_(1.0 / div_scale)
def _compute_grad_norm(self) -> float: def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
if self.max_norm <= 0.0: r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""
if len(param_gradient_pairs) == 0:
return 0.0 return 0.0
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0: # gradients used for norm calculation.
return 0.0 gradients = [grad for param, grad in param_gradient_pairs]
device = grads[0].device
# TODO(ver217): support tp if norm_type == inf:
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) total_norm = max(grad.data.abs().max() for grad in gradients)
return total_norm.item()
else:
total_norm_exponentiated = 0.0
for grad in gradients:
total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type
total_norm = total_norm_exponentiated ** (1.0 / norm_type)
return total_norm
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step(): if self.mixed_precision.should_skip_step():
@ -142,8 +173,22 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if working_param.grad is not None: if working_param.grad is not None:
p.grad = working_param.grad.data.float() p.grad = working_param.grad.data.float()
working_param.grad = None working_param.grad = None
total_norm = self._compute_grad_norm()
# gradient unscale and clip.
if self.max_norm <= 0:
# no need to compute gradient norm.
total_norm = 0.0
else:
# compute the total norm.
param_gradient_pairs = [
(self.master_to_working_map[p], p.grad)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._unscale_and_clip_grads(total_norm) self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs) self.optim.step(*args, **kwargs)
# update working params # update working params
for group in self.optim.param_groups: for group in self.optim.param_groups:

View File

@ -1,3 +1,4 @@
import ctypes
import random import random
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
@ -7,7 +8,8 @@ from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple,
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch import Tensor, inf
from torch.distributed import ProcessGroup, get_world_size
from torch.nn import Module, SyncBatchNorm from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
@ -24,6 +26,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
@ -160,12 +163,143 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
class HybridParallelNaiveOptimizer(OptimizerWrapper): class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): def __init__(
self,
optim: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
max_norm: float = 0,
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
):
self.param_info = param_info self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.max_norm = max_norm
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
super().__init__(optim) super().__init__(optim)
def step(self, *args, **kwargs):
r"""
Perform an optimization step.
Args:
*args: Variable-length positional arguments to be passed to the optimizer's step function.
**kwargs: Keyword arguments to be passed to the optimizer's step function.
"""
if self.max_norm > 0:
# Compute the total gradient norm.
param_gradient_pairs = [
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
# Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm)
# Perform the optimization step using the underlying optimizer.
self.optim.step(*args, **kwargs)
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""
if len(param_gradient_pairs) == 0:
return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
# grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'.
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
if grad is stage_shared_param.grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# compute the total_norm
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
def _clip_grad_norm(self, total_norm: float) -> None:
r"""
Clips the gradients of the model's parameters to prevent exploding gradients.
Args:
total_norm (float): The computed total gradient norm.
Returns:
None
"""
clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6))
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for group in self.optim.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(clip_coef_clamped)
def update_master_params(self, model: Module): def update_master_params(self, model: Module):
pass pass
@ -192,23 +326,108 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0, max_norm: float = 0,
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
): ):
self.param_info = param_info self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
super().__init__( super().__init__(
optim, optim,
precision, precision=precision,
initial_scale, initial_scale=initial_scale,
min_scale, min_scale=min_scale,
growth_factor, growth_factor=growth_factor,
backoff_factor, backoff_factor=backoff_factor,
growth_interval, growth_interval=growth_interval,
hysteresis, hysteresis=hysteresis,
max_scale, max_scale=max_scale,
max_norm, max_norm=max_norm,
) )
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""
if len(param_gradient_pairs) == 0:
return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
if norm_type == inf:
# The parent class calculates the norm of 'dp' gradients,
# so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
# grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism.
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_working_shared_param = shared_param[self.stage_manager.stage]
stage_master_shared_param = self.working_to_master_map[stage_working_shared_param]
if grad is stage_master_shared_param.grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# compute the total_norm
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__( def __init__(
@ -233,9 +452,15 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
): ):
self.param_info = param_info self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optimizer, model) init_pipeline_optimizer(optimizer, model)
super().__init__( super().__init__(
@ -255,10 +480,90 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
partition_grad, partition_grad,
cpu_offload, cpu_offload,
dp_process_group, dp_process_group,
tp_process_group,
forced_dtype, forced_dtype,
) )
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
gradients (List[Tensor]): A list of tensors containing gradients.
norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2.
Returns:
float: The computed gradient norm.
"""
# Check if the list of gradients is empty
if len(gradients) == 0:
return 0.0
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
if norm_type == inf:
# The parent class calculates the norm of 'dp' gradients,
# so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
if grad is working_grad:
grad_norm_exponentiated /= len(shared_param)
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if dp_size > 1:
# compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
if tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
# Compute the 'total_norm' from 'total_norm_exponentiated'
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
class HybridParallelPlugin(PipelinePluginBase): class HybridParallelPlugin(PipelinePluginBase):
""" """
@ -475,11 +780,19 @@ class HybridParallelPlugin(PipelinePluginBase):
param_info=param_info, param_info=param_info,
precision=self.precision, precision=self.precision,
max_norm=self.max_norm, max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
**self.amp_config, **self.amp_config,
) )
else: else:
optimizer = HybridParallelNaiveOptimizer( optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
tp_process_group=self.tp_group,
) )
else: else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
@ -491,6 +804,7 @@ class HybridParallelPlugin(PipelinePluginBase):
param_info=param_info, param_info=param_info,
dp_process_group=self.dp_group, dp_process_group=self.dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
verbose=True, verbose=True,
clip_grad_norm=self.max_norm, clip_grad_norm=self.max_norm,
**self.zero_config, **self.zero_config,

View File

@ -3,9 +3,7 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
def flatten(input_): def flatten(input_):
@ -192,53 +190,6 @@ def calculate_global_norm_from_list(norm_list):
total_norm += norm**2.0 total_norm += norm**2.0
return math.sqrt(total_norm) return math.sqrt(total_norm)
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters.
Args:
gradients (Tensor): The gradients to compute norm
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
tp_group (ProcessGroup): The process group of Tensor Parallelism
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
Returns:
int: The total norm of given gradients
"""
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
# Take max across all GPUs.
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
for g in gradients:
param_norm = g.data.double().norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
total_norm = -1
return total_norm
def sync_tensor(flat_tensor, tensor_list): def sync_tensor(flat_tensor, tensor_list):
""" """
Synchronize the flattened tensor and unflattened tensor list. When Synchronize the flattened tensor and unflattened tensor list. When

View File

@ -21,6 +21,8 @@ class GradientStore(BaseStore):
# for zero2, it's `param_id: [grad_local_rank]` # for zero2, it's `param_id: [grad_local_rank]`
self._working_index = 0 if partition_grad else self._local_rank self._working_index = 0 if partition_grad else self._local_rank
self.grad_to_param_mapping = dict()
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
"""Return list of gradient slices of a specific parameter """Return list of gradient slices of a specific parameter
@ -54,6 +56,8 @@ class GradientStore(BaseStore):
else: else:
self._grads_of_params[group_id][param_id].append(grad) self._grads_of_params[group_id][param_id].append(grad)
self.grad_to_param_mapping[id(grad)] = param_id
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int): def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
"""Add a gradient slice on an existing slice of the parameter's gradient """Add a gradient slice on an existing slice of the parameter's gradient
Used when no_sync is not activated. Used when no_sync is not activated.
@ -83,8 +87,37 @@ class GradientStore(BaseStore):
return grad_list return grad_list
def get_working_grad_by_param_id(self, param_id) -> Tensor:
"""
Return the working gradient for the specified parameter.
Args:
param_id (int): The index of the parameter.
Returns:
Tensor: The the working gradient slices for the specified param_id.
"""
for group in self._grads_of_params.values():
if param_id in group.keys():
return group[param_id][self._working_index]
raise KeyError(f"Working gradient for param_id {param_id} not found.")
def reset_grads_by_group_id(self, group_id: int): def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict() self._grads_of_params[group_id] = dict()
def reset_all_gradients(self): def reset_all_gradients(self):
self._grads_of_params = dict() self._grads_of_params = dict()
def get_param_id_for_grad(self, grad: Tensor) -> int:
"""Return the id of a parameter which the gradient slice belongs to
Args:
grad (Tensor): the gradient slice
Returns:
int: the id of a parameter which the gradient slice belongs to
"""
return self.grad_to_param_mapping[id(grad)]

View File

@ -2,11 +2,12 @@
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Dict, Iterator, Optional, Tuple from typing import Dict, Iterator, List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch import Tensor, inf
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
@ -21,14 +22,7 @@ from colossalai.logging import get_dist_logger
# from colossalai.tensor import ColoParameter, ProcessGroup # from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._utils import ( from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
calculate_global_norm_from_list,
compute_norm,
flatten,
has_inf_or_nan,
release_param_grad,
sync_tensor,
)
from .bookkeeping import BucketStore, GradientStore, ParameterStore from .bookkeeping import BucketStore, GradientStore, ParameterStore
@ -80,7 +74,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
): ):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -101,8 +94,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._local_rank = dist.get_rank(group=self.dp_pg) self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg)
self.tp_pg = tp_process_group
# working and master params for mixed precision training # working and master params for mixed precision training
self._working_param_groups = dict() self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict() self._master_param_groups_of_current_rank = dict()
@ -433,7 +424,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# compute norm # compute norm
working_grads = self._grad_store.get_working_grads_by_group_id(group_id) working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg) norm_group = self._compute_grad_norm(gradients=working_grads)
norm_groups.append(norm_group) norm_groups.append(norm_group)
self._grad_store.reset_grads_by_group_id(group_id) self._grad_store.reset_grads_by_group_id(group_id)
@ -467,6 +458,44 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
gradients (List[Tensor]): The gradients to compute norm
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
Returns:
float: The total norm of given gradients
"""
if len(gradients) == 0:
return 0.0
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
else:
total_norm_exponentiated = 0.0
for grad in gradients:
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
total_norm_exponentiated += grad_norm_exponentiated
# Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
return total_norm
############################# #############################
# Mixed Precision Utilities # # Mixed Precision Utilities #
############################# #############################

View File

@ -0,0 +1,258 @@
import pytest
import torch
from torch.nn.utils.clip_grad import clip_grad_norm_
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
elif test_config["precision"] == "fp16":
atol, rtol = 5e-3, 5e-3
else:
atol, rtol = 2e-2, 2e-2
# Check grads
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
col_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
row_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
check_all_grad_tensors(grads_to_check)
# Check gradient norm
# Convert the gradient data of the working parameter to float and assign it to the master parameter's gradient
# Note that this operation should have been done in the 'step' function, but it is performed here in advance for gradient norm calculation purposes.
# Although it will be done again in the 'step' function, it does not affect correctness.
for group in sharded_optimizer.optim.param_groups:
for p in group["params"]:
working_param = sharded_optimizer.master_to_working_map[p]
if p is working_param:
continue
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
# Create a list of parameter-gradient pairs containing working parameters and their gradients
param_gradient_pairs = [
(sharded_optimizer.master_to_working_map[p], p.grad)
for group in sharded_optimizer.param_groups
for p in group["params"]
if p.grad is not None
]
origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"])
# Calculate the gradient norm of the sharded optimizer
device = origin_norm.device
hybrid_norm = torch.tensor(sharded_optimizer._compute_grad_norm(param_gradient_pairs)).to(device)
# If using fp16 precision, divide by the initial scale
if test_config["precision"] == "fp16":
hybrid_norm /= test_config["initial_scale"]
# Assert that the gradient norm of the original model is close to the gradient norm of the hybrid model
assert torch.allclose(
origin_norm, hybrid_norm, atol=atol, rtol=rtol
), f"Original model grad norm is not equal to sharded model grad norm\n{origin_norm}\n{hybrid_norm}"
# Optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# Check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
elif test_config["precision"] == "fp16":
atol, rtol = 5e-3, 5e-3
else:
atol, rtol = 2e-2, 2e-2
if org_model.__class__.__name__ == "BertModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# Check weights
if test_config["precision"] == "fp32":
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
],
)
def run_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
],
)
def run_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_grad_clip_norm(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_test()
def check_grad_clip_norm_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_grad_clip_norm():
spawn(check_grad_clip_norm, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_grad_clip_norm_3d():
spawn(check_grad_clip_norm_3d, 8)
if __name__ == "__main__":
test_grad_clip_norm()
test_grad_clip_norm_3d()

View File

@ -0,0 +1,197 @@
import pytest
import torch
from torch.nn.utils.clip_grad import clip_grad_norm_
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
elif test_config["precision"] == "fp16":
atol, rtol = 5e-3, 5e-3
else:
atol, rtol = 2e-2, 2e-2
# Check grads
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
col_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
row_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
check_all_grad_tensors(grads_to_check)
# Check grad norm
param_gradient_pairs = [
(p, p.grad) for group in sharded_optimizer.param_groups for p in group["params"] if p.grad is not None
]
origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"])
device = origin_norm.device
hybrid_norm = torch.tensor(sharded_optimizer._compute_grad_norm(param_gradient_pairs)).to(device)
assert torch.allclose(
origin_norm, hybrid_norm, atol=atol, rtol=rtol
), f"orgin origin model grad norm is not equal to shard model grad norm\n{origin_norm}\n{hybrid_norm}"
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
elif test_config["precision"] == "fp16":
atol, rtol = 5e-3, 5e-3
else:
atol, rtol = 2e-2, 2e-2
if org_model.__class__.__name__ == "BertModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if test_config["precision"] == "fp32":
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": True,
"precision": "fp32",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
},
],
)
def run_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
},
],
)
def run_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_grad_clip_norm(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_test()
def check_grad_clip_norm_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_grad_clip_norm():
spawn(check_grad_clip_norm, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_grad_clip_norm_3d():
spawn(check_grad_clip_norm_3d, 8)
if __name__ == "__main__":
test_grad_clip_norm()
test_grad_clip_norm_3d()

View File

@ -0,0 +1,241 @@
import math
import pytest
import torch
import torch.distributed as dist
from torch.nn.utils.clip_grad import clip_grad_norm_
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
col_layer_for_check = ["encoder.layer[0].output.dense"]
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
elif test_config["precision"] == "fp16":
atol, rtol = 5e-3, 5e-3
else:
atol, rtol = 2e-2, 2e-2
dist.barrier()
# Check gradient norm
origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"])
# Calculate the gradient norm of the sharded optimizer
device = origin_norm.device
norm_groups = []
for group_id in range(sharded_optimizer.num_param_groups):
working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id)
norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads)
norm_groups.append(norm_group)
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
hybrid_norm = torch.tensor(math.sqrt(total_norm)).to(device)
# If using fp16 precision, divide by the initial scale
if test_config["precision"] == "fp16":
hybrid_norm /= test_config["initial_scale"]
# Assert that the gradient norm of the original model is close to the gradient norm of the hybrid model
assert torch.allclose(
origin_norm, hybrid_norm, atol=atol, rtol=rtol
), f"Original model grad norm is not equal to sharded model grad norm\n{origin_norm}\n{hybrid_norm}"
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
elif test_config["precision"] == "fp16":
atol, rtol = 5e-3, 5e-3
else:
atol, rtol = 2e-2, 2e-2
if org_model.__class__.__name__ == "BertModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if test_config["precision"] == "fp32":
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
],
)
def run_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
],
)
def run_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_grad_clip_norm(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_test()
def check_grad_clip_norm_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_grad_clip_norm():
spawn(check_grad_clip_norm, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_grad_clip_norm_3d():
spawn(check_grad_clip_norm_3d, 8)
if __name__ == "__main__":
test_grad_clip_norm()
test_grad_clip_norm_3d()