[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------

Co-authored-by: littsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------

Co-authored-by: littsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------

Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
pull/5007/head
littsk 2023-11-03 13:32:43 +08:00 committed by GitHub
parent d99b2c961a
commit 1a3315e336
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1120 additions and 552 deletions

View File

@ -1,6 +1,6 @@
import ctypes import ctypes
import random import random
from contextlib import nullcontext from contextlib import contextmanager
from functools import partial from functools import partial
from types import MethodType from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
@ -25,6 +25,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
@ -47,12 +48,17 @@ class HybridParallelModule(ModelWrapper):
precision: str, precision: str,
shard_config: ShardConfig, shard_config: ShardConfig,
dp_group: ProcessGroup, dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_ddp: bool, use_ddp: bool,
ddp_config: dict, ddp_config: dict,
custom_policy: Policy, custom_policy: Policy,
) -> None: ) -> None:
self.stage_manager = shard_config.pipeline_stage_manager self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
self.dp_group = dp_group self.dp_group = dp_group
self.tp_group = tp_group
self.use_dpp = use_ddp
self.require_grad_sync = True
shardformer = ShardFormer(shard_config) shardformer = ShardFormer(shard_config)
if custom_policy is not None: if custom_policy is not None:
@ -98,19 +104,75 @@ class HybridParallelModule(ModelWrapper):
dist.all_reduce(param.grad, group=group) dist.all_reduce(param.grad, group=group)
dist.barrier() dist.barrier()
def no_sync(self) -> Iterator[None]: @contextmanager
# no sync grads across data parallel def no_sync(self):
return nullcontext() r"""
A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization
when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass
when exiting the context.
"""
def sync_grads(self): # Store the current value of 'require_grad_sync' to restore it later.
# sync grad across data parallel old_require_grad_sync = self.require_grad_sync
# Disable automatic gradient synchronization.
self.require_grad_sync = False
try:
if self.use_dpp:
# If using data parallel processing (use_dpp), disable synchronization too.
with self.module.no_sync():
yield
else:
yield
finally:
# Restore the original value of 'require_grad_sync'.
self.require_grad_sync = old_require_grad_sync
def sync_dp_grads(self):
r"""
Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1.
This function performs an all-reduce operation to combine gradients from different devices in the DP group.
Args:
None
Returns:
None
"""
# Check if the DP group size is 1, meaning no synchronization is needed.
if self.dp_group.size() == 1: if self.dp_group.size() == 1:
return return
# Iterate through the model's parameters and perform gradient synchronization.
for p in self.module.parameters(): for p in self.module.parameters():
if p.grad is not None: if p.grad is not None:
# Perform all-reduce to combine gradients from different devices.
dist.all_reduce(p.grad, group=self.dp_group) dist.all_reduce(p.grad, group=self.dp_group)
# Normalize the gradient by dividing it by the DP group size.
p.grad.div_(self.dp_group.size()) p.grad.div_(self.dp_group.size())
def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
r"""
Synchronize gradients that are partially derived within sequence parallelism
if sequence parallelism is enabled. Gradients can be provided explicitly or extracted
from the module.
Args:
grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not
provided, gradients will be extracted from the model.
Returns:
None
"""
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
else:
# Synchronize gradients from the model across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if self.convert_fn is not None: if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) args = tree_map(self.convert_fn, args)
@ -166,7 +228,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__( def __init__(
self, self,
optim: Optimizer, optim: Optimizer,
model: Module, model: HybridParallelModule,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict, param_info: OrderedDict,
max_norm: float = 0, max_norm: float = 0,
@ -176,13 +238,69 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self.param_info = param_info self.param_info = param_info
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
self.model = model
self.stage_manager = model.stage_manager self.stage_manager = model.stage_manager
self.shared_params = model.shared_params self.shared_params = model.shared_params
self.max_norm = max_norm self.max_norm = max_norm
self.tp_pg = tp_process_group self.tp_pg = tp_process_group
self.pp_pg = pp_process_group self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
super().__init__(optim) super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs backward pass for gradient computation. If sequence parallelism is enabled
and gradient synchronization is required, it will synchronize gradients that are partially derived
within sequence parallelism across tp parallelism groups.
Args:
loss (Tensor): The loss tensor to compute gradients with respect to.
*args: Additional positional arguments to be passed to the superclass backward method.
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across tp parallelism groups.
Args:
tensor (Tensor): The input tensor for which gradients are computed.
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
r""" r"""
Perform an optimization step. Perform an optimization step.
@ -220,8 +338,6 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if len(param_gradient_pairs) == 0: if len(param_gradient_pairs) == 0:
return 0.0 return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type) norm_type = float(norm_type)
# gradients used for norm calculation. # gradients used for norm calculation.
@ -230,9 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if norm_type == inf: if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients) total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1: if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1: if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
else: else:
@ -250,16 +366,16 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices. # However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1: if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)] param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad): if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages. # it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by # To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages. # the number of shared stages.
if pp_size > 1: if self.pp_size > 1:
for shared_param in self.shared_params: for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param: if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage] stage_shared_param = shared_param[self.stage_manager.stage]
@ -269,10 +385,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1: if self.tp_size > 1:
# compute norm in tp process group # compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1: if self.pp_size > 1:
# compute norm in pp process group # compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
@ -314,7 +430,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__( def __init__(
self, self,
optim: Optimizer, optim: Optimizer,
model: Module, model: HybridParallelModule,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict, param_info: OrderedDict,
precision: str = "fp16", precision: str = "fp16",
@ -329,11 +445,14 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
tp_process_group: Optional[ProcessGroup] = None, # if using tp tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp pp_process_group: Optional[ProcessGroup] = None, # if using pp
): ):
self.model = model
self.param_info = param_info self.param_info = param_info
self.stage_manager = model.stage_manager self.stage_manager = model.stage_manager
self.shared_params = model.shared_params self.shared_params = model.shared_params
self.tp_pg = tp_process_group self.tp_pg = tp_process_group
self.pp_pg = pp_process_group self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline: if use_pipeline:
init_pipeline_optimizer(optim, model) init_pipeline_optimizer(optim, model)
super().__init__( super().__init__(
@ -349,6 +468,59 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
max_norm=max_norm, max_norm=max_norm,
) )
def backward(self, loss: Tensor, *args, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs backward pass for gradient computation. If sequence parallelism is enabled
and gradient synchronization is required, it will synchronize gradients that are partially derived
within sequence parallelism across tp parallelism groups.
Args:
loss (Tensor): The loss tensor to compute gradients with respect to.
*args: Additional positional arguments to be passed to the superclass backward method.
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across tp parallelism groups.
Args:
tensor (Tensor): The input tensor for which gradients are computed.
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r""" r"""
Compute and return the gradient norm for gradient clipping. Compute and return the gradient norm for gradient clipping.
@ -363,8 +535,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
if len(param_gradient_pairs) == 0: if len(param_gradient_pairs) == 0:
return 0.0 return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type) norm_type = float(norm_type)
if norm_type == inf: if norm_type == inf:
@ -374,9 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1: if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1: if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
@ -396,16 +566,16 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices. # However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1: if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)] param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad): if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages. # it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by # To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages. # the number of shared stages.
if pp_size > 1: if self.pp_size > 1:
for shared_param in self.shared_params: for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param: if self.stage_manager.stage in shared_param:
stage_working_shared_param = shared_param[self.stage_manager.stage] stage_working_shared_param = shared_param[self.stage_manager.stage]
@ -416,10 +586,10 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1: if self.tp_size > 1:
# compute norm in tp process group # compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1: if self.pp_size > 1:
# compute norm in pp process group # compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
@ -433,7 +603,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__( def __init__(
self, self,
optimizer: Optimizer, optimizer: Optimizer,
model: Module, model: HybridParallelModule,
use_pipeline: bool, use_pipeline: bool,
param_info: OrderedDict, param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config initial_scale: int = 2**16, # grad scaler config
@ -455,6 +625,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
pp_process_group: Optional[ProcessGroup] = None, # if using pp pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
): ):
self.model = model
self.param_info = param_info self.param_info = param_info
self.stage_manager = model.stage_manager self.stage_manager = model.stage_manager
self.shared_params = model.shared_params self.shared_params = model.shared_params
@ -483,6 +654,123 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
) )
def sync_dp_grads(self):
r"""
Synchronize gradients in the data parallelism dimension.
This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients
in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions,
namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization
and readability.
Args:
None
Returns:
None
"""
# Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad()
def _sync_sp_grads(self):
r"""
Synchronize gradients that are partially derived within sequence parallelism.
This method is responsible for synchronizing partially derived gradients across tp parallelism groups.
It identifies gradients that ara partially derived or not and synchronizes them.
If synchronization is required and gradients are found to be synchronized,
it performs the synchronization.
Args:
None
Returns:
None
"""
def _get_all_working_grads() -> List[Tensor]:
"""Retrieve all working gradients from different parameter groups."""
all_working_grads = []
for group_id in range(self.num_param_groups):
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
all_working_grads.extend(working_grads)
return all_working_grads
def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
"""Identify gradients to be synchronized in the sequence parallelism."""
grads_to_sync = []
for grad in all_working_grads:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
grads_to_sync.append(grad)
if len(grads_to_sync) > 0:
return grads_to_sync
else:
return None
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
else:
return
def backward(self, loss, retain_graph=False):
"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs the backward pass for gradient computation based on a given loss tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across TP parallelism groups.
Args:
loss: The loss tensor to compute gradients with respect to.
retain_graph (bool): Whether to retain the computation graph.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor, grad):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation based on a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across TP parallelism groups.
Args:
tensor: The input tensor for which gradients are computed.
grad: The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r""" r"""
Compute and return the gradient norm for gradient clipping. Compute and return the gradient norm for gradient clipping.
@ -768,7 +1056,14 @@ class HybridParallelPlugin(PipelinePluginBase):
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule( model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0: if self.zero_stage == 0:
@ -826,17 +1121,32 @@ class HybridParallelPlugin(PipelinePluginBase):
return_outputs: bool = False, return_outputs: bool = False,
) -> dict: ) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
# return loss or outputs if needed
# Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx: with ctx:
outputs = self.schedule.forward_backward_step( outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs model, data_iter, criterion, optimizer, return_loss, return_outputs
) )
# Synchronize the grads of shared parameters of the model.
model.sync_shared_params() model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()
# Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
# Otherwise, synchronize data parallelism gradients of the model.
# This is because these are two different forms of data parallelism.
if isinstance(optimizer, HybridParallelZeroOptimizer): if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad() optimizer.sync_dp_grads()
else: else:
model.sync_grads() model.sync_dp_grads()
return outputs return outputs
def prepare_dataloader( def prepare_dataloader(

View File

@ -338,7 +338,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule( model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
) )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0: if self.zero_stage == 0:

View File

@ -218,10 +218,10 @@ class TPInferEngine:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__ model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
model = model.model if self.shard_config.inference_gptq else model model = model.model if self.shard_config.inference_gptq else model
policy = get_autopolicy(model, shard_config=self.shard_config)
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy) self.model, _ = shardformer.optimize(model, policy)
if self.shard_config.inference_gptq: if self.shard_config.inference_gptq:

