mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>pull/5007/head
parent
d99b2c961a
commit
1a3315e336
|
@ -1,6 +1,6 @@
|
|||
import ctypes
|
||||
import random
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
@ -25,6 +25,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
|||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
@ -47,12 +48,17 @@ class HybridParallelModule(ModelWrapper):
|
|||
precision: str,
|
||||
shard_config: ShardConfig,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_ddp: bool,
|
||||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
self.dp_group = dp_group
|
||||
self.tp_group = tp_group
|
||||
self.use_dpp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
|
@ -98,19 +104,75 @@ class HybridParallelModule(ModelWrapper):
|
|||
dist.all_reduce(param.grad, group=group)
|
||||
dist.barrier()
|
||||
|
||||
def no_sync(self) -> Iterator[None]:
|
||||
# no sync grads across data parallel
|
||||
return nullcontext()
|
||||
@contextmanager
|
||||
def no_sync(self):
|
||||
r"""
|
||||
A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization
|
||||
when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass
|
||||
when exiting the context.
|
||||
"""
|
||||
|
||||
def sync_grads(self):
|
||||
# sync grad across data parallel
|
||||
# Store the current value of 'require_grad_sync' to restore it later.
|
||||
old_require_grad_sync = self.require_grad_sync
|
||||
# Disable automatic gradient synchronization.
|
||||
self.require_grad_sync = False
|
||||
try:
|
||||
if self.use_dpp:
|
||||
# If using data parallel processing (use_dpp), disable synchronization too.
|
||||
with self.module.no_sync():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
finally:
|
||||
# Restore the original value of 'require_grad_sync'.
|
||||
self.require_grad_sync = old_require_grad_sync
|
||||
|
||||
def sync_dp_grads(self):
|
||||
r"""
|
||||
Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1.
|
||||
This function performs an all-reduce operation to combine gradients from different devices in the DP group.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Check if the DP group size is 1, meaning no synchronization is needed.
|
||||
if self.dp_group.size() == 1:
|
||||
return
|
||||
|
||||
# Iterate through the model's parameters and perform gradient synchronization.
|
||||
for p in self.module.parameters():
|
||||
if p.grad is not None:
|
||||
# Perform all-reduce to combine gradients from different devices.
|
||||
dist.all_reduce(p.grad, group=self.dp_group)
|
||||
# Normalize the gradient by dividing it by the DP group size.
|
||||
p.grad.div_(self.dp_group.size())
|
||||
|
||||
def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
|
||||
r"""
|
||||
Synchronize gradients that are partially derived within sequence parallelism
|
||||
if sequence parallelism is enabled. Gradients can be provided explicitly or extracted
|
||||
from the module.
|
||||
|
||||
Args:
|
||||
grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not
|
||||
provided, gradients will be extracted from the model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
||||
if grads is not None:
|
||||
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
|
||||
else:
|
||||
# Synchronize gradients from the model across the tensor parallelism group.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
|
@ -166,7 +228,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
model: HybridParallelModule,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
max_norm: float = 0,
|
||||
|
@ -176,13 +238,69 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
self.param_info = param_info
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
self.model = model
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.max_norm = max_norm
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
super().__init__(optim)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
r"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs backward pass for gradient computation. If sequence parallelism is enabled
|
||||
and gradient synchronization is required, it will synchronize gradients that are partially derived
|
||||
within sequence parallelism across tp parallelism groups.
|
||||
|
||||
Args:
|
||||
loss (Tensor): The loss tensor to compute gradients with respect to.
|
||||
*args: Additional positional arguments to be passed to the superclass backward method.
|
||||
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, *args, **kwargs)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self.model.sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
|
||||
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||
gradients that are partially derived within sequence parallelism across tp parallelism groups.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): The input tensor for which gradients are computed.
|
||||
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self.model.sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
r"""
|
||||
Perform an optimization step.
|
||||
|
@ -220,8 +338,6 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
# gradients used for norm calculation.
|
||||
|
@ -230,9 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
else:
|
||||
|
@ -250,16 +366,16 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
grad_norm_exponentiated /= self.tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||
|
@ -269,10 +385,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
|
@ -314,7 +430,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
def __init__(
|
||||
self,
|
||||
optim: Optimizer,
|
||||
model: Module,
|
||||
model: HybridParallelModule,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
precision: str = "fp16",
|
||||
|
@ -329,11 +445,14 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
):
|
||||
self.model = model
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.tp_pg = tp_process_group
|
||||
self.pp_pg = pp_process_group
|
||||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
if use_pipeline:
|
||||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(
|
||||
|
@ -349,6 +468,59 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
max_norm=max_norm,
|
||||
)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
r"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs backward pass for gradient computation. If sequence parallelism is enabled
|
||||
and gradient synchronization is required, it will synchronize gradients that are partially derived
|
||||
within sequence parallelism across tp parallelism groups.
|
||||
|
||||
Args:
|
||||
loss (Tensor): The loss tensor to compute gradients with respect to.
|
||||
*args: Additional positional arguments to be passed to the superclass backward method.
|
||||
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, *args, **kwargs)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self.model.sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
|
||||
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||
gradients that are partially derived within sequence parallelism across tp parallelism groups.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): The input tensor for which gradients are computed.
|
||||
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self.model.require_grad_sync:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self.model.sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
@ -363,8 +535,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
|
||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
norm_type = float(norm_type)
|
||||
|
||||
if norm_type == inf:
|
||||
|
@ -374,9 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
|
||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
@ -396,16 +566,16 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||
if not is_distributed_tensor(param_for_grad):
|
||||
grad_norm_exponentiated /= tp_size
|
||||
grad_norm_exponentiated /= self.tp_size
|
||||
|
||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||
# it means that this parameter is used in two different pipeline stages.
|
||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||
# the number of shared stages.
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
for shared_param in self.shared_params:
|
||||
if self.stage_manager.stage in shared_param:
|
||||
stage_working_shared_param = shared_param[self.stage_manager.stage]
|
||||
|
@ -416,10 +586,10 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||
if tp_size > 1:
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
if pp_size > 1:
|
||||
if self.pp_size > 1:
|
||||
# compute norm in pp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||
|
||||
|
@ -433,7 +603,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
model: HybridParallelModule,
|
||||
use_pipeline: bool,
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
|
@ -455,6 +625,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
|
@ -483,6 +654,123 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
forced_dtype=forced_dtype,
|
||||
)
|
||||
|
||||
def sync_dp_grads(self):
|
||||
r"""
|
||||
Synchronize gradients in the data parallelism dimension.
|
||||
|
||||
This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients
|
||||
in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions,
|
||||
namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization
|
||||
and readability.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# Call the superclass `_sync_grad` method to synchronize gradients.
|
||||
super()._sync_grad()
|
||||
|
||||
def _sync_sp_grads(self):
|
||||
r"""
|
||||
Synchronize gradients that are partially derived within sequence parallelism.
|
||||
|
||||
This method is responsible for synchronizing partially derived gradients across tp parallelism groups.
|
||||
It identifies gradients that ara partially derived or not and synchronizes them.
|
||||
If synchronization is required and gradients are found to be synchronized,
|
||||
it performs the synchronization.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def _get_all_working_grads() -> List[Tensor]:
|
||||
"""Retrieve all working gradients from different parameter groups."""
|
||||
all_working_grads = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||
all_working_grads.extend(working_grads)
|
||||
return all_working_grads
|
||||
|
||||
def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
|
||||
"""Identify gradients to be synchronized in the sequence parallelism."""
|
||||
grads_to_sync = []
|
||||
for grad in all_working_grads:
|
||||
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
|
||||
grads_to_sync.append(grad)
|
||||
|
||||
if len(grads_to_sync) > 0:
|
||||
return grads_to_sync
|
||||
else:
|
||||
return None
|
||||
|
||||
# Get all working gradients and gradients to be synchronized.
|
||||
all_working_grads = _get_all_working_grads()
|
||||
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
||||
|
||||
if self.require_grad_sync and grads_to_sync is not None:
|
||||
# Synchronize sequence parallelism gradients if required.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
|
||||
else:
|
||||
return
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
"""
|
||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs the backward pass for gradient computation based on a given loss tensor.
|
||||
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||
gradients that are partially derived within sequence parallelism across TP parallelism groups.
|
||||
|
||||
Args:
|
||||
loss: The loss tensor to compute gradients with respect to.
|
||||
retain_graph (bool): Whether to retain the computation graph.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, retain_graph)
|
||||
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
"""
|
||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||
|
||||
This method performs a backward pass for gradient computation based on a precomputed gradient tensor.
|
||||
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||
gradients that are partially derived within sequence parallelism across TP parallelism groups.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor for which gradients are computed.
|
||||
grad: The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Call the superclass backward_by_grad method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
# If gradient synchronization is is not required, return.
|
||||
return
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
@ -768,7 +1056,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(
|
||||
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
|
||||
model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
|
@ -826,17 +1121,32 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return_outputs: bool = False,
|
||||
) -> dict:
|
||||
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
|
||||
# return loss or outputs if needed
|
||||
|
||||
# Create a context for gradient synchronization based on the optimizer type.
|
||||
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
|
||||
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
|
||||
# so we disable it, performing manual reduction instead.
|
||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
with ctx:
|
||||
outputs = self.schedule.forward_backward_step(
|
||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||
)
|
||||
|
||||
# Synchronize the grads of shared parameters of the model.
|
||||
model.sync_shared_params()
|
||||
|
||||
# Synchronize sequence parallelism gradients of the model.
|
||||
model.sync_sp_grads()
|
||||
|
||||
# Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
|
||||
# Otherwise, synchronize data parallelism gradients of the model.
|
||||
# This is because these are two different forms of data parallelism.
|
||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||
optimizer.sync_grad()
|
||||
optimizer.sync_dp_grads()
|
||||
else:
|
||||
model.sync_grads()
|
||||
model.sync_dp_grads()
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_dataloader(
|
||||
|
|
|
@ -338,7 +338,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
if not isinstance(model, ModelWrapper):
|
||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||
model = HybridParallelModule(
|
||||
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
|
||||
module=model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
use_ddp=use_ddp,
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
|
|
|
@ -218,10 +218,10 @@ class TPInferEngine:
|
|||
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
||||
model_name = model.__class__.__name__
|
||||
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
||||
|
||||
|
||||
model = model.model if self.shard_config.inference_gptq else model
|
||||
policy = get_autopolicy(model, shard_config=self.shard_config)
|
||||
|
||||
policy = get_autopolicy(model, inference_only=True)
|
||||
self.model, _ = shardformer.optimize(model, policy)
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
|
|
|
@ -235,6 +235,14 @@ class SubModuleReplacementDescription:
|
|||
|
||||
|
||||
class Policy(ABC):
|
||||
r"""
|
||||
The base class for all the policies. For each different model, it should have a different policy class,
|
||||
like BertPolicy for Bert Model or OPTPolicy for OPT model.
|
||||
|
||||
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
|
||||
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
|
||||
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
|
||||
"""
|
||||
|
||||
def __init__(self)
|
||||
self.model = None
|
||||
|
@ -245,6 +253,16 @@ class Policy(ABC):
|
|||
"""
|
||||
self.model = model
|
||||
|
||||
def set_shard_config(self, shard_config: ShardConfig) -> None:
|
||||
r"""
|
||||
Set shard config as an attribute of the Policy object.
|
||||
Args:
|
||||
shard_config (:class:`ShardConfig`): The shard config to be perform
|
||||
"""
|
||||
self.shard_config = shard_config
|
||||
|
||||
self.config_sanity_check()
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self) -> nn.Module:
|
||||
"""
|
||||
|
|
|
@ -2,7 +2,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
|||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row
|
||||
from .loss import cross_entropy_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
||||
|
@ -16,6 +16,9 @@ __all__ = [
|
|||
"DropoutForParallelInput",
|
||||
"DropoutForReplicatedInput",
|
||||
"cross_entropy_1d",
|
||||
"BaseLayerNorm",
|
||||
"LayerNorm",
|
||||
"RMSNorm",
|
||||
"FusedLayerNorm",
|
||||
"FusedRMSNorm",
|
||||
"FusedLinear1D_Col",
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
|
||||
from .utils import SeqParallelUtils
|
||||
|
||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
||||
|
||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||
1024,
|
||||
|
@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
|
|||
]
|
||||
|
||||
|
||||
class FusedLayerNorm:
|
||||
class BaseLayerNorm(ABC):
|
||||
@abstractmethod
|
||||
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
|
||||
"""
|
||||
Convert a native PyTorch layer normalization module to a specific layer normalization module,
|
||||
and optionally mark parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The native PyTorch layer normalization module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: The specific layer normalization module.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of the supported layer normalization type.
|
||||
"""
|
||||
|
||||
|
||||
class RMSNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
"""
|
||||
Convert a native RMSNorm module to colossalai layer norm module,
|
||||
and optionally mark parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The native RMSNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: The RMSNorm module.
|
||||
"""
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class LayerNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"LayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a native pytorch layer norm module to colossalai layer norm module,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: The LayerNorm module.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class FusedLayerNorm(BaseLayerNorm):
|
||||
r"""
|
||||
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
@ -43,15 +142,29 @@ class FusedLayerNorm:
|
|||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
|
||||
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
|
||||
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a native pytorch layer norm module to colossalai layer norm module
|
||||
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
# check if apex is installed
|
||||
|
||||
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||
|
||||
try:
|
||||
pass
|
||||
except ImportError:
|
||||
|
@ -85,10 +198,18 @@ class FusedLayerNorm:
|
|||
|
||||
layernorm.weight = module.weight
|
||||
layernorm.bias = module.bias
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
||||
|
||||
return layernorm
|
||||
|
||||
|
||||
class FusedRMSNorm:
|
||||
class FusedRMSNorm(BaseLayerNorm):
|
||||
"""
|
||||
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
|
@ -96,11 +217,22 @@ class FusedRMSNorm:
|
|||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedRMSNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
|
||||
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||
r"""
|
||||
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
|
||||
Returns:
|
||||
nn.Module: FusedRMSNorm module.
|
||||
"""
|
||||
try:
|
||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||
except ImportError:
|
||||
|
@ -124,4 +256,10 @@ class FusedRMSNorm:
|
|||
|
||||
rmsnorm.weight = module.weight
|
||||
|
||||
if sp_partial_derived:
|
||||
# Since gradients are computed using only a subset of the data,
|
||||
# aggregation of these gradients is necessary during backpropagation.
|
||||
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
|
||||
|
||||
return rmsnorm
|
||||
|
|
|
@ -1,8 +1,82 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch import nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
|
||||
|
||||
class SeqParallelUtils:
|
||||
@staticmethod
|
||||
def marked_as_sp_partial_derived_param(param):
|
||||
"""
|
||||
Mark a parameter as partially derived in sequence parallelism.
|
||||
|
||||
Args:
|
||||
param: The parameter to mark as partially derived.
|
||||
"""
|
||||
setattr(param, "partial_derived", True)
|
||||
|
||||
@staticmethod
|
||||
def is_sp_partial_derived_param(param):
|
||||
"""
|
||||
Check if a parameter is marked as partially derived in sequence parallelism.
|
||||
|
||||
Args:
|
||||
param: The parameter to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the parameter is marked as partially derived, False otherwise.
|
||||
"""
|
||||
return getattr(param, "partial_derived", False)
|
||||
|
||||
@staticmethod
|
||||
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
|
||||
"""
|
||||
Allreduce partial derived gradients across the specified process group.
|
||||
|
||||
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
|
||||
|
||||
Args:
|
||||
tp_group (ProcessGroup): The process group for gradient synchronization.
|
||||
model (nn.Module): The model from which gradients will be synchronized.
|
||||
grads (List[torch.Tensor]): The list of gradients to be synchronized.
|
||||
|
||||
Raises:
|
||||
AssertionError: If both `model` and `grads` are provided or neither is provided.
|
||||
"""
|
||||
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
|
||||
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
|
||||
|
||||
# Get the size of the process group, which determines whether synchronization is needed.
|
||||
tp_size = get_world_size(tp_group) if tp_group is not None else 1
|
||||
|
||||
if tp_size == 1:
|
||||
# If the process group size is 1, no synchronization is required.
|
||||
return
|
||||
|
||||
if model is not None:
|
||||
# If `model` is provided, extract partial derived gradients from the model's parameters.
|
||||
grads = []
|
||||
for p in model.parameters():
|
||||
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
|
||||
grads.append(p.grad.data)
|
||||
|
||||
# Flatten and reduce the gradients using the specified process group.
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
||||
|
||||
# Unflatten the synchronized gradients and update the model's gradients.
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
else:
|
||||
# If `grads` are provided explicitly, synchronize those gradients directly.
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
||||
class Randomizer:
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..shard.shard_config import ShardConfig
|
||||
from .base_policy import Policy
|
||||
|
||||
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
||||
|
@ -197,7 +198,7 @@ def _fullname(obj):
|
|||
return module + "." + klass.__qualname__
|
||||
|
||||
|
||||
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
|
||||
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
|
||||
r"""
|
||||
Return the auto policy for the model
|
||||
|
||||
|
@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
|
|||
:class:`Policy`: The auto policy for the model
|
||||
"""
|
||||
full_name = _fullname(model)
|
||||
if inference_only:
|
||||
if shard_config.inference_only:
|
||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||
else:
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
|
@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
|
|||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
||||
)
|
||||
else:
|
||||
policy = import_policy(policy_location, inference_only)
|
||||
policy = import_policy(policy_location, shard_config.inference_only)
|
||||
return policy()
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch.nn import Module
|
|||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ..layer.normalization import BaseLayerNorm
|
||||
from ..layer.parallel_module import ParallelModule
|
||||
from ..shard.shard_config import ShardConfig
|
||||
|
||||
|
@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
|
|||
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
|
||||
"""
|
||||
suffix: str
|
||||
target_module: ParallelModule
|
||||
target_module: Union[ParallelModule, BaseLayerNorm]
|
||||
kwargs: Dict[str, Any] = None
|
||||
ignore_if_not_exist: bool = False
|
||||
|
||||
|
@ -77,7 +78,6 @@ class Policy(ABC):
|
|||
def set_model(self, model: nn.Module) -> None:
|
||||
r"""
|
||||
Set model as an attribute of the Policy object so that we can access the model's attributes.
|
||||
|
||||
Args:
|
||||
model (:class:`nn.Module`): The model to be perform
|
||||
"""
|
||||
|
@ -86,11 +86,11 @@ class Policy(ABC):
|
|||
def set_shard_config(self, shard_config: ShardConfig) -> None:
|
||||
r"""
|
||||
Set shard config as an attribute of the Policy object.
|
||||
|
||||
Args:
|
||||
shard_config (:class:`ShardConfig`): The shard config to be perform
|
||||
"""
|
||||
self.shard_config = shard_config
|
||||
|
||||
self.config_sanity_check()
|
||||
|
||||
@property
|
||||
|
|
|
@ -60,6 +60,12 @@ class BertPolicy(Policy):
|
|||
)
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -141,33 +147,34 @@ class BertPolicy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle bert layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertLayer,
|
||||
)
|
||||
# handle embedding layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertEmbeddings,
|
||||
)
|
||||
# Handle bert layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.LayerNorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertLayer,
|
||||
)
|
||||
# handle embedding layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="LayerNorm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertEmbeddings,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
|
@ -288,9 +295,6 @@ class BertPolicy(Policy):
|
|||
|
||||
# BertModel
|
||||
class BertModelPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.bert.modeling_bert import BertModel
|
||||
|
@ -313,9 +317,6 @@ class BertModelPolicy(BertPolicy):
|
|||
|
||||
# BertForPreTraining
|
||||
class BertForPreTrainingPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
|
@ -355,9 +356,6 @@ class BertForPreTrainingPolicy(BertPolicy):
|
|||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
|
@ -396,9 +394,6 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
|
||||
# BertForMaskedLM
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
|
@ -437,9 +432,6 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||
|
||||
|
@ -484,9 +476,6 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
|||
|
||||
# BertForTokenClassification
|
||||
class BertForTokenClassificationPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
||||
|
||||
|
@ -531,9 +520,6 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
|||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
|
||||
|
@ -564,9 +550,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
|
|||
|
||||
# BertForMultipleChoice
|
||||
class BertForMultipleChoicePolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
||||
|
||||
|
@ -610,9 +593,6 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
|||
|
||||
|
||||
class BertForQuestionAnsweringPolicy(BertPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
|
||||
|
||||
|
|
|
@ -43,6 +43,11 @@ class BlipPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[Blip2EncoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
|
@ -214,94 +219,93 @@ class BlipPolicy(Policy):
|
|||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle Blip2EncoderLayer layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2EncoderLayer,
|
||||
)
|
||||
# Handle Blip2EncoderLayer layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2EncoderLayer,
|
||||
)
|
||||
|
||||
# handle Blip2VisionModel layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2VisionModel,
|
||||
)
|
||||
# handle Blip2VisionModel layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_layernorm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2VisionModel,
|
||||
)
|
||||
|
||||
# handle Blip2VisionModel layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2QFormerModel,
|
||||
)
|
||||
# handle Blip2VisionModel layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layernorm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2QFormerModel,
|
||||
)
|
||||
|
||||
# handle Blip2QFormerLayer layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2QFormerLayer,
|
||||
)
|
||||
# handle Blip2QFormerLayer layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="crossattention.output.LayerNorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output_query.LayerNorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=Blip2QFormerLayer,
|
||||
)
|
||||
|
||||
# handle OPTForCausalLM layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
)
|
||||
# handle OPTForCausalLM layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.final_layer_norm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
)
|
||||
|
||||
# handle OPTDecoderLayer layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer,
|
||||
)
|
||||
# handle OPTDecoderLayer layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
|
|
|
@ -42,6 +42,10 @@ class BloomPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -97,38 +101,39 @@ class BloomPolicy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# handle bloom model
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_f",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomModel,
|
||||
)
|
||||
# handle bloom model
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_f",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings_layernorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomModel,
|
||||
)
|
||||
|
||||
# handle bloom block
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomBlock,
|
||||
)
|
||||
# handle bloom block
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomBlock,
|
||||
)
|
||||
|
||||
if use_sequence_parallel:
|
||||
self.append_or_create_method_replacement(
|
||||
|
@ -225,9 +230,6 @@ class BloomPolicy(Policy):
|
|||
|
||||
|
||||
class BloomModelPolicy(BloomPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
||||
|
|
|
@ -45,6 +45,16 @@ class ChatGLMPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
if self.model.config.rmsnorm:
|
||||
norm_cls = col_nn.FusedRMSNorm
|
||||
else:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
if self.model.config.rmsnorm:
|
||||
norm_cls = col_nn.RMSNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -96,52 +106,34 @@ class ChatGLMPolicy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
if not self.model.config.rmsnorm:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GLMBlock,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GLMBlock,
|
||||
)
|
||||
|
||||
if self.model.config.post_layer_norm:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
)
|
||||
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GLMBlock,
|
||||
)
|
||||
|
||||
if self.model.config.post_layer_norm:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
if self.model.config.post_layer_norm:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder.final_layernorm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
|
@ -224,9 +216,6 @@ class ChatGLMPolicy(Policy):
|
|||
|
||||
|
||||
class ChatGLMModelPolicy(ChatGLMPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
pass
|
||||
|
||||
|
|
|
@ -39,6 +39,11 @@ class GPT2Policy(Policy):
|
|||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
@ -102,33 +107,37 @@ class GPT2Policy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="ln_f",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPT2Model,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="ln_f",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPT2Model,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_1",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_2",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GPT2Block,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_1",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_2",
|
||||
target_module=norm_cls,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="ln_cross_attn",
|
||||
target_module=norm_cls,
|
||||
ignore_if_not_exist=True,
|
||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GPT2Block,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
|
@ -192,9 +201,6 @@ class GPT2Policy(Policy):
|
|||
|
||||
# GPT2Model
|
||||
class GPT2ModelPolicy(GPT2Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
|
@ -216,9 +222,6 @@ class GPT2ModelPolicy(GPT2Policy):
|
|||
|
||||
# GPT2LMHeadModel
|
||||
class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
|
@ -263,9 +266,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
|
||||
# GPT2DoubleHeadsModel
|
||||
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
|
||||
|
||||
|
@ -317,9 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
|
||||
# GPT2ForQuestionAnswering
|
||||
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
|
||||
|
||||
|
@ -347,9 +344,6 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
|||
|
||||
# GPT2ForTokenClassification
|
||||
class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
|
||||
|
||||
|
@ -387,9 +381,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
|||
|
||||
# GPT2ForSequenceClassification
|
||||
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
||||
|
||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
@ -35,6 +35,11 @@ class LlamaPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedRMSNorm
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
@ -93,31 +98,31 @@ class LlamaPolicy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=LlamaDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=LlamaDecoderLayer,
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
|
@ -174,9 +179,6 @@ class LlamaPolicy(Policy):
|
|||
|
||||
|
||||
class LlamaModelPolicy(LlamaPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Callable, Dict, List
|
|||
import torch.nn as nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .._utils import getattr_
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
|
@ -42,6 +42,12 @@ class OPTPolicy(Policy):
|
|||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedLayerNorm
|
||||
else:
|
||||
norm_cls = LayerNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
@ -94,26 +100,25 @@ class OPTPolicy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTDecoder,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTDecoder,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer,
|
||||
)
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTDecoderLayer,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
|
@ -183,9 +188,6 @@ class OPTPolicy(Policy):
|
|||
|
||||
|
||||
class OPTModelPolicy(OPTPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTModel
|
||||
|
||||
|
@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
|
||||
|
||||
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTForSequenceClassification
|
||||
|
||||
|
@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
|
|||
|
||||
|
||||
class OPTForQuestionAnsweringPolicy(OPTPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering
|
||||
|
||||
|
|
|
@ -24,6 +24,11 @@ class SamPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[SamVisionLayer] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
|
@ -151,58 +156,57 @@ class SamPolicy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle SamVisionLayer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamVisionLayer,
|
||||
)
|
||||
# Handle SamVisionLayer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamVisionLayer,
|
||||
)
|
||||
|
||||
# Handle SamTwoWayAttentionBlock
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm3",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm4",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamTwoWayAttentionBlock,
|
||||
)
|
||||
# Handle SamTwoWayAttentionBlock
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm1",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm2",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm3",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm4",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamTwoWayAttentionBlock,
|
||||
)
|
||||
|
||||
# Handle SamTwoWayTransformer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm_final_attn",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamTwoWayTransformer,
|
||||
)
|
||||
# Handle SamTwoWayTransformer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm_final_attn",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=SamTwoWayTransformer,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.shardformer.layer import (
|
|||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
RMSNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||
|
@ -58,6 +59,11 @@ class T5BasePolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedRMSNorm
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
@ -169,38 +175,37 @@ class T5BasePolicy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
|
@ -363,9 +368,6 @@ class T5BasePolicy(Policy):
|
|||
|
||||
|
||||
class T5ModelPolicy(T5BasePolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5Model
|
||||
|
||||
|
@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy):
|
|||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
|
@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
|
||||
|
||||
class T5EncoderPolicy(T5BasePolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
|
|
|
@ -159,9 +159,6 @@ class ViTPolicy(Policy):
|
|||
|
||||
# ViTModel
|
||||
class ViTModelPolicy(ViTPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.vit.modeling_vit import ViTModel
|
||||
|
||||
|
@ -227,9 +224,6 @@ class ViTForImageClassificationPolicy(ViTPolicy):
|
|||
|
||||
# ViTForMaskedImageModeling
|
||||
class ViTForMaskedImageModelingPolicy(ViTPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
|
||||
|
||||
|
|
|
@ -52,6 +52,11 @@ class WhisperPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
norm_cls = col_nn.LayerNorm
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn(
|
||||
|
@ -161,62 +166,61 @@ class WhisperPolicy(Policy):
|
|||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle encoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer,
|
||||
)
|
||||
# Handle encoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer,
|
||||
)
|
||||
|
||||
# Handle decoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer,
|
||||
)
|
||||
# Handle decoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=norm_cls,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer,
|
||||
)
|
||||
|
||||
# handle encoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoder,
|
||||
)
|
||||
# handle encoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoder,
|
||||
)
|
||||
|
||||
# handle decoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoder,
|
||||
)
|
||||
# handle decoder layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=norm_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoder,
|
||||
)
|
||||
|
||||
# enable flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
|
@ -416,9 +420,6 @@ class WhisperPolicy(Policy):
|
|||
|
||||
# WhisperModel
|
||||
class WhisperModelPolicy(WhisperPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import WhisperModel
|
||||
|
||||
|
@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy):
|
|||
|
||||
# WhisperForConditionalGeneration
|
||||
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import WhisperForConditionalGeneration
|
||||
|
||||
|
@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
|||
|
||||
# WhisperForAudioClassification
|
||||
class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def preprocess(self):
|
||||
return self.model
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ class ModelSharder(object):
|
|||
|
||||
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
||||
self.model = model
|
||||
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
|
||||
self.shard_config = shard_config
|
||||
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy
|
||||
|
||||
def shard(self) -> List[Dict[int, Tensor]]:
|
||||
r"""
|
||||
|
@ -196,7 +196,7 @@ class ModelSharder(object):
|
|||
|
||||
try:
|
||||
replace_layer = target_module.from_native_module(
|
||||
native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs
|
||||
native_sub_module, process_group=self.shard_config.tensor_parallel_process_group, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
|
|
|
@ -700,7 +700,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
############################
|
||||
|
||||
# this method is used to sync gradient manually
|
||||
def sync_grad(self):
|
||||
def _sync_grad(self):
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._working_param_groups[group_id]
|
||||
for param in param_group:
|
||||
|
@ -713,7 +713,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# if not overlapping communication (no reduction hook is attached) when zero1
|
||||
# we need to manually reduce these gradients
|
||||
if not partition_grad and not self._overlap_communication:
|
||||
self.sync_grad()
|
||||
self._sync_grad()
|
||||
else:
|
||||
self._run_reduction()
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ for k, v in inputs.items():
|
|||
new_shape[0] = 16
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
|
||||
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
model = transformers.LlamaForCausalLM(
|
||||
transformers.LlamaConfig(
|
||||
|
@ -59,6 +58,7 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
|
|||
@parameterize("pp_size", [2])
|
||||
@parameterize("max_output_len", [4])
|
||||
@parameterize("micro_batch_size", [1])
|
||||
|
||||
@clear_cache_before_run()
|
||||
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
||||
|
@ -76,6 +76,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
|
|
@ -128,7 +128,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
@ -147,7 +147,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
@ -157,7 +157,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@ -174,7 +174,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@ -199,7 +199,7 @@ def run_test(test_config):
|
|||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@ -208,7 +208,7 @@ def run_test(test_config):
|
|||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
|
|
@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
|
@ -114,7 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
|
@ -123,7 +123,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
|
@ -148,7 +148,7 @@ def run_test(test_config):
|
|||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"max_norm": 5,
|
||||
|
|
|
@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
@ -126,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 2,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@ -146,7 +146,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@ -155,7 +155,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"zero_stage": 2,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@ -181,7 +181,7 @@ def run_test(test_config):
|
|||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "bf16",
|
||||
"max_norm": 5,
|
||||
|
@ -191,7 +191,7 @@ def run_test(test_config):
|
|||
"pp_size": 2,
|
||||
"num_microbatches": 4,
|
||||
"zero_stage": 1,
|
||||
"enable_all_optimization": False,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp16",
|
||||
"max_norm": 5,
|
||||
|
|
|
@ -34,6 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
bert = unwrap_model(org_model, "BertModel", "bert")
|
||||
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
||||
|
||||
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
|
||||
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
||||
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
||||
|
||||
|
@ -50,8 +51,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
row_layer_grads = get_grad_tensors_for_check(
|
||||
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||
)
|
||||
|
||||
norm_layer_grads = get_grad_tensors_for_check(
|
||||
bert,
|
||||
sharded_bert,
|
||||
norm_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
grads_to_check.update(norm_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
|
@ -85,6 +99,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp32",
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
|
|
|
@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
bloom = unwrap_model(org_model, "BloomModel", "transformer")
|
||||
sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer")
|
||||
|
||||
norm_layer_for_check = ["word_embeddings_layernorm", "h[0].input_layernorm"]
|
||||
row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"]
|
||||
col_layer_for_check = ["h[0].self_attention.dense"]
|
||||
|
||||
|
@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
col_layer_grads = get_grad_tensors_for_check(
|
||||
bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
|
||||
norm_layer_grads = get_grad_tensors_for_check(
|
||||
bloom,
|
||||
sharded_bloom,
|
||||
norm_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
grads_to_check.update(norm_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
|
|
|
@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer")
|
||||
shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer")
|
||||
|
||||
norm_layer_for_check = ["encoder.layers[0].input_layernorm"]
|
||||
row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"]
|
||||
col_layer_for_check = ["encoder.layers[0].self_attention.dense"]
|
||||
|
||||
|
@ -66,8 +67,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
norm_layer_grads = get_grad_tensors_for_check(
|
||||
chatglm_model,
|
||||
shard_chatglm_model,
|
||||
norm_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
grads_to_check.update(norm_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
|
|
|
@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
gpt2 = unwrap_model(org_model, "GPT2Model", "transformer")
|
||||
sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer")
|
||||
|
||||
norm_layer_for_check = ["h[0].ln_1", "h[0].ln_2"]
|
||||
col_layer_for_check = ["h[0].mlp.c_fc"]
|
||||
row_layer_for_check = ["wte", "h[0].mlp.c_proj"]
|
||||
|
||||
|
@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
row_layer_grads = get_grad_tensors_for_check(
|
||||
gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||
)
|
||||
|
||||
norm_layer_grads = get_grad_tensors_for_check(
|
||||
gpt2,
|
||||
sharded_gpt2,
|
||||
norm_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
grads_to_check.update(col_layer_grads)
|
||||
grads_to_check.update(row_layer_grads)
|
||||
grads_to_check.update(norm_layer_grads)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
|
|
Loading…
Reference in New Issue