[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------

Co-authored-by: littsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------

Co-authored-by: littsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------

Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
pull/5007/head
littsk 2023-11-03 13:32:43 +08:00 committed by GitHub
parent d99b2c961a
commit 1a3315e336
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1120 additions and 552 deletions

View File

@ -1,6 +1,6 @@
import ctypes
import random
from contextlib import nullcontext
from contextlib import contextmanager
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
@ -25,6 +25,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
@ -47,12 +48,17 @@ class HybridParallelModule(ModelWrapper):
precision: str,
shard_config: ShardConfig,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
self.dp_group = dp_group
self.tp_group = tp_group
self.use_dpp = use_ddp
self.require_grad_sync = True
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
@ -98,19 +104,75 @@ class HybridParallelModule(ModelWrapper):
dist.all_reduce(param.grad, group=group)
dist.barrier()
def no_sync(self) -> Iterator[None]:
# no sync grads across data parallel
return nullcontext()
@contextmanager
def no_sync(self):
r"""
A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization
when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass
when exiting the context.
"""
def sync_grads(self):
# sync grad across data parallel
# Store the current value of 'require_grad_sync' to restore it later.
old_require_grad_sync = self.require_grad_sync
# Disable automatic gradient synchronization.
self.require_grad_sync = False
try:
if self.use_dpp:
# If using data parallel processing (use_dpp), disable synchronization too.
with self.module.no_sync():
yield
else:
yield
finally:
# Restore the original value of 'require_grad_sync'.
self.require_grad_sync = old_require_grad_sync
def sync_dp_grads(self):
r"""
Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1.
This function performs an all-reduce operation to combine gradients from different devices in the DP group.
Args:
None
Returns:
None
"""
# Check if the DP group size is 1, meaning no synchronization is needed.
if self.dp_group.size() == 1:
return
# Iterate through the model's parameters and perform gradient synchronization.
for p in self.module.parameters():
if p.grad is not None:
# Perform all-reduce to combine gradients from different devices.
dist.all_reduce(p.grad, group=self.dp_group)
# Normalize the gradient by dividing it by the DP group size.
p.grad.div_(self.dp_group.size())
def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
r"""
Synchronize gradients that are partially derived within sequence parallelism
if sequence parallelism is enabled. Gradients can be provided explicitly or extracted
from the module.
Args:
grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not
provided, gradients will be extracted from the model.
Returns:
None
"""
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
else:
# Synchronize gradients from the model across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
@ -166,7 +228,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(
self,
optim: Optimizer,
model: Module,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
max_norm: float = 0,
@ -176,13 +238,69 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
self.model = model
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.max_norm = max_norm
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs backward pass for gradient computation. If sequence parallelism is enabled
and gradient synchronization is required, it will synchronize gradients that are partially derived
within sequence parallelism across tp parallelism groups.
Args:
loss (Tensor): The loss tensor to compute gradients with respect to.
*args: Additional positional arguments to be passed to the superclass backward method.
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across tp parallelism groups.
Args:
tensor (Tensor): The input tensor for which gradients are computed.
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def step(self, *args, **kwargs):
r"""
Perform an optimization step.
@ -220,8 +338,6 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if len(param_gradient_pairs) == 0:
return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
# gradients used for norm calculation.
@ -230,9 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
@ -250,16 +366,16 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
if self.pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
@ -269,10 +385,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1:
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
if self.pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
@ -314,7 +430,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
self,
optim: Optimizer,
model: Module,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = "fp16",
@ -329,11 +445,14 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
):
self.model = model
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(
@ -349,6 +468,59 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
max_norm=max_norm,
)
def backward(self, loss: Tensor, *args, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs backward pass for gradient computation. If sequence parallelism is enabled
and gradient synchronization is required, it will synchronize gradients that are partially derived
within sequence parallelism across tp parallelism groups.
Args:
loss (Tensor): The loss tensor to compute gradients with respect to.
*args: Additional positional arguments to be passed to the superclass backward method.
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across tp parallelism groups.
Args:
tensor (Tensor): The input tensor for which gradients are computed.
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
self.model.sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
@ -363,8 +535,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
if len(param_gradient_pairs) == 0:
return 0.0
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
if norm_type == inf:
@ -374,9 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if tp_size > 1:
if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
if pp_size > 1:
if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
@ -396,16 +566,16 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
if tp_size > 1:
if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
grad_norm_exponentiated /= tp_size
grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
if pp_size > 1:
if self.pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_working_shared_param = shared_param[self.stage_manager.stage]
@ -416,10 +586,10 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated += grad_norm_exponentiated
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
if tp_size > 1:
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
if pp_size > 1:
if self.pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
@ -433,7 +603,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
@ -455,6 +625,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
):
self.model = model
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
@ -483,6 +654,123 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
forced_dtype=forced_dtype,
)
def sync_dp_grads(self):
r"""
Synchronize gradients in the data parallelism dimension.
This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients
in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions,
namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization
and readability.
Args:
None
Returns:
None
"""
# Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad()
def _sync_sp_grads(self):
r"""
Synchronize gradients that are partially derived within sequence parallelism.
This method is responsible for synchronizing partially derived gradients across tp parallelism groups.
It identifies gradients that ara partially derived or not and synchronizes them.
If synchronization is required and gradients are found to be synchronized,
it performs the synchronization.
Args:
None
Returns:
None
"""
def _get_all_working_grads() -> List[Tensor]:
"""Retrieve all working gradients from different parameter groups."""
all_working_grads = []
for group_id in range(self.num_param_groups):
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
all_working_grads.extend(working_grads)
return all_working_grads
def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
"""Identify gradients to be synchronized in the sequence parallelism."""
grads_to_sync = []
for grad in all_working_grads:
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
grads_to_sync.append(grad)
if len(grads_to_sync) > 0:
return grads_to_sync
else:
return None
# Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads)
if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
else:
return
def backward(self, loss, retain_graph=False):
"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
This method performs the backward pass for gradient computation based on a given loss tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across TP parallelism groups.
Args:
loss: The loss tensor to compute gradients with respect to.
retain_graph (bool): Whether to retain the computation graph.
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor, grad):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
This method performs a backward pass for gradient computation based on a precomputed gradient tensor.
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
gradients that are partially derived within sequence parallelism across TP parallelism groups.
Args:
tensor: The input tensor for which gradients are computed.
grad: The precomputed gradient tensor to compute gradients with respect to the input tensor.
Returns:
None
"""
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
self._sync_sp_grads()
else:
# If gradient synchronization is is not required, return.
return
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@ -768,7 +1056,14 @@ class HybridParallelPlugin(PipelinePluginBase):
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
@ -826,17 +1121,32 @@ class HybridParallelPlugin(PipelinePluginBase):
return_outputs: bool = False,
) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
# return loss or outputs if needed
# Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx:
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()
# Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
# Otherwise, synchronize data parallelism gradients of the model.
# This is because these are two different forms of data parallelism.
if isinstance(optimizer, HybridParallelZeroOptimizer):
optimizer.sync_grad()
optimizer.sync_dp_grads()
else:
model.sync_grads()
model.sync_dp_grads()
return outputs
def prepare_dataloader(

View File

@ -338,7 +338,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=self.dp_group,
tp_group=self.tp_group,
use_ddp=use_ddp,
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:

View File

@ -220,8 +220,8 @@ class TPInferEngine:
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
model = model.model if self.shard_config.inference_gptq else model
policy = get_autopolicy(model, shard_config=self.shard_config)
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)
if self.shard_config.inference_gptq:

View File

@ -235,6 +235,14 @@ class SubModuleReplacementDescription:
class Policy(ABC):
r"""
The base class for all the policies. For each different model, it should have a different policy class,
like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""
def __init__(self)
self.model = None
@ -245,6 +253,16 @@ class Policy(ABC):
"""
self.model = model
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@abstractmethod
def preprocess(self) -> nn.Module:
"""

View File

@ -2,7 +2,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
@ -16,6 +16,9 @@ __all__ = [
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
"FusedLayerNorm",
"FusedRMSNorm",
"FusedLinear1D_Col",

View File

@ -1,11 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.lazy import LazyInitContext
__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
]
class FusedLayerNorm:
class BaseLayerNorm(ABC):
@abstractmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
"""
Convert a native PyTorch layer normalization module to a specific layer normalization module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native PyTorch layer normalization module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The specific layer normalization module.
Raises:
AssertionError: If the provided module is not an instance of the supported layer normalization type.
"""
class RMSNorm(BaseLayerNorm):
r"""
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
"""
Convert a native RMSNorm module to colossalai layer norm module,
and optionally mark parameters for gradient aggregation.
Args:
module (nn.Module): The native RMSNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The RMSNorm module.
"""
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
return module
class LayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"LayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
)
@staticmethod
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: The LayerNorm module.
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
LazyInitContext.materialize(module)
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
return module
class FusedLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""
@ -43,15 +142,29 @@ class FusedLayerNorm:
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
)
@staticmethod
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native pytorch layer norm module to colossalai layer norm module
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
# check if apex is installed
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
try:
pass
except ImportError:
@ -85,10 +198,18 @@ class FusedLayerNorm:
layernorm.weight = module.weight
layernorm.bias = module.bias
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm
class FusedRMSNorm:
class FusedRMSNorm(BaseLayerNorm):
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""
@ -96,11 +217,22 @@ class FusedRMSNorm:
def __init__(self) -> None:
raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
)
@staticmethod
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: FusedRMSNorm module.
"""
try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError:
@ -124,4 +256,10 @@ class FusedRMSNorm:
rmsnorm.weight = module.weight
if sp_partial_derived:
# Since gradients are computed using only a subset of the data,
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
return rmsnorm

View File

@ -1,8 +1,82 @@
from contextlib import contextmanager
from typing import List
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch import nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup, get_world_size
class SeqParallelUtils:
@staticmethod
def marked_as_sp_partial_derived_param(param):
"""
Mark a parameter as partially derived in sequence parallelism.
Args:
param: The parameter to mark as partially derived.
"""
setattr(param, "partial_derived", True)
@staticmethod
def is_sp_partial_derived_param(param):
"""
Check if a parameter is marked as partially derived in sequence parallelism.
Args:
param: The parameter to check.
Returns:
bool: True if the parameter is marked as partially derived, False otherwise.
"""
return getattr(param, "partial_derived", False)
@staticmethod
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
"""
Allreduce partial derived gradients across the specified process group.
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
Args:
tp_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized.
Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided.
"""
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
# Get the size of the process group, which determines whether synchronization is needed.
tp_size = get_world_size(tp_group) if tp_group is not None else 1
if tp_size == 1:
# If the process group size is 1, no synchronization is required.
return
if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters.
grads = []
for p in model.parameters():
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
grads.append(p.grad.data)
# Flatten and reduce the gradients using the specified process group.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
# Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
else:
# If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
class Randomizer:

View File

@ -4,6 +4,7 @@ from typing import Optional
import torch.nn as nn
from ..shard.shard_config import ShardConfig
from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
@ -197,7 +198,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
r"""
Return the auto policy for the model
@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
if inference_only:
if shard_config.inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, inference_only)
policy = import_policy(policy_location, shard_config.inference_only)
return policy()