View File

@ -235,6 +235,14 @@ class SubModuleReplacementDescription:
class Policy(ABC): class Policy(ABC):
r"""
The base class for all the policies. For each different model, it should have a different policy class,
like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""
def __init__(self) def __init__(self)
self.model = None self.model = None
@ -245,6 +253,16 @@ class Policy(ABC):
""" """
self.model = model self.model = model
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@abstractmethod @abstractmethod
def preprocess(self) -> nn.Module: def preprocess(self) -> nn.Module:
""" """

View File

@ -2,7 +2,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
@ -16,6 +16,9 @@ __all__ = [
"DropoutForParallelInput", "DropoutForParallelInput",
"DropoutForReplicatedInput", "DropoutForReplicatedInput",
"cross_entropy_1d", "cross_entropy_1d",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
"FusedLayerNorm", "FusedLayerNorm",
"FusedRMSNorm", "FusedRMSNorm",
"FusedLinear1D_Col", "FusedLinear1D_Col",

View File

@ -1,11 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch.nn as nn import torch.nn as nn
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
__all__ = ["FusedLayerNorm", "FusedRMSNorm"] from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,
@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
] ]
class FusedLayerNorm: class BaseLayerNorm(ABC):
@abstractmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
"""
Convert a native PyTorch layer normalization module to a specific layer normalization module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native PyTorch layer normalization module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The specific layer normalization module.
Raises:
AssertionError: If the provided module is not an instance of the supported layer normalization type.
"""
class RMSNorm(BaseLayerNorm):
r"""
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
"""
Convert a native RMSNorm module to colossalai layer norm module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native RMSNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The RMSNorm module.
"""
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
return module
class LayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"LayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The LayerNorm module.
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
return module
class FusedLayerNorm(BaseLayerNorm):
r""" r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
""" """
@ -43,15 +142,29 @@ class FusedLayerNorm:
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. " "FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex." "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r""" r"""
Convert a native pytorch layer norm module to colossalai layer norm module Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
""" """
# check if apex is installed # check if apex is installed
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
try: try:
pass pass
except ImportError: except ImportError:
@ -85,10 +198,18 @@ class FusedLayerNorm:
layernorm.weight = module.weight layernorm.weight = module.weight
layernorm.bias = module.bias layernorm.bias = module.bias
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm return layernorm
class FusedRMSNorm: class FusedRMSNorm(BaseLayerNorm):
""" """
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
""" """
@ -96,11 +217,22 @@ class FusedRMSNorm:
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. " "FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex." "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: FusedRMSNorm module.
"""
try: try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError: except ImportError:
@ -124,4 +256,10 @@ class FusedRMSNorm:
rmsnorm.weight = module.weight rmsnorm.weight = module.weight
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
return rmsnorm return rmsnorm

