mirror of https://github.com/hpcaitech/ColossalAI
[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)
* Add clip_grad_norm for hibrid_parallel_plugin * polish code * add unittests * Move tp to a higher-level optimizer interface. * bug fix * polish codepull/4864/head
parent
df63564184
commit
83b52c56cd
|
@ -1,7 +1,7 @@
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor, inf
|
||||||
from torch.nn import Module, Parameter
|
from torch.nn import Module, Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
@ -68,8 +68,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||||
self.mixed_precision = BF16MixedPrecisionMixin()
|
self.mixed_precision = BF16MixedPrecisionMixin()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported precision: {precision}")
|
raise ValueError(f"Unsupported precision: {precision}")
|
||||||
if max_norm > 0.0:
|
|
||||||
raise NotImplementedError("max_norm is not supported yet.")
|
|
||||||
self.max_norm = max_norm
|
self.max_norm = max_norm
|
||||||
self.working_to_master_map: Dict[Parameter, Tensor] = {}
|
self.working_to_master_map: Dict[Parameter, Tensor] = {}
|
||||||
self.master_to_working_map: Dict[Tensor, Parameter] = {}
|
self.master_to_working_map: Dict[Tensor, Parameter] = {}
|
||||||
|
@ -102,32 +100,65 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||||
return super().zero_grad(*args, **kwargs)
|
return super().zero_grad(*args, **kwargs)
|
||||||
|
|
||||||
def _unscale_and_clip_grads(self, total_norm: float) -> None:
|
def _unscale_and_clip_grads(self, total_norm: float) -> None:
|
||||||
|
"""
|
||||||
|
Unscale and clip gradients before performing the optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_norm (float): The computed total gradient norm.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
div_scale = 1.0
|
div_scale = 1.0
|
||||||
|
|
||||||
|
# If mixed-precision training is used, get the gradient division scale from the mixed-precision handler.
|
||||||
if self.mixed_precision is not None:
|
if self.mixed_precision is not None:
|
||||||
div_scale = self.mixed_precision.get_grad_div_scale()
|
div_scale = self.mixed_precision.get_grad_div_scale()
|
||||||
|
|
||||||
if self.max_norm > 0.0:
|
if self.max_norm > 0.0:
|
||||||
# norm is in fact norm*scale
|
# Calculate the scaling factor for gradient clipping
|
||||||
|
# The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm'
|
||||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||||
|
|
||||||
|
# If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping
|
||||||
if clip > 1:
|
if clip > 1:
|
||||||
div_scale = clip * div_scale
|
div_scale = clip * div_scale
|
||||||
|
|
||||||
|
# Apply the scaling factor to gradients
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
p.grad.data.mul_(1.0 / div_scale)
|
p.grad.data.mul_(1.0 / div_scale)
|
||||||
|
|
||||||
def _compute_grad_norm(self) -> float:
|
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||||
if self.max_norm <= 0.0:
|
r"""
|
||||||
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||||
|
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The total norm of the given gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(param_gradient_pairs) == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
|
|
||||||
if len(grads) == 0:
|
# gradients used for norm calculation.
|
||||||
return 0.0
|
gradients = [grad for param, grad in param_gradient_pairs]
|
||||||
device = grads[0].device
|
|
||||||
# TODO(ver217): support tp
|
if norm_type == inf:
|
||||||
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||||
return total_norm.item()
|
|
||||||
|
else:
|
||||||
|
total_norm_exponentiated = 0.0
|
||||||
|
for grad in gradients:
|
||||||
|
total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type
|
||||||
|
total_norm = total_norm_exponentiated ** (1.0 / norm_type)
|
||||||
|
|
||||||
|
return total_norm
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
if self.mixed_precision.should_skip_step():
|
if self.mixed_precision.should_skip_step():
|
||||||
|
@ -142,8 +173,22 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||||
if working_param.grad is not None:
|
if working_param.grad is not None:
|
||||||
p.grad = working_param.grad.data.float()
|
p.grad = working_param.grad.data.float()
|
||||||
working_param.grad = None
|
working_param.grad = None
|
||||||
total_norm = self._compute_grad_norm()
|
|
||||||
|
# gradient unscale and clip.
|
||||||
|
if self.max_norm <= 0:
|
||||||
|
# no need to compute gradient norm.
|
||||||
|
total_norm = 0.0
|
||||||
|
else:
|
||||||
|
# compute the total norm.
|
||||||
|
param_gradient_pairs = [
|
||||||
|
(self.master_to_working_map[p], p.grad)
|
||||||
|
for group in self.param_groups
|
||||||
|
for p in group["params"]
|
||||||
|
if p.grad is not None
|
||||||
|
]
|
||||||
|
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||||
self._unscale_and_clip_grads(total_norm)
|
self._unscale_and_clip_grads(total_norm)
|
||||||
|
|
||||||
self.optim.step(*args, **kwargs)
|
self.optim.step(*args, **kwargs)
|
||||||
# update working params
|
# update working params
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import ctypes
|
||||||
import random
|
import random
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -7,7 +8,8 @@ from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple,
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch import Tensor, inf
|
||||||
|
from torch.distributed import ProcessGroup, get_world_size
|
||||||
from torch.nn import Module, SyncBatchNorm
|
from torch.nn import Module, SyncBatchNorm
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -24,6 +26,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||||
|
|
||||||
from .pp_plugin_base import PipelinePluginBase
|
from .pp_plugin_base import PipelinePluginBase
|
||||||
|
@ -160,12 +163,143 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module):
|
||||||
|
|
||||||
|
|
||||||
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
|
def __init__(
|
||||||
|
self,
|
||||||
|
optim: Optimizer,
|
||||||
|
model: Module,
|
||||||
|
use_pipeline: bool,
|
||||||
|
param_info: OrderedDict,
|
||||||
|
max_norm: float = 0,
|
||||||
|
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||||
|
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||||
|
):
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optim, model)
|
init_pipeline_optimizer(optim, model)
|
||||||
|
self.stage_manager = model.stage_manager
|
||||||
|
self.shared_params = model.shared_params
|
||||||
|
self.max_norm = max_norm
|
||||||
|
self.tp_pg = tp_process_group
|
||||||
|
self.pp_pg = pp_process_group
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
|
|
||||||
|
def step(self, *args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Perform an optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Variable-length positional arguments to be passed to the optimizer's step function.
|
||||||
|
**kwargs: Keyword arguments to be passed to the optimizer's step function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.max_norm > 0:
|
||||||
|
# Compute the total gradient norm.
|
||||||
|
param_gradient_pairs = [
|
||||||
|
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
|
||||||
|
]
|
||||||
|
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||||
|
|
||||||
|
# Clip the gradients to prevent exploding gradients.
|
||||||
|
self._clip_grad_norm(total_norm)
|
||||||
|
|
||||||
|
# Perform the optimization step using the underlying optimizer.
|
||||||
|
self.optim.step(*args, **kwargs)
|
||||||
|
|
||||||
|
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||||
|
r"""
|
||||||
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||||
|
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The total norm of the given gradients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(param_gradient_pairs) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||||
|
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
|
||||||
|
# gradients used for norm calculation.
|
||||||
|
gradients = [grad for param, grad in param_gradient_pairs]
|
||||||
|
|
||||||
|
if norm_type == inf:
|
||||||
|
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||||
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
|
if tp_size > 1:
|
||||||
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||||
|
if pp_size > 1:
|
||||||
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||||
|
total_norm = total_norm_cuda.item()
|
||||||
|
else:
|
||||||
|
# gradients used for norm calculation.
|
||||||
|
gradients = [grad for param, grad in param_gradient_pairs]
|
||||||
|
# grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'.
|
||||||
|
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
|
||||||
|
|
||||||
|
total_norm_exponentiated = 0.0
|
||||||
|
for grad in gradients:
|
||||||
|
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||||
|
|
||||||
|
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||||
|
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||||
|
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||||
|
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||||
|
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||||
|
if tp_size > 1:
|
||||||
|
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||||
|
if not is_distributed_tensor(param_for_grad):
|
||||||
|
grad_norm_exponentiated /= tp_size
|
||||||
|
|
||||||
|
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||||
|
# it means that this parameter is used in two different pipeline stages.
|
||||||
|
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||||
|
# the number of shared stages.
|
||||||
|
if pp_size > 1:
|
||||||
|
for shared_param in self.shared_params:
|
||||||
|
if self.stage_manager.stage in shared_param:
|
||||||
|
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||||
|
if grad is stage_shared_param.grad:
|
||||||
|
grad_norm_exponentiated /= len(shared_param)
|
||||||
|
|
||||||
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
|
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||||
|
if tp_size > 1:
|
||||||
|
# compute norm in tp process group
|
||||||
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
if pp_size > 1:
|
||||||
|
# compute norm in pp process group
|
||||||
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||||
|
|
||||||
|
# compute the total_norm
|
||||||
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||||
|
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
def _clip_grad_norm(self, total_norm: float) -> None:
|
||||||
|
r"""
|
||||||
|
Clips the gradients of the model's parameters to prevent exploding gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_norm (float): The computed total gradient norm.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6))
|
||||||
|
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||||
|
|
||||||
|
for group in self.optim.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
p.grad.data.mul_(clip_coef_clamped)
|
||||||
|
|
||||||
def update_master_params(self, model: Module):
|
def update_master_params(self, model: Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -192,23 +326,108 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
hysteresis: int = 2,
|
hysteresis: int = 2,
|
||||||
max_scale: float = 2**32,
|
max_scale: float = 2**32,
|
||||||
max_norm: float = 0,
|
max_norm: float = 0,
|
||||||
|
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||||
|
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||||
):
|
):
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
|
self.stage_manager = model.stage_manager
|
||||||
|
self.shared_params = model.shared_params
|
||||||
|
self.tp_pg = tp_process_group
|
||||||
|
self.pp_pg = pp_process_group
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optim, model)
|
init_pipeline_optimizer(optim, model)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
optim,
|
optim,
|
||||||
precision,
|
precision=precision,
|
||||||
initial_scale,
|
initial_scale=initial_scale,
|
||||||
min_scale,
|
min_scale=min_scale,
|
||||||
growth_factor,
|
growth_factor=growth_factor,
|
||||||
backoff_factor,
|
backoff_factor=backoff_factor,
|
||||||
growth_interval,
|
growth_interval=growth_interval,
|
||||||
hysteresis,
|
hysteresis=hysteresis,
|
||||||
max_scale,
|
max_scale=max_scale,
|
||||||
max_norm,
|
max_norm=max_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||||
|
r"""
|
||||||
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||||
|
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The total norm of the given gradients.
|
||||||
|
"""
|
||||||
|
if len(param_gradient_pairs) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||||
|
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
|
||||||
|
if norm_type == inf:
|
||||||
|
# The parent class calculates the norm of 'dp' gradients,
|
||||||
|
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
||||||
|
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
||||||
|
|
||||||
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
|
|
||||||
|
if tp_size > 1:
|
||||||
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||||
|
if pp_size > 1:
|
||||||
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||||
|
|
||||||
|
total_norm = total_norm_cuda.item()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# gradients used for norm calculation.
|
||||||
|
gradients = [grad for param, grad in param_gradient_pairs]
|
||||||
|
# grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism.
|
||||||
|
grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs}
|
||||||
|
|
||||||
|
total_norm_exponentiated = 0.0
|
||||||
|
for grad in gradients:
|
||||||
|
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||||
|
|
||||||
|
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||||
|
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||||
|
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||||
|
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||||
|
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||||
|
if tp_size > 1:
|
||||||
|
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||||
|
if not is_distributed_tensor(param_for_grad):
|
||||||
|
grad_norm_exponentiated /= tp_size
|
||||||
|
|
||||||
|
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||||
|
# it means that this parameter is used in two different pipeline stages.
|
||||||
|
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||||
|
# the number of shared stages.
|
||||||
|
if pp_size > 1:
|
||||||
|
for shared_param in self.shared_params:
|
||||||
|
if self.stage_manager.stage in shared_param:
|
||||||
|
stage_working_shared_param = shared_param[self.stage_manager.stage]
|
||||||
|
stage_master_shared_param = self.working_to_master_map[stage_working_shared_param]
|
||||||
|
if grad is stage_master_shared_param.grad:
|
||||||
|
grad_norm_exponentiated /= len(shared_param)
|
||||||
|
|
||||||
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
|
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||||
|
if tp_size > 1:
|
||||||
|
# compute norm in tp process group
|
||||||
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
if pp_size > 1:
|
||||||
|
# compute norm in pp process group
|
||||||
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||||
|
|
||||||
|
# compute the total_norm
|
||||||
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||||
|
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
|
||||||
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -233,9 +452,15 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
cpu_offload: bool = False, # cpu offload
|
cpu_offload: bool = False, # cpu offload
|
||||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||||
|
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
|
self.stage_manager = model.stage_manager
|
||||||
|
self.shared_params = model.shared_params
|
||||||
|
self.dp_pg = dp_process_group
|
||||||
|
self.tp_pg = tp_process_group
|
||||||
|
self.pp_pg = pp_process_group
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optimizer, model)
|
init_pipeline_optimizer(optimizer, model)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -255,10 +480,90 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
partition_grad,
|
partition_grad,
|
||||||
cpu_offload,
|
cpu_offload,
|
||||||
dp_process_group,
|
dp_process_group,
|
||||||
tp_process_group,
|
|
||||||
forced_dtype,
|
forced_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||||
|
r"""
|
||||||
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradients (List[Tensor]): A list of tensors containing gradients.
|
||||||
|
norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The computed gradient norm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if the list of gradients is empty
|
||||||
|
if len(gradients) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
dp_size = get_world_size(self.dp_pg) if self.dp_pg is not None else 1
|
||||||
|
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||||
|
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
|
||||||
|
if norm_type == inf:
|
||||||
|
# The parent class calculates the norm of 'dp' gradients,
|
||||||
|
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
||||||
|
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
||||||
|
|
||||||
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
|
|
||||||
|
if tp_size > 1:
|
||||||
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||||
|
if pp_size > 1:
|
||||||
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||||
|
|
||||||
|
total_norm = total_norm_cuda.item()
|
||||||
|
else:
|
||||||
|
total_norm_exponentiated = 0.0
|
||||||
|
for grad in gradients:
|
||||||
|
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||||
|
|
||||||
|
# If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor,
|
||||||
|
# it indicates that the parameter is not distributed across devices of the 'tp_group'.
|
||||||
|
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||||
|
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||||
|
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||||
|
if tp_size > 1:
|
||||||
|
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||||
|
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||||
|
|
||||||
|
if not is_distributed_tensor(param_for_grad):
|
||||||
|
grad_norm_exponentiated /= tp_size
|
||||||
|
|
||||||
|
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||||
|
# it means that this parameter is used in two different pipeline stages.
|
||||||
|
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||||
|
# the number of shared stages.
|
||||||
|
if pp_size > 1:
|
||||||
|
for shared_param in self.shared_params:
|
||||||
|
if self.stage_manager.stage in shared_param:
|
||||||
|
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||||
|
working_grad = self._grad_store.get_working_grad_by_param_id(id(stage_shared_param))
|
||||||
|
if grad is working_grad:
|
||||||
|
grad_norm_exponentiated /= len(shared_param)
|
||||||
|
|
||||||
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
|
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||||
|
if dp_size > 1:
|
||||||
|
# compute norm in dp process group
|
||||||
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
if tp_size > 1:
|
||||||
|
# compute norm in tp process group
|
||||||
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
if pp_size > 1:
|
||||||
|
# compute norm in pp process group
|
||||||
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||||
|
|
||||||
|
# Compute the 'total_norm' from 'total_norm_exponentiated'
|
||||||
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||||
|
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
|
||||||
class HybridParallelPlugin(PipelinePluginBase):
|
class HybridParallelPlugin(PipelinePluginBase):
|
||||||
"""
|
"""
|
||||||
|
@ -475,11 +780,19 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
max_norm=self.max_norm,
|
max_norm=self.max_norm,
|
||||||
|
pp_process_group=self.pp_group,
|
||||||
|
tp_process_group=self.tp_group,
|
||||||
**self.amp_config,
|
**self.amp_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer = HybridParallelNaiveOptimizer(
|
optimizer = HybridParallelNaiveOptimizer(
|
||||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
optimizer,
|
||||||
|
model,
|
||||||
|
use_pipeline=self.enable_pipeline_parallelism,
|
||||||
|
param_info=param_info,
|
||||||
|
max_norm=self.max_norm,
|
||||||
|
pp_process_group=self.pp_group,
|
||||||
|
tp_process_group=self.tp_group,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
|
||||||
|
@ -491,6 +804,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
dp_process_group=self.dp_group,
|
dp_process_group=self.dp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
|
pp_process_group=self.pp_group,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
clip_grad_norm=self.max_norm,
|
clip_grad_norm=self.max_norm,
|
||||||
**self.zero_config,
|
**self.zero_config,
|
||||||
|
|
|
@ -3,9 +3,7 @@ from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import Tensor, inf
|
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
|
|
||||||
|
|
||||||
def flatten(input_):
|
def flatten(input_):
|
||||||
|
@ -192,53 +190,6 @@ def calculate_global_norm_from_list(norm_list):
|
||||||
total_norm += norm**2.0
|
total_norm += norm**2.0
|
||||||
return math.sqrt(total_norm)
|
return math.sqrt(total_norm)
|
||||||
|
|
||||||
|
|
||||||
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
|
|
||||||
"""Clips gradient norm of an iterable of parameters.
|
|
||||||
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
|
|
||||||
added functionality to handle model parallel parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gradients (Tensor): The gradients to compute norm
|
|
||||||
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
|
|
||||||
tp_group (ProcessGroup): The process group of Tensor Parallelism
|
|
||||||
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The total norm of given gradients
|
|
||||||
"""
|
|
||||||
|
|
||||||
norm_type = float(norm_type)
|
|
||||||
if norm_type == inf:
|
|
||||||
total_norm = max(g.data.abs().max() for g in gradients)
|
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
|
||||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)
|
|
||||||
|
|
||||||
# Take max across all GPUs.
|
|
||||||
if tp_group is not None:
|
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
|
|
||||||
total_norm = total_norm_cuda[0].item()
|
|
||||||
else:
|
|
||||||
total_norm = 0.0
|
|
||||||
for g in gradients:
|
|
||||||
param_norm = g.data.double().norm(norm_type)
|
|
||||||
total_norm += param_norm.item() ** norm_type
|
|
||||||
|
|
||||||
# Sum across all model parallel GPUs.
|
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
|
||||||
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
|
|
||||||
|
|
||||||
if tp_group is not None:
|
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)
|
|
||||||
|
|
||||||
total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type)
|
|
||||||
|
|
||||||
if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm:
|
|
||||||
total_norm = -1
|
|
||||||
|
|
||||||
return total_norm
|
|
||||||
|
|
||||||
|
|
||||||
def sync_tensor(flat_tensor, tensor_list):
|
def sync_tensor(flat_tensor, tensor_list):
|
||||||
"""
|
"""
|
||||||
Synchronize the flattened tensor and unflattened tensor list. When
|
Synchronize the flattened tensor and unflattened tensor list. When
|
||||||
|
|
|
@ -21,6 +21,8 @@ class GradientStore(BaseStore):
|
||||||
# for zero2, it's `param_id: [grad_local_rank]`
|
# for zero2, it's `param_id: [grad_local_rank]`
|
||||||
self._working_index = 0 if partition_grad else self._local_rank
|
self._working_index = 0 if partition_grad else self._local_rank
|
||||||
|
|
||||||
|
self.grad_to_param_mapping = dict()
|
||||||
|
|
||||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||||
"""Return list of gradient slices of a specific parameter
|
"""Return list of gradient slices of a specific parameter
|
||||||
|
|
||||||
|
@ -54,6 +56,8 @@ class GradientStore(BaseStore):
|
||||||
else:
|
else:
|
||||||
self._grads_of_params[group_id][param_id].append(grad)
|
self._grads_of_params[group_id][param_id].append(grad)
|
||||||
|
|
||||||
|
self.grad_to_param_mapping[id(grad)] = param_id
|
||||||
|
|
||||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||||
"""Add a gradient slice on an existing slice of the parameter's gradient
|
"""Add a gradient slice on an existing slice of the parameter's gradient
|
||||||
Used when no_sync is not activated.
|
Used when no_sync is not activated.
|
||||||
|
@ -83,8 +87,37 @@ class GradientStore(BaseStore):
|
||||||
|
|
||||||
return grad_list
|
return grad_list
|
||||||
|
|
||||||
|
def get_working_grad_by_param_id(self, param_id) -> Tensor:
|
||||||
|
"""
|
||||||
|
Return the working gradient for the specified parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_id (int): The index of the parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The the working gradient slices for the specified param_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for group in self._grads_of_params.values():
|
||||||
|
if param_id in group.keys():
|
||||||
|
return group[param_id][self._working_index]
|
||||||
|
|
||||||
|
raise KeyError(f"Working gradient for param_id {param_id} not found.")
|
||||||
|
|
||||||
def reset_grads_by_group_id(self, group_id: int):
|
def reset_grads_by_group_id(self, group_id: int):
|
||||||
self._grads_of_params[group_id] = dict()
|
self._grads_of_params[group_id] = dict()
|
||||||
|
|
||||||
def reset_all_gradients(self):
|
def reset_all_gradients(self):
|
||||||
self._grads_of_params = dict()
|
self._grads_of_params = dict()
|
||||||
|
|
||||||
|
def get_param_id_for_grad(self, grad: Tensor) -> int:
|
||||||
|
"""Return the id of a parameter which the gradient slice belongs to
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grad (Tensor): the gradient slice
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the id of a parameter which the gradient slice belongs to
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.grad_to_param_mapping[id(grad)]
|
||||||
|
|
|
@ -2,11 +2,12 @@
|
||||||
import copy
|
import copy
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Iterator, Optional, Tuple
|
from typing import Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor, inf
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
@ -21,14 +22,7 @@ from colossalai.logging import get_dist_logger
|
||||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||||
calculate_global_norm_from_list,
|
|
||||||
compute_norm,
|
|
||||||
flatten,
|
|
||||||
has_inf_or_nan,
|
|
||||||
release_param_grad,
|
|
||||||
sync_tensor,
|
|
||||||
)
|
|
||||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,7 +74,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
partition_grad: bool = False, # stage 2 flag
|
partition_grad: bool = False, # stage 2 flag
|
||||||
cpu_offload: bool = False, # cpu offload
|
cpu_offload: bool = False, # cpu offload
|
||||||
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
|
||||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||||
|
@ -101,8 +94,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
self._local_rank = dist.get_rank(group=self.dp_pg)
|
self._local_rank = dist.get_rank(group=self.dp_pg)
|
||||||
self._world_size = dist.get_world_size(group=self.dp_pg)
|
self._world_size = dist.get_world_size(group=self.dp_pg)
|
||||||
|
|
||||||
self.tp_pg = tp_process_group
|
|
||||||
|
|
||||||
# working and master params for mixed precision training
|
# working and master params for mixed precision training
|
||||||
self._working_param_groups = dict()
|
self._working_param_groups = dict()
|
||||||
self._master_param_groups_of_current_rank = dict()
|
self._master_param_groups_of_current_rank = dict()
|
||||||
|
@ -433,7 +424,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
# compute norm
|
# compute norm
|
||||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||||
norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg)
|
norm_group = self._compute_grad_norm(gradients=working_grads)
|
||||||
norm_groups.append(norm_group)
|
norm_groups.append(norm_group)
|
||||||
|
|
||||||
self._grad_store.reset_grads_by_group_id(group_id)
|
self._grad_store.reset_grads_by_group_id(group_id)
|
||||||
|
@ -467,6 +458,44 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||||
|
|
||||||
|
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||||
|
r"""
|
||||||
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradients (List[Tensor]): The gradients to compute norm
|
||||||
|
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: The total norm of given gradients
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(gradients) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
if norm_type == inf:
|
||||||
|
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||||
|
|
||||||
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
|
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||||
|
total_norm = total_norm_cuda.item()
|
||||||
|
|
||||||
|
else:
|
||||||
|
total_norm_exponentiated = 0.0
|
||||||
|
for grad in gradients:
|
||||||
|
grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type
|
||||||
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
|
# Sum across all model parallel GPUs.
|
||||||
|
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
||||||
|
)
|
||||||
|
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)
|
||||||
|
|
||||||
|
return total_norm
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
# Mixed Precision Utilities #
|
# Mixed Precision Utilities #
|
||||||
#############################
|
#############################
|
||||||
|
|
|
@ -0,0 +1,258 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.clip_grad import clip_grad_norm_
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_all_grad_tensors,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
get_grad_tensors_for_check,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
|
model_fn, loss_fn, test_config
|
||||||
|
)
|
||||||
|
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_manager = booster.plugin.stage_manager
|
||||||
|
tp_group = booster.plugin.tp_group
|
||||||
|
|
||||||
|
bert = unwrap_model(org_model, "BertModel", "bert")
|
||||||
|
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
||||||
|
|
||||||
|
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
||||||
|
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
||||||
|
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-4, 1e-3
|
||||||
|
elif test_config["precision"] == "fp16":
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 2e-2, 2e-2
|
||||||
|
|
||||||
|
# Check grads
|
||||||
|
# Save gradient tensors for comparison between the original model and the sharded model.
|
||||||
|
grads_to_check = {}
|
||||||
|
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||||
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
|
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||||
|
)
|
||||||
|
row_layer_grads = get_grad_tensors_for_check(
|
||||||
|
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||||
|
)
|
||||||
|
grads_to_check.update(col_layer_grads)
|
||||||
|
grads_to_check.update(row_layer_grads)
|
||||||
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
# Check gradient norm
|
||||||
|
# Convert the gradient data of the working parameter to float and assign it to the master parameter's gradient
|
||||||
|
# Note that this operation should have been done in the 'step' function, but it is performed here in advance for gradient norm calculation purposes.
|
||||||
|
# Although it will be done again in the 'step' function, it does not affect correctness.
|
||||||
|
for group in sharded_optimizer.optim.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
working_param = sharded_optimizer.master_to_working_map[p]
|
||||||
|
if p is working_param:
|
||||||
|
continue
|
||||||
|
if working_param.grad is not None:
|
||||||
|
p.grad = working_param.grad.data.float()
|
||||||
|
working_param.grad = None
|
||||||
|
# Create a list of parameter-gradient pairs containing working parameters and their gradients
|
||||||
|
param_gradient_pairs = [
|
||||||
|
(sharded_optimizer.master_to_working_map[p], p.grad)
|
||||||
|
for group in sharded_optimizer.param_groups
|
||||||
|
for p in group["params"]
|
||||||
|
if p.grad is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"])
|
||||||
|
# Calculate the gradient norm of the sharded optimizer
|
||||||
|
device = origin_norm.device
|
||||||
|
hybrid_norm = torch.tensor(sharded_optimizer._compute_grad_norm(param_gradient_pairs)).to(device)
|
||||||
|
|
||||||
|
# If using fp16 precision, divide by the initial scale
|
||||||
|
if test_config["precision"] == "fp16":
|
||||||
|
hybrid_norm /= test_config["initial_scale"]
|
||||||
|
|
||||||
|
# Assert that the gradient norm of the original model is close to the gradient norm of the hybrid model
|
||||||
|
assert torch.allclose(
|
||||||
|
origin_norm, hybrid_norm, atol=atol, rtol=rtol
|
||||||
|
), f"Original model grad norm is not equal to sharded model grad norm\n{origin_norm}\n{hybrid_norm}"
|
||||||
|
|
||||||
|
# Optimizer executes step
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
|
||||||
|
# Check last hidden state & loss
|
||||||
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-5, 1e-3
|
||||||
|
elif test_config["precision"] == "fp16":
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 2e-2, 2e-2
|
||||||
|
if org_model.__class__.__name__ == "BertModel":
|
||||||
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
# Check weights
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 5e-3, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"max_norm": 5,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"max_norm": 5,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"max_norm": 5,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "bf16",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "bf16",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "bf16",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "bf16",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"max_norm": 5,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_3d_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad_clip_norm(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_test()
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad_clip_norm_3d(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_3d_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_grad_clip_norm():
|
||||||
|
spawn(check_grad_clip_norm, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_grad_clip_norm_3d():
|
||||||
|
spawn(check_grad_clip_norm_3d, 8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_grad_clip_norm()
|
||||||
|
test_grad_clip_norm_3d()
|
|
@ -0,0 +1,197 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.clip_grad import clip_grad_norm_
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_all_grad_tensors,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
get_grad_tensors_for_check,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
|
model_fn, loss_fn, test_config
|
||||||
|
)
|
||||||
|
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_manager = booster.plugin.stage_manager
|
||||||
|
tp_group = booster.plugin.tp_group
|
||||||
|
|
||||||
|
bert = unwrap_model(org_model, "BertModel", "bert")
|
||||||
|
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
||||||
|
|
||||||
|
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
||||||
|
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
||||||
|
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-4, 1e-3
|
||||||
|
elif test_config["precision"] == "fp16":
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 2e-2, 2e-2
|
||||||
|
|
||||||
|
# Check grads
|
||||||
|
# Save gradient tensors for comparison between the original model and the sharded model.
|
||||||
|
grads_to_check = {}
|
||||||
|
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
|
||||||
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
|
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||||
|
)
|
||||||
|
row_layer_grads = get_grad_tensors_for_check(
|
||||||
|
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||||
|
)
|
||||||
|
grads_to_check.update(col_layer_grads)
|
||||||
|
grads_to_check.update(row_layer_grads)
|
||||||
|
check_all_grad_tensors(grads_to_check)
|
||||||
|
|
||||||
|
# Check grad norm
|
||||||
|
param_gradient_pairs = [
|
||||||
|
(p, p.grad) for group in sharded_optimizer.param_groups for p in group["params"] if p.grad is not None
|
||||||
|
]
|
||||||
|
origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"])
|
||||||
|
device = origin_norm.device
|
||||||
|
hybrid_norm = torch.tensor(sharded_optimizer._compute_grad_norm(param_gradient_pairs)).to(device)
|
||||||
|
assert torch.allclose(
|
||||||
|
origin_norm, hybrid_norm, atol=atol, rtol=rtol
|
||||||
|
), f"orgin origin model grad norm is not equal to shard model grad norm\n{origin_norm}\n{hybrid_norm}"
|
||||||
|
|
||||||
|
# optimizer executes step
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
|
||||||
|
# check last hidden state & loss
|
||||||
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-5, 1e-3
|
||||||
|
elif test_config["precision"] == "fp16":
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 2e-2, 2e-2
|
||||||
|
|
||||||
|
if org_model.__class__.__name__ == "BertModel":
|
||||||
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
# check weights
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 5e-3, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp32",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp32",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_3d_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad_clip_norm(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_test()
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad_clip_norm_3d(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_3d_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_grad_clip_norm():
|
||||||
|
spawn(check_grad_clip_norm, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_grad_clip_norm_3d():
|
||||||
|
spawn(check_grad_clip_norm_3d, 8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_grad_clip_norm()
|
||||||
|
test_grad_clip_norm_3d()
|
|
@ -0,0 +1,241 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.utils.clip_grad import clip_grad_norm_
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
unwrap_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
|
model_fn, loss_fn, test_config
|
||||||
|
)
|
||||||
|
|
||||||
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_manager = booster.plugin.stage_manager
|
||||||
|
tp_group = booster.plugin.tp_group
|
||||||
|
|
||||||
|
bert = unwrap_model(org_model, "BertModel", "bert")
|
||||||
|
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
||||||
|
|
||||||
|
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
||||||
|
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-4, 1e-3
|
||||||
|
elif test_config["precision"] == "fp16":
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 2e-2, 2e-2
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
# Check gradient norm
|
||||||
|
origin_norm = clip_grad_norm_(org_model.parameters(), test_config["max_norm"])
|
||||||
|
|
||||||
|
# Calculate the gradient norm of the sharded optimizer
|
||||||
|
device = origin_norm.device
|
||||||
|
norm_groups = []
|
||||||
|
for group_id in range(sharded_optimizer.num_param_groups):
|
||||||
|
working_grads = sharded_optimizer._grad_store.get_working_grads_by_group_id(group_id)
|
||||||
|
norm_group = sharded_optimizer._compute_grad_norm(gradients=working_grads)
|
||||||
|
norm_groups.append(norm_group)
|
||||||
|
total_norm = 0.0
|
||||||
|
for norm in norm_groups:
|
||||||
|
total_norm += norm**2.0
|
||||||
|
hybrid_norm = torch.tensor(math.sqrt(total_norm)).to(device)
|
||||||
|
|
||||||
|
# If using fp16 precision, divide by the initial scale
|
||||||
|
if test_config["precision"] == "fp16":
|
||||||
|
hybrid_norm /= test_config["initial_scale"]
|
||||||
|
|
||||||
|
# Assert that the gradient norm of the original model is close to the gradient norm of the hybrid model
|
||||||
|
assert torch.allclose(
|
||||||
|
origin_norm, hybrid_norm, atol=atol, rtol=rtol
|
||||||
|
), f"Original model grad norm is not equal to sharded model grad norm\n{origin_norm}\n{hybrid_norm}"
|
||||||
|
|
||||||
|
# optimizer executes step
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
|
||||||
|
# check last hidden state & loss
|
||||||
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 1e-5, 1e-3
|
||||||
|
elif test_config["precision"] == "fp16":
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 2e-2, 2e-2
|
||||||
|
if org_model.__class__.__name__ == "BertModel":
|
||||||
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
# check weights
|
||||||
|
if test_config["precision"] == "fp32":
|
||||||
|
atol, rtol = 5e-3, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"max_norm": 5,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"max_norm": 5,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"max_norm": 5,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "bf16",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "bf16",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "bf16",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "bf16",
|
||||||
|
"max_norm": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 2,
|
||||||
|
"num_microbatches": 4,
|
||||||
|
"zero_stage": 1,
|
||||||
|
"enable_all_optimization": False,
|
||||||
|
"use_lazy_init": False,
|
||||||
|
"precision": "fp16",
|
||||||
|
"max_norm": 5,
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def run_3d_test(test_config):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad_clip_norm(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_test()
|
||||||
|
|
||||||
|
|
||||||
|
def check_grad_clip_norm_3d(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_3d_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_grad_clip_norm():
|
||||||
|
spawn(check_grad_clip_norm, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.largedist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_grad_clip_norm_3d():
|
||||||
|
spawn(check_grad_clip_norm_3d, 8)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_grad_clip_norm()
|
||||||
|
test_grad_clip_norm_3d()
|
Loading…
Reference in New Issue