View File

@ -11,6 +11,7 @@ from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager
from ..layer.normalization import BaseLayerNorm
from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig
@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False
@ -77,7 +78,6 @@ class Policy(ABC):
def set_model(self, model: nn.Module) -> None:
r"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
Args:
model (:class:`nn.Module`): The model to be perform
"""
@ -86,11 +86,11 @@ class Policy(ABC):
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
self.config_sanity_check()
@property

View File

@ -60,6 +60,12 @@ class BertPolicy(Policy):
)
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@ -141,17 +147,18 @@ class BertPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle bert layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
@ -162,7 +169,7 @@ class BertPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="LayerNorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
)
],
policy=policy,
@ -288,9 +295,6 @@ class BertPolicy(Policy):
# BertModel
class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel
@ -313,9 +317,6 @@ class BertModelPolicy(BertPolicy):
# BertForPreTraining
class BertForPreTrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
@ -355,9 +356,6 @@ class BertForPreTrainingPolicy(BertPolicy):
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
@ -396,9 +394,6 @@ class BertLMHeadModelPolicy(BertPolicy):
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
@ -437,9 +432,6 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification
@ -484,9 +476,6 @@ class BertForSequenceClassificationPolicy(BertPolicy):
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification
@ -531,9 +520,6 @@ class BertForTokenClassificationPolicy(BertPolicy):
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
@ -564,9 +550,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice
@ -610,9 +593,6 @@ class BertForMultipleChoicePolicy(BertPolicy):
class BertForQuestionAnsweringPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForQuestionAnswering

View File

@ -43,6 +43,11 @@ class BlipPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism:
policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={
@ -214,17 +219,16 @@ class BlipPolicy(Policy):
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
],
policy=policy,
@ -236,7 +240,7 @@ class BlipPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="post_layernorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
)
],
policy=policy,
@ -248,7 +252,7 @@ class BlipPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="layernorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
)
],
policy=policy,
@ -260,15 +264,15 @@ class BlipPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="attention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="crossattention.output.LayerNorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="output_query.LayerNorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
],
policy=policy,
@ -280,7 +284,7 @@ class BlipPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="model.decoder.final_layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
)
],
policy=policy,
@ -292,11 +296,11 @@ class BlipPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
],
policy=policy,

View File

@ -42,6 +42,10 @@ class BloomPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@ -97,17 +101,16 @@ class BloomPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# handle bloom model
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
],
policy=policy,
@ -119,11 +122,13 @@ class BloomPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
@ -225,9 +230,6 @@ class BloomPolicy(Policy):
class BloomModelPolicy(BloomPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.bloom.modeling_bloom import BloomModel

View File

@ -45,6 +45,16 @@ class ChatGLMPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
if self.model.config.rmsnorm:
norm_cls = col_nn.FusedRMSNorm
else:
norm_cls = col_nn.FusedLayerNorm
else:
if self.model.config.rmsnorm:
norm_cls = col_nn.RMSNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@ -96,13 +106,17 @@ class ChatGLMPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
if not self.model.config.rmsnorm:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm
suffix="input_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
@ -113,30 +127,8 @@ class ChatGLMPolicy(Policy):
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm
)
],
policy=policy,
target_key=ChatGLMModel,
)
else:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
SubModuleReplacementDescription(
suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm
),
],
policy=policy,
target_key=GLMBlock,
)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm
suffix="encoder.final_layernorm",
target_module=norm_cls,
)
],
policy=policy,
@ -224,9 +216,6 @@ class ChatGLMPolicy(Policy):
class ChatGLMModelPolicy(ChatGLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
pass

View File

@ -39,6 +39,11 @@ class GPT2Policy(Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@ -102,11 +107,10 @@ class GPT2Policy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
policy=policy,
target_key=GPT2Model,
@ -116,14 +120,19 @@ class GPT2Policy(Policy):
description=[
SubModuleReplacementDescription(
suffix="ln_1",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="ln_2",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
suffix="ln_cross_attn",
target_module=norm_cls,
ignore_if_not_exist=True,
kwargs={"sp_partial_derived": use_sequence_parallel},
),
],
policy=policy,
@ -192,9 +201,6 @@ class GPT2Policy(Policy):
# GPT2Model
class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
@ -216,9 +222,6 @@ class GPT2ModelPolicy(GPT2Policy):
# GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
@ -263,9 +266,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
# GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
@ -317,9 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
# GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
@ -347,9 +344,6 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
@ -387,9 +381,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification

View File

@ -6,7 +6,7 @@ import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -35,6 +35,11 @@ class LlamaPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@ -93,16 +98,15 @@ class LlamaPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
target_module=norm_cls,
),
],
policy=policy,
@ -112,12 +116,13 @@ class LlamaPolicy(Policy):
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
target_module=norm_cls,
),
policy=policy,
target_key=LlamaModel,
)
# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
@ -174,9 +179,6 @@ class LlamaPolicy(Policy):
class LlamaModelPolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel

View File

@ -5,7 +5,7 @@ from typing import Callable, Dict, List
import torch.nn as nn
from torch import Tensor, nn
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
@ -42,6 +42,12 @@ class OPTPolicy(Policy):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
norm_cls = LayerNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@ -94,10 +100,9 @@ class OPTPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
policy=policy,
target_key=OPTDecoder,
@ -105,10 +110,10 @@ class OPTPolicy(Policy):
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
],
policy=policy,
@ -183,9 +188,6 @@ class OPTPolicy(Policy):
class OPTModelPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTModel
@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
class OPTForSequenceClassificationPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForSequenceClassification
@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
class OPTForQuestionAnsweringPolicy(OPTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering

View File

@ -24,6 +24,11 @@ class SamPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_tensor_parallelism:
policy[SamVisionLayer] = ModulePolicyDescription(
attribute_replacement={
@ -151,17 +156,16 @@ class SamPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle SamVisionLayer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
],
policy=policy,
@ -173,19 +177,19 @@ class SamPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="layer_norm1",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm2",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm3",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="layer_norm4",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
],
policy=policy,
@ -197,7 +201,7 @@ class SamPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="layer_norm_final_attn",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
)
],
policy=policy,

View File

@ -11,6 +11,7 @@ from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
RMSNorm,
VocabParallelEmbedding1D,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
@ -58,6 +59,11 @@ class T5BasePolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@ -169,11 +175,10 @@ class T5BasePolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
target_module=norm_cls,
),
policy=policy,
target_key=T5LayerFF,
@ -181,23 +186,23 @@ class T5BasePolicy(Policy):
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=FusedRMSNorm,
target_module=norm_cls,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
policy=policy,
target_key=T5LayerSelfAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
policy=policy,
target_key=T5LayerCrossAttention,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm),
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls),
policy=policy,
target_key=T5Stack,
)
@ -363,9 +368,6 @@ class T5BasePolicy(Policy):
class T5ModelPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5Model
@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy):
class T5ForConditionalGenerationPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5ForConditionalGeneration
@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
class T5EncoderPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5EncoderModel

View File

@ -159,9 +159,6 @@ class ViTPolicy(Policy):
# ViTModel
class ViTModelPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTModel
@ -227,9 +224,6 @@ class ViTForImageClassificationPolicy(ViTPolicy):
# ViTForMaskedImageModeling
class ViTForMaskedImageModelingPolicy(ViTPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel

View File

@ -52,6 +52,11 @@ class WhisperPolicy(Policy):
policy = {}
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
norm_cls = col_nn.LayerNorm
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
@ -161,17 +166,16 @@ class WhisperPolicy(Policy):
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle encoder layer
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
],
policy=policy,
@ -183,11 +187,11 @@ class WhisperPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="self_attn_layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
SubModuleReplacementDescription(
suffix="final_layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
),
],
policy=policy,
@ -199,7 +203,7 @@ class WhisperPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
)
],
policy=policy,
@ -211,7 +215,7 @@ class WhisperPolicy(Policy):
description=[
SubModuleReplacementDescription(
suffix="layer_norm",
target_module=col_nn.FusedLayerNorm,
target_module=norm_cls,
)
],
policy=policy,
@ -416,9 +420,6 @@ class WhisperPolicy(Policy):
# WhisperModel
class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperModel
@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy):
# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperForConditionalGeneration
@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def preprocess(self):
return self.model

View File

@ -27,8 +27,8 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
self.shard_config = shard_config
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy
def shard(self) -> List[Dict[int, Tensor]]:
r"""
@ -196,7 +196,7 @@ class ModelSharder(object):
try:
replace_layer = target_module.from_native_module(
native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs
native_sub_module, process_group=self.shard_config.tensor_parallel_process_group, **kwargs
)
except Exception as e:
raise RuntimeError(

View File

@ -700,7 +700,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
############################
# this method is used to sync gradient manually
def sync_grad(self):
def _sync_grad(self):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
@ -713,7 +713,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication:
self.sync_grad()
self._sync_grad()
else:
self._run_reduction()

View File

@ -24,7 +24,6 @@ for k, v in inputs.items():
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
@ -59,6 +58,7 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
@ -76,6 +76,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

View File

@ -128,7 +128,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
@ -147,7 +147,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
@ -157,7 +157,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@ -174,7 +174,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@ -199,7 +199,7 @@ def run_test(test_config):
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@ -208,7 +208,7 @@ def run_test(test_config):
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,

View File

@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp32",
"max_norm": 5,
@ -114,7 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
@ -123,7 +123,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,
@ -148,7 +148,7 @@ def run_test(test_config):
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
"max_norm": 5,

View File

@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"max_norm": 5,
@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
@ -126,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "bf16",
"max_norm": 5,
@ -146,7 +146,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 1,
"zero_stage": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@ -155,7 +155,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 1,
"zero_stage": 2,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@ -181,7 +181,7 @@ def run_test(test_config):
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@ -191,7 +191,7 @@ def run_test(test_config):
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,

View File

@ -34,6 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
@ -50,8 +51,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
norm_layer_grads = get_grad_tensors_for_check(
bert,
sharded_bert,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()
@ -85,6 +99,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp32",
},
{
"tp_size": 1,
"pp_size": 2,

View File

@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
bloom = unwrap_model(org_model, "BloomModel", "transformer")
sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer")
norm_layer_for_check = ["word_embeddings_layernorm", "h[0].input_layernorm"]
row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"]
col_layer_for_check = ["h[0].self_attention.dense"]
@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
col_layer_grads = get_grad_tensors_for_check(
bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
norm_layer_grads = get_grad_tensors_for_check(
bloom,
sharded_bloom,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()

View File

@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer")
shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer")
norm_layer_for_check = ["encoder.layers[0].input_layernorm"]
row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"]
col_layer_for_check = ["encoder.layers[0].self_attention.dense"]
@ -66,8 +67,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False,
)
norm_layer_grads = get_grad_tensors_for_check(
chatglm_model,
shard_chatglm_model,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()

View File

@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
gpt2 = unwrap_model(org_model, "GPT2Model", "transformer")
sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer")
norm_layer_for_check = ["h[0].ln_1", "h[0].ln_2"]
col_layer_for_check = ["h[0].mlp.c_fc"]
row_layer_for_check = ["wte", "h[0].mlp.c_proj"]
@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_grads = get_grad_tensors_for_check(
gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
norm_layer_grads = get_grad_tensors_for_check(
gpt2,
sharded_gpt2,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()