View File

@ -1,8 +1,82 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
class SeqParallelUtils:
@staticmethod
def marked_as_sp_partial_derived_param(param):
"""
Mark a parameter as partially derived in sequence parallelism.
Args:
param: The parameter to mark as partially derived.
"""
setattr(param, "partial_derived", True)
@staticmethod
def is_sp_partial_derived_param(param):
"""
Check if a parameter is marked as partially derived in sequence parallelism.
Args:
param: The parameter to check.
Returns:
bool: True if the parameter is marked as partially derived, False otherwise.
"""
return getattr(param, "partial_derived", False)
@staticmethod
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
"""
Allreduce partial derived gradients across the specified process group.
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
Args:
tp_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized.
Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided.
"""
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
# Get the size of the process group, which determines whether synchronization is needed.
tp_size = get_world_size(tp_group) if tp_group is not None else 1
if tp_size == 1:
# If the process group size is 1, no synchronization is required.
return
if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters.
grads = []
for p in model.parameters():
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
grads.append(p.grad.data)
# Flatten and reduce the gradients using the specified process group.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
# Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
else:
# If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
class Randomizer: class Randomizer:

View File

@ -4,6 +4,7 @@ from typing import Optional
import torch.nn as nn import torch.nn as nn
from ..shard.shard_config import ShardConfig
from .base_policy import Policy from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] __all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
@ -197,7 +198,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__ return module + "." + klass.__qualname__
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy: def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
r""" r"""
Return the auto policy for the model Return the auto policy for the model
@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
:class:`Policy`: The auto policy for the model :class:`Policy`: The auto policy for the model
""" """
full_name = _fullname(model) full_name = _fullname(model)
if inference_only: if shard_config.inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None) policy_location = _INFER_POLICY_LIST.get(full_name, None)
else: else:
policy_location = _POLICY_LIST.get(full_name, None) policy_location = _POLICY_LIST.get(full_name, None)
@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location, inference_only) policy = import_policy(policy_location, shard_config.inference_only)
return policy() return policy()

View File

@ -11,6 +11,7 @@ from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from ..layer.normalization import BaseLayerNorm
from ..layer.parallel_module import ParallelModule from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig from ..shard.shard_config import ShardConfig
@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
""" """
suffix: str suffix: str
target_module: ParallelModule target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False ignore_if_not_exist: bool = False
@ -77,7 +78,6 @@ class Policy(ABC):
def set_model(self, model: nn.Module) -> None: def set_model(self, model: nn.Module) -> None:
r""" r"""
Set model as an attribute of the Policy object so that we can access the model's attributes. Set model as an attribute of the Policy object so that we can access the model's attributes.
Args: Args:
model (:class:`nn.Module`): The model to be perform model (:class:`nn.Module`): The model to be perform
""" """
@ -86,11 +86,11 @@ class Policy(ABC):
def set_shard_config(self, shard_config: ShardConfig) -> None: def set_shard_config(self, shard_config: ShardConfig) -> None:
r""" r"""
Set shard config as an attribute of the Policy object. Set shard config as an attribute of the Policy object.
Args: Args:
shard_config (:class:`ShardConfig`): The shard config to be perform shard_config (:class:`ShardConfig`): The shard config to be perform
""" """
self.shard_config = shard_config self.shard_config = shard_config
self.config_sanity_check() self.config_sanity_check()
@property @property

View File

@ -60,6 +60,12 @@ class BertPolicy(Policy):
) )
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -141,33 +147,34 @@ class BertPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # Handle bert layer
# Handle bert layer self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="attention.output.LayerNorm",
suffix="attention.output.LayerNorm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, kwargs={"sp_partial_derived": use_sequence_parallel},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.LayerNorm", suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), kwargs={"sp_partial_derived": use_sequence_parallel},
], ),
policy=policy, ],
target_key=BertLayer, policy=policy,
) target_key=BertLayer,
# handle embedding layer )
self.append_or_create_submodule_replacement( # handle embedding layer
description=[ self.append_or_create_submodule_replacement(
SubModuleReplacementDescription( description=[
suffix="LayerNorm", SubModuleReplacementDescription(
target_module=col_nn.FusedLayerNorm, suffix="LayerNorm",
) target_module=norm_cls,
], )
policy=policy, ],
target_key=BertEmbeddings, policy=policy,
) target_key=BertEmbeddings,
)
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
@ -288,9 +295,6 @@ class BertPolicy(Policy):
# BertModel # BertModel
class BertModelPolicy(BertPolicy): class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel from transformers.models.bert.modeling_bert import BertModel
@ -313,9 +317,6 @@ class BertModelPolicy(BertPolicy):
# BertForPreTraining # BertForPreTraining
class BertForPreTrainingPolicy(BertPolicy): class BertForPreTrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
policy = self.add_lm_head_policy(policy) policy = self.add_lm_head_policy(policy)
@ -355,9 +356,6 @@ class BertForPreTrainingPolicy(BertPolicy):
# BertLMHeadModel # BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy): class BertLMHeadModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
policy = self.add_lm_head_policy(policy) policy = self.add_lm_head_policy(policy)
@ -396,9 +394,6 @@ class BertLMHeadModelPolicy(BertPolicy):
# BertForMaskedLM # BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy): class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
policy = self.add_lm_head_policy(policy) policy = self.add_lm_head_policy(policy)
@ -437,9 +432,6 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertForSequenceClassification # BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy): class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification from transformers.models.bert.modeling_bert import BertForSequenceClassification
@ -484,9 +476,6 @@ class BertForSequenceClassificationPolicy(BertPolicy):
# BertForTokenClassification # BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy): class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification from transformers.models.bert.modeling_bert import BertForTokenClassification
@ -531,9 +520,6 @@ class BertForTokenClassificationPolicy(BertPolicy):
# BertForNextSentencePrediction # BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy): class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
@ -564,9 +550,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
# BertForMultipleChoice # BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy): class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice from transformers.models.bert.modeling_bert import BertForMultipleChoice
@ -610,9 +593,6 @@ class BertForMultipleChoicePolicy(BertPolicy):
class BertForQuestionAnsweringPolicy(BertPolicy): class BertForQuestionAnsweringPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.bert.modeling_bert import BertForQuestionAnswering from transformers.models.bert.modeling_bert import BertForQuestionAnswering

View File

@ -43,6 +43,11 @@ class BlipPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[Blip2EncoderLayer] = ModulePolicyDescription( policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
@ -214,94 +219,93 @@ class BlipPolicy(Policy):
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # Handle Blip2EncoderLayer layer
# Handle Blip2EncoderLayer layer self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="layer_norm1",
suffix="layer_norm1", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="layer_norm2",
suffix="layer_norm2", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), ],
], policy=policy,
policy=policy, target_key=Blip2EncoderLayer,
target_key=Blip2EncoderLayer, )
)
# handle Blip2VisionModel layer # handle Blip2VisionModel layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="post_layernorm", suffix="post_layernorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=Blip2VisionModel, target_key=Blip2VisionModel,
) )
# handle Blip2VisionModel layer # handle Blip2VisionModel layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layernorm", suffix="layernorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=Blip2QFormerModel, target_key=Blip2QFormerModel,
) )
# handle Blip2QFormerLayer layer # handle Blip2QFormerLayer layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.LayerNorm", suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm", suffix="crossattention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output_query.LayerNorm", suffix="output_query.LayerNorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
], ],
policy=policy, policy=policy,
target_key=Blip2QFormerLayer, target_key=Blip2QFormerLayer,
) )
# handle OPTForCausalLM layer # handle OPTForCausalLM layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm", suffix="model.decoder.final_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=OPTForCausalLM, target_key=OPTForCausalLM,
) )
# handle OPTDecoderLayer layer # handle OPTDecoderLayer layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn_layer_norm", suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="final_layer_norm", suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
], ],
policy=policy, policy=policy,
target_key=OPTDecoderLayer, target_key=OPTDecoderLayer,
) )
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:

View File

@ -42,6 +42,10 @@ class BloomPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -97,38 +101,39 @@ class BloomPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # handle bloom model
# handle bloom model self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="ln_f",
suffix="ln_f", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="word_embeddings_layernorm",
suffix="word_embeddings_layernorm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), ],
], policy=policy,
policy=policy, target_key=BloomModel,
target_key=BloomModel, )
)
# handle bloom block # handle bloom block
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="input_layernorm", suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), kwargs={"sp_partial_derived": use_sequence_parallel},
SubModuleReplacementDescription( ),
suffix="post_attention_layernorm", SubModuleReplacementDescription(
target_module=col_nn.FusedLayerNorm, suffix="post_attention_layernorm",
), target_module=norm_cls,
], kwargs={"sp_partial_derived": use_sequence_parallel},
policy=policy, ),
target_key=BloomBlock, ],
) policy=policy,
target_key=BloomBlock,
)
if use_sequence_parallel: if use_sequence_parallel:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
@ -225,9 +230,6 @@ class BloomPolicy(Policy):
class BloomModelPolicy(BloomPolicy): class BloomModelPolicy(BloomPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
from transformers.models.bloom.modeling_bloom import BloomModel from transformers.models.bloom.modeling_bloom import BloomModel

View File

@ -45,6 +45,16 @@ class ChatGLMPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
if self.model.config.rmsnorm:
norm_cls = col_nn.FusedRMSNorm
else:
norm_cls = col_nn.FusedLayerNorm
else:
if self.model.config.rmsnorm:
norm_cls = col_nn.RMSNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -96,52 +106,34 @@ class ChatGLMPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
if not self.model.config.rmsnorm: description=[
self.append_or_create_submodule_replacement( SubModuleReplacementDescription(
description=[ suffix="input_layernorm",
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), target_module=norm_cls,
SubModuleReplacementDescription( kwargs={"sp_partial_derived": use_sequence_parallel},
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm ),
), SubModuleReplacementDescription(
], suffix="post_attention_layernorm",
policy=policy, target_module=norm_cls,
target_key=GLMBlock, kwargs={"sp_partial_derived": use_sequence_parallel},
) ),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm: if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm suffix="encoder.final_layernorm",
) target_module=norm_cls,
],
policy=policy,
target_key=ChatGLMModel,
)
else:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm
),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm
)
],
policy=policy,
target_key=ChatGLMModel,
) )
],
policy=policy,
target_key=ChatGLMModel,
)
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
@ -224,9 +216,6 @@ class ChatGLMPolicy(Policy):
class ChatGLMModelPolicy(ChatGLMPolicy): class ChatGLMModelPolicy(ChatGLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
pass pass

View File

@ -39,6 +39,11 @@ class GPT2Policy(Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
@ -102,33 +107,37 @@ class GPT2Policy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(
description=SubModuleReplacementDescription( suffix="ln_f",
suffix="ln_f", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), policy=policy,
policy=policy, target_key=GPT2Model,
target_key=GPT2Model, )
)
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="ln_1", suffix="ln_1",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), kwargs={"sp_partial_derived": use_sequence_parallel},
SubModuleReplacementDescription( ),
suffix="ln_2", SubModuleReplacementDescription(
target_module=col_nn.FusedLayerNorm, suffix="ln_2",
), target_module=norm_cls,
SubModuleReplacementDescription( kwargs={"sp_partial_derived": use_sequence_parallel},
suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True ),
), SubModuleReplacementDescription(
], suffix="ln_cross_attn",
policy=policy, target_module=norm_cls,
target_key=GPT2Block, ignore_if_not_exist=True,
) kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
target_key=GPT2Block,
)
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
@ -192,9 +201,6 @@ class GPT2Policy(Policy):
# GPT2Model # GPT2Model
class GPT2ModelPolicy(GPT2Policy): class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
@ -216,9 +222,6 @@ class GPT2ModelPolicy(GPT2Policy):
# GPT2LMHeadModel # GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy): class GPT2LMHeadModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
@ -263,9 +266,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
# GPT2DoubleHeadsModel # GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy): class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
@ -317,9 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
# GPT2ForQuestionAnswering # GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy): class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
@ -347,9 +344,6 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
# GPT2ForTokenClassification # GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy): class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
@ -387,9 +381,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
# GPT2ForSequenceClassification # GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy): class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -35,6 +35,11 @@ class LlamaPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@ -93,31 +98,31 @@ class LlamaPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="input_layernorm",
suffix="input_layernorm", target_module=norm_cls,
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=LlamaDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
), ),
policy=policy, SubModuleReplacementDescription(
target_key=LlamaModel, suffix="post_attention_layernorm",
) target_module=norm_cls,
),
],
policy=policy,
target_key=LlamaDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=norm_cls,
),
policy=policy,
target_key=LlamaModel,
)
# use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
@ -174,9 +179,6 @@ class LlamaPolicy(Policy):
class LlamaModelPolicy(LlamaPolicy): class LlamaModelPolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
policy = super().module_policy() policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.llama.modeling_llama import LlamaModel

View File

@ -5,7 +5,7 @@ from typing import Callable, Dict, List
import torch.nn as nn import torch.nn as nn
from torch import Tensor, nn from torch import Tensor, nn
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_ from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.jit import get_jit_fused_dropout_add_func
@ -42,6 +42,12 @@ class OPTPolicy(Policy):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@ -94,26 +100,25 @@ class OPTPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(
description=SubModuleReplacementDescription( suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True ),
policy=policy,
target_key=OPTDecoder,
)
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
), ),
policy=policy, SubModuleReplacementDescription(
target_key=OPTDecoder, suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
) ),
self.append_or_create_submodule_replacement( ],
description=[ policy=policy,
SubModuleReplacementDescription( target_key=OPTDecoderLayer,
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True )
),
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
),
],
policy=policy,
target_key=OPTDecoderLayer,
)
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
@ -183,9 +188,6 @@ class OPTPolicy(Policy):
class OPTModelPolicy(OPTPolicy): class OPTModelPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.opt.modeling_opt import OPTModel from transformers.models.opt.modeling_opt import OPTModel
@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
class OPTForSequenceClassificationPolicy(OPTPolicy): class OPTForSequenceClassificationPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForSequenceClassification from transformers.models.opt.modeling_opt import OPTForSequenceClassification
@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
class OPTForQuestionAnsweringPolicy(OPTPolicy): class OPTForQuestionAnsweringPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering from transformers.models.opt.modeling_opt import OPTForQuestionAnswering

View File

@ -24,6 +24,11 @@ class SamPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[SamVisionLayer] = ModulePolicyDescription( policy[SamVisionLayer] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
@ -151,58 +156,57 @@ class SamPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # Handle SamVisionLayer
# Handle SamVisionLayer self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="layer_norm1",
suffix="layer_norm1", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="layer_norm2",
suffix="layer_norm2", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), ],
], policy=policy,
policy=policy, target_key=SamVisionLayer,
target_key=SamVisionLayer, )
)
# Handle SamTwoWayAttentionBlock # Handle SamTwoWayAttentionBlock
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm1", suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm2", suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm3", suffix="layer_norm3",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm4", suffix="layer_norm4",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
], ],
policy=policy, policy=policy,
target_key=SamTwoWayAttentionBlock, target_key=SamTwoWayAttentionBlock,
) )
# Handle SamTwoWayTransformer # Handle SamTwoWayTransformer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm_final_attn", suffix="layer_norm_final_attn",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=SamTwoWayTransformer, target_key=SamTwoWayTransformer,
) )
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:

View File

@ -11,6 +11,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
RMSNorm,
VocabParallelEmbedding1D, VocabParallelEmbedding1D,
) )
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
@ -58,6 +59,11 @@ class T5BasePolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@ -169,38 +175,37 @@ class T5BasePolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(
description=SubModuleReplacementDescription( suffix="layer_norm",
suffix="layer_norm", target_module=norm_cls,
target_module=FusedRMSNorm, ),
), policy=policy,
policy=policy, target_key=T5LayerFF,
target_key=T5LayerFF, )
) self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(
description=SubModuleReplacementDescription( suffix="layer_norm",
suffix="layer_norm", target_module=norm_cls,
target_module=FusedRMSNorm, ),
), policy=policy,
policy=policy, target_key=T5LayerFF,
target_key=T5LayerFF, )
) self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), policy=policy,
policy=policy, target_key=T5LayerSelfAttention,
target_key=T5LayerSelfAttention, )
) self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), policy=policy,
policy=policy, target_key=T5LayerCrossAttention,
target_key=T5LayerCrossAttention, )
) self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls),
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm), policy=policy,
policy=policy, target_key=T5Stack,
target_key=T5Stack, )
)
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
@ -363,9 +368,6 @@ class T5BasePolicy(Policy):
class T5ModelPolicy(T5BasePolicy): class T5ModelPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import T5Model from transformers import T5Model
@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy):
class T5ForConditionalGenerationPolicy(T5BasePolicy): class T5ForConditionalGenerationPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import T5ForConditionalGeneration from transformers import T5ForConditionalGeneration
@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
class T5EncoderPolicy(T5BasePolicy): class T5EncoderPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import T5EncoderModel from transformers import T5EncoderModel

View File

@ -159,9 +159,6 @@ class ViTPolicy(Policy):
# ViTModel # ViTModel
class ViTModelPolicy(ViTPolicy): class ViTModelPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.vit.modeling_vit import ViTModel from transformers.models.vit.modeling_vit import ViTModel
@ -227,9 +224,6 @@ class ViTForImageClassificationPolicy(ViTPolicy):
# ViTForMaskedImageModeling # ViTForMaskedImageModeling
class ViTForMaskedImageModelingPolicy(ViTPolicy): class ViTForMaskedImageModelingPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel

View File

@ -52,6 +52,11 @@ class WhisperPolicy(Policy):
policy = {} policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn( warnings.warn(
@ -161,62 +166,61 @@ class WhisperPolicy(Policy):
) )
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: # Handle encoder layer
# Handle encoder layer self.append_or_create_submodule_replacement(
self.append_or_create_submodule_replacement( description=[
description=[ SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="self_attn_layer_norm",
suffix="self_attn_layer_norm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="final_layer_norm",
suffix="final_layer_norm", target_module=norm_cls,
target_module=col_nn.FusedLayerNorm, ),
), ],
], policy=policy,
policy=policy, target_key=WhisperEncoderLayer,
target_key=WhisperEncoderLayer, )
)
# Handle decoder layer # Handle decoder layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn_layer_norm", suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="final_layer_norm", suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
), ),
], ],
policy=policy, policy=policy,
target_key=WhisperDecoderLayer, target_key=WhisperDecoderLayer,
) )
# handle encoder layer # handle encoder layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm", suffix="layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=WhisperEncoder, target_key=WhisperEncoder,
) )
# handle decoder layer # handle decoder layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="layer_norm", suffix="layer_norm",
target_module=col_nn.FusedLayerNorm, target_module=norm_cls,
) )
], ],
policy=policy, policy=policy,
target_key=WhisperDecoder, target_key=WhisperDecoder,
) )
# enable flash attention # enable flash attention
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
@ -416,9 +420,6 @@ class WhisperPolicy(Policy):
# WhisperModel # WhisperModel
class WhisperModelPolicy(WhisperPolicy): class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import WhisperModel from transformers import WhisperModel
@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy):
# WhisperForConditionalGeneration # WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy): class WhisperForConditionalGenerationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self): def module_policy(self):
from transformers import WhisperForConditionalGeneration from transformers import WhisperForConditionalGeneration
@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
# WhisperForAudioClassification # WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy): class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def preprocess(self): def preprocess(self):
return self.model return self.model

View File

@ -27,8 +27,8 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model self.model = model
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
self.shard_config = shard_config self.shard_config = shard_config
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy
def shard(self) -> List[Dict[int, Tensor]]: def shard(self) -> List[Dict[int, Tensor]]:
r""" r"""
@ -196,7 +196,7 @@ class ModelSharder(object):
try: try:
replace_layer = target_module.from_native_module( replace_layer = target_module.from_native_module(
native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs native_sub_module, process_group=self.shard_config.tensor_parallel_process_group, **kwargs
) )
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(

View File

@ -700,7 +700,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
############################ ############################
# this method is used to sync gradient manually # this method is used to sync gradient manually
def sync_grad(self): def _sync_grad(self):
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id] param_group = self._working_param_groups[group_id]
for param in param_group: for param in param_group:
@ -713,7 +713,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# if not overlapping communication (no reduction hook is attached) when zero1 # if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients # we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication: if not partition_grad and not self._overlap_communication:
self.sync_grad() self._sync_grad()
else: else:
self._run_reduction() self._run_reduction()

View File

@ -24,7 +24,6 @@ for k, v in inputs.items():
new_shape[0] = 16 new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape) inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.LlamaForCausalLM( model = transformers.LlamaForCausalLM(
transformers.LlamaConfig( transformers.LlamaConfig(
@ -59,6 +58,7 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
@parameterize("pp_size", [2]) @parameterize("pp_size", [2])
@parameterize("max_output_len", [4]) @parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1]) @parameterize("micro_batch_size", [1])
@clear_cache_before_run() @clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
@ -76,6 +76,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()

View File

@ -128,7 +128,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,
@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,
@ -147,7 +147,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,
@ -157,7 +157,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,
@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,
@ -174,7 +174,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,
@ -199,7 +199,7 @@ def run_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,
@ -208,7 +208,7 @@ def run_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,

View File

@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp32", "precision": "fp32",
"max_norm": 5, "max_norm": 5,
@ -114,7 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"max_norm": 5, "max_norm": 5,
@ -123,7 +123,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"max_norm": 5, "max_norm": 5,
@ -148,7 +148,7 @@ def run_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"max_norm": 5, "max_norm": 5,

View File

@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"zero_stage": 1, "zero_stage": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,
@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"zero_stage": 1, "zero_stage": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,
@ -126,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"zero_stage": 2, "zero_stage": 2,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,
@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"zero_stage": 1, "zero_stage": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,
@ -146,7 +146,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"zero_stage": 1, "zero_stage": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,
@ -155,7 +155,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"zero_stage": 2, "zero_stage": 2,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,
@ -181,7 +181,7 @@ def run_test(test_config):
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"zero_stage": 1, "zero_stage": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "bf16", "precision": "bf16",
"max_norm": 5, "max_norm": 5,
@ -191,7 +191,7 @@ def run_test(test_config):
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"zero_stage": 1, "zero_stage": 1,
"enable_all_optimization": False, "enable_all_optimization": True,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"max_norm": 5, "max_norm": 5,

View File

@ -34,6 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
bert = unwrap_model(org_model, "BertModel", "bert") bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
col_layer_for_check = ["encoder.layer[0].output.dense"] col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
@ -50,8 +51,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_grads = get_grad_tensors_for_check( row_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
) )
norm_layer_grads = get_grad_tensors_for_check(
bert,
sharded_bert,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads) grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads) grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step # optimizer executes step
org_optimizer.step() org_optimizer.step()
@ -85,6 +99,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp32",
},
{ {
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,

View File

@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
bloom = unwrap_model(org_model, "BloomModel", "transformer") bloom = unwrap_model(org_model, "BloomModel", "transformer")
sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer") sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer")
norm_layer_for_check = ["word_embeddings_layernorm", "h[0].input_layernorm"]
row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"] row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"]
col_layer_for_check = ["h[0].self_attention.dense"] col_layer_for_check = ["h[0].self_attention.dense"]
@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
col_layer_grads = get_grad_tensors_for_check( col_layer_grads = get_grad_tensors_for_check(
bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
) )
norm_layer_grads = get_grad_tensors_for_check(
bloom,
sharded_bloom,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads) grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads) grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step # optimizer executes step
org_optimizer.step() org_optimizer.step()

View File

@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer") chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer")
shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer") shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer")
norm_layer_for_check = ["encoder.layers[0].input_layernorm"]
row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"] row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"]
col_layer_for_check = ["encoder.layers[0].self_attention.dense"] col_layer_for_check = ["encoder.layers[0].self_attention.dense"]
@ -66,8 +67,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1, dim=1,
verbose=False, verbose=False,
) )
norm_layer_grads = get_grad_tensors_for_check(
chatglm_model,
shard_chatglm_model,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads) grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads) grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step # optimizer executes step
org_optimizer.step() org_optimizer.step()

View File

@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
gpt2 = unwrap_model(org_model, "GPT2Model", "transformer") gpt2 = unwrap_model(org_model, "GPT2Model", "transformer")
sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer") sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer")
norm_layer_for_check = ["h[0].ln_1", "h[0].ln_2"]
col_layer_for_check = ["h[0].mlp.c_fc"] col_layer_for_check = ["h[0].mlp.c_fc"]
row_layer_for_check = ["wte", "h[0].mlp.c_proj"] row_layer_for_check = ["wte", "h[0].mlp.c_proj"]
@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_grads = get_grad_tensors_for_check( row_layer_grads = get_grad_tensors_for_check(
gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
) )
norm_layer_grads = get_grad_tensors_for_check(
gpt2,
sharded_gpt2,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads) grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads) grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step # optimizer executes step
org_optimizer.step() org_optimizer.step()