mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)
* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915) * Add layer norm gradients all-reduce for sequence parallel. * skip pipeline inference test * [hotfix] fixing polices of sequence parallel (#4922) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy --------- Co-authored-by: littsk <1214689160@qq.com> * Hotfix/add grad all reduce for sequence parallel (#4927) * Add layer norm gradients all-reduce for sequence parallel. * fix parameter passing when calling get_autopolicy * fix bug using wrong variables --------- Co-authored-by: littsk <1214689160@qq.com> * fix policy initialization * fix bloom and chatglm policices * polish code of handling layernorm * fix moe module * polish code of class initializing --------- Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>pull/5007/head
parent
d99b2c961a
commit
1a3315e336
|
@ -1,6 +1,6 @@
|
||||||
import ctypes
|
import ctypes
|
||||||
import random
|
import random
|
||||||
from contextlib import nullcontext
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||||
|
@ -25,6 +25,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||||
|
@ -47,12 +48,17 @@ class HybridParallelModule(ModelWrapper):
|
||||||
precision: str,
|
precision: str,
|
||||||
shard_config: ShardConfig,
|
shard_config: ShardConfig,
|
||||||
dp_group: ProcessGroup,
|
dp_group: ProcessGroup,
|
||||||
|
tp_group: ProcessGroup,
|
||||||
use_ddp: bool,
|
use_ddp: bool,
|
||||||
ddp_config: dict,
|
ddp_config: dict,
|
||||||
custom_policy: Policy,
|
custom_policy: Policy,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.stage_manager = shard_config.pipeline_stage_manager
|
self.stage_manager = shard_config.pipeline_stage_manager
|
||||||
|
self.shard_config = shard_config
|
||||||
self.dp_group = dp_group
|
self.dp_group = dp_group
|
||||||
|
self.tp_group = tp_group
|
||||||
|
self.use_dpp = use_ddp
|
||||||
|
self.require_grad_sync = True
|
||||||
|
|
||||||
shardformer = ShardFormer(shard_config)
|
shardformer = ShardFormer(shard_config)
|
||||||
if custom_policy is not None:
|
if custom_policy is not None:
|
||||||
|
@ -98,19 +104,75 @@ class HybridParallelModule(ModelWrapper):
|
||||||
dist.all_reduce(param.grad, group=group)
|
dist.all_reduce(param.grad, group=group)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
def no_sync(self) -> Iterator[None]:
|
@contextmanager
|
||||||
# no sync grads across data parallel
|
def no_sync(self):
|
||||||
return nullcontext()
|
r"""
|
||||||
|
A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization
|
||||||
|
when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass
|
||||||
|
when exiting the context.
|
||||||
|
"""
|
||||||
|
|
||||||
def sync_grads(self):
|
# Store the current value of 'require_grad_sync' to restore it later.
|
||||||
# sync grad across data parallel
|
old_require_grad_sync = self.require_grad_sync
|
||||||
|
# Disable automatic gradient synchronization.
|
||||||
|
self.require_grad_sync = False
|
||||||
|
try:
|
||||||
|
if self.use_dpp:
|
||||||
|
# If using data parallel processing (use_dpp), disable synchronization too.
|
||||||
|
with self.module.no_sync():
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Restore the original value of 'require_grad_sync'.
|
||||||
|
self.require_grad_sync = old_require_grad_sync
|
||||||
|
|
||||||
|
def sync_dp_grads(self):
|
||||||
|
r"""
|
||||||
|
Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1.
|
||||||
|
This function performs an all-reduce operation to combine gradients from different devices in the DP group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if the DP group size is 1, meaning no synchronization is needed.
|
||||||
if self.dp_group.size() == 1:
|
if self.dp_group.size() == 1:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Iterate through the model's parameters and perform gradient synchronization.
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
|
# Perform all-reduce to combine gradients from different devices.
|
||||||
dist.all_reduce(p.grad, group=self.dp_group)
|
dist.all_reduce(p.grad, group=self.dp_group)
|
||||||
|
# Normalize the gradient by dividing it by the DP group size.
|
||||||
p.grad.div_(self.dp_group.size())
|
p.grad.div_(self.dp_group.size())
|
||||||
|
|
||||||
|
def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
|
||||||
|
r"""
|
||||||
|
Synchronize gradients that are partially derived within sequence parallelism
|
||||||
|
if sequence parallelism is enabled. Gradients can be provided explicitly or extracted
|
||||||
|
from the module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not
|
||||||
|
provided, gradients will be extracted from the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
||||||
|
if grads is not None:
|
||||||
|
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||||
|
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
|
||||||
|
else:
|
||||||
|
# Synchronize gradients from the model across the tensor parallelism group.
|
||||||
|
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.convert_fn is not None:
|
if self.convert_fn is not None:
|
||||||
args = tree_map(self.convert_fn, args)
|
args = tree_map(self.convert_fn, args)
|
||||||
|
@ -166,7 +228,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
model: Module,
|
model: HybridParallelModule,
|
||||||
use_pipeline: bool,
|
use_pipeline: bool,
|
||||||
param_info: OrderedDict,
|
param_info: OrderedDict,
|
||||||
max_norm: float = 0,
|
max_norm: float = 0,
|
||||||
|
@ -176,13 +238,69 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optim, model)
|
init_pipeline_optimizer(optim, model)
|
||||||
|
self.model = model
|
||||||
self.stage_manager = model.stage_manager
|
self.stage_manager = model.stage_manager
|
||||||
self.shared_params = model.shared_params
|
self.shared_params = model.shared_params
|
||||||
self.max_norm = max_norm
|
self.max_norm = max_norm
|
||||||
self.tp_pg = tp_process_group
|
self.tp_pg = tp_process_group
|
||||||
self.pp_pg = pp_process_group
|
self.pp_pg = pp_process_group
|
||||||
|
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||||
|
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
|
|
||||||
|
def backward(self, loss: Tensor, *args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
This method performs backward pass for gradient computation. If sequence parallelism is enabled
|
||||||
|
and gradient synchronization is required, it will synchronize gradients that are partially derived
|
||||||
|
within sequence parallelism across tp parallelism groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (Tensor): The loss tensor to compute gradients with respect to.
|
||||||
|
*args: Additional positional arguments to be passed to the superclass backward method.
|
||||||
|
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Call the superclass backward method to compute gradients.
|
||||||
|
super().backward(loss, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.model.require_grad_sync:
|
||||||
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
self.model.sync_sp_grads()
|
||||||
|
else:
|
||||||
|
# If gradient synchronization is is not required, return.
|
||||||
|
return
|
||||||
|
|
||||||
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||||
|
"""
|
||||||
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
|
||||||
|
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||||
|
gradients that are partially derived within sequence parallelism across tp parallelism groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): The input tensor for which gradients are computed.
|
||||||
|
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Call the superclass backward method to compute gradients.
|
||||||
|
super().backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
|
if self.model.require_grad_sync:
|
||||||
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
self.model.sync_sp_grads()
|
||||||
|
else:
|
||||||
|
# If gradient synchronization is is not required, return.
|
||||||
|
return
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Perform an optimization step.
|
Perform an optimization step.
|
||||||
|
@ -220,8 +338,6 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
if len(param_gradient_pairs) == 0:
|
if len(param_gradient_pairs) == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
|
||||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
|
||||||
norm_type = float(norm_type)
|
norm_type = float(norm_type)
|
||||||
|
|
||||||
# gradients used for norm calculation.
|
# gradients used for norm calculation.
|
||||||
|
@ -230,9 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
if norm_type == inf:
|
if norm_type == inf:
|
||||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
if tp_size > 1:
|
if self.tp_size > 1:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||||
if pp_size > 1:
|
if self.pp_size > 1:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||||
total_norm = total_norm_cuda.item()
|
total_norm = total_norm_cuda.item()
|
||||||
else:
|
else:
|
||||||
|
@ -250,16 +366,16 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||||
if tp_size > 1:
|
if self.tp_size > 1:
|
||||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||||
if not is_distributed_tensor(param_for_grad):
|
if not is_distributed_tensor(param_for_grad):
|
||||||
grad_norm_exponentiated /= tp_size
|
grad_norm_exponentiated /= self.tp_size
|
||||||
|
|
||||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||||
# it means that this parameter is used in two different pipeline stages.
|
# it means that this parameter is used in two different pipeline stages.
|
||||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||||
# the number of shared stages.
|
# the number of shared stages.
|
||||||
if pp_size > 1:
|
if self.pp_size > 1:
|
||||||
for shared_param in self.shared_params:
|
for shared_param in self.shared_params:
|
||||||
if self.stage_manager.stage in shared_param:
|
if self.stage_manager.stage in shared_param:
|
||||||
stage_shared_param = shared_param[self.stage_manager.stage]
|
stage_shared_param = shared_param[self.stage_manager.stage]
|
||||||
|
@ -269,10 +385,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
total_norm_exponentiated += grad_norm_exponentiated
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||||
if tp_size > 1:
|
if self.tp_size > 1:
|
||||||
# compute norm in tp process group
|
# compute norm in tp process group
|
||||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||||
if pp_size > 1:
|
if self.pp_size > 1:
|
||||||
# compute norm in pp process group
|
# compute norm in pp process group
|
||||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||||
|
|
||||||
|
@ -314,7 +430,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
model: Module,
|
model: HybridParallelModule,
|
||||||
use_pipeline: bool,
|
use_pipeline: bool,
|
||||||
param_info: OrderedDict,
|
param_info: OrderedDict,
|
||||||
precision: str = "fp16",
|
precision: str = "fp16",
|
||||||
|
@ -329,11 +445,14 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
tp_process_group: Optional[ProcessGroup] = None, # if using tp
|
||||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||||
):
|
):
|
||||||
|
self.model = model
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
self.stage_manager = model.stage_manager
|
self.stage_manager = model.stage_manager
|
||||||
self.shared_params = model.shared_params
|
self.shared_params = model.shared_params
|
||||||
self.tp_pg = tp_process_group
|
self.tp_pg = tp_process_group
|
||||||
self.pp_pg = pp_process_group
|
self.pp_pg = pp_process_group
|
||||||
|
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||||
|
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
init_pipeline_optimizer(optim, model)
|
init_pipeline_optimizer(optim, model)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -349,6 +468,59 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def backward(self, loss: Tensor, *args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
This method performs backward pass for gradient computation. If sequence parallelism is enabled
|
||||||
|
and gradient synchronization is required, it will synchronize gradients that are partially derived
|
||||||
|
within sequence parallelism across tp parallelism groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (Tensor): The loss tensor to compute gradients with respect to.
|
||||||
|
*args: Additional positional arguments to be passed to the superclass backward method.
|
||||||
|
**kwargs: Additional keyword arguments to be passed to the superclass backward method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Call the superclass backward method to compute gradients.
|
||||||
|
super().backward(loss, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.model.require_grad_sync:
|
||||||
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
self.model.sync_sp_grads()
|
||||||
|
else:
|
||||||
|
# If gradient synchronization is is not required, return.
|
||||||
|
return
|
||||||
|
|
||||||
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
||||||
|
"""
|
||||||
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
This method performs a backward pass for gradient computation using a precomputed gradient tensor.
|
||||||
|
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||||
|
gradients that are partially derived within sequence parallelism across tp parallelism groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): The input tensor for which gradients are computed.
|
||||||
|
grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Call the superclass backward method to compute gradients.
|
||||||
|
super().backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
|
if self.model.require_grad_sync:
|
||||||
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
self.model.sync_sp_grads()
|
||||||
|
else:
|
||||||
|
# If gradient synchronization is is not required, return.
|
||||||
|
return
|
||||||
|
|
||||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||||
r"""
|
r"""
|
||||||
Compute and return the gradient norm for gradient clipping.
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
@ -363,8 +535,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
if len(param_gradient_pairs) == 0:
|
if len(param_gradient_pairs) == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
|
||||||
pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
|
||||||
norm_type = float(norm_type)
|
norm_type = float(norm_type)
|
||||||
|
|
||||||
if norm_type == inf:
|
if norm_type == inf:
|
||||||
|
@ -374,9 +544,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
|
|
||||||
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
|
||||||
|
|
||||||
if tp_size > 1:
|
if self.tp_size > 1:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||||
if pp_size > 1:
|
if self.pp_size > 1:
|
||||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
|
||||||
|
|
||||||
total_norm = total_norm_cuda.item()
|
total_norm = total_norm_cuda.item()
|
||||||
|
@ -396,16 +566,16 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
|
||||||
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
|
||||||
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
|
||||||
if tp_size > 1:
|
if self.tp_size > 1:
|
||||||
param_for_grad = grad_to_param_mapping[id(grad)]
|
param_for_grad = grad_to_param_mapping[id(grad)]
|
||||||
if not is_distributed_tensor(param_for_grad):
|
if not is_distributed_tensor(param_for_grad):
|
||||||
grad_norm_exponentiated /= tp_size
|
grad_norm_exponentiated /= self.tp_size
|
||||||
|
|
||||||
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
|
||||||
# it means that this parameter is used in two different pipeline stages.
|
# it means that this parameter is used in two different pipeline stages.
|
||||||
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
# To avoid redundant norm calculations, we divide the exponent of this norm by
|
||||||
# the number of shared stages.
|
# the number of shared stages.
|
||||||
if pp_size > 1:
|
if self.pp_size > 1:
|
||||||
for shared_param in self.shared_params:
|
for shared_param in self.shared_params:
|
||||||
if self.stage_manager.stage in shared_param:
|
if self.stage_manager.stage in shared_param:
|
||||||
stage_working_shared_param = shared_param[self.stage_manager.stage]
|
stage_working_shared_param = shared_param[self.stage_manager.stage]
|
||||||
|
@ -416,10 +586,10 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
total_norm_exponentiated += grad_norm_exponentiated
|
total_norm_exponentiated += grad_norm_exponentiated
|
||||||
|
|
||||||
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
|
||||||
if tp_size > 1:
|
if self.tp_size > 1:
|
||||||
# compute norm in tp process group
|
# compute norm in tp process group
|
||||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||||
if pp_size > 1:
|
if self.pp_size > 1:
|
||||||
# compute norm in pp process group
|
# compute norm in pp process group
|
||||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
|
||||||
|
|
||||||
|
@ -433,7 +603,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
model: Module,
|
model: HybridParallelModule,
|
||||||
use_pipeline: bool,
|
use_pipeline: bool,
|
||||||
param_info: OrderedDict,
|
param_info: OrderedDict,
|
||||||
initial_scale: int = 2**16, # grad scaler config
|
initial_scale: int = 2**16, # grad scaler config
|
||||||
|
@ -455,6 +625,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
|
self.model = model
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
self.stage_manager = model.stage_manager
|
self.stage_manager = model.stage_manager
|
||||||
self.shared_params = model.shared_params
|
self.shared_params = model.shared_params
|
||||||
|
@ -483,6 +654,123 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
forced_dtype=forced_dtype,
|
forced_dtype=forced_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sync_dp_grads(self):
|
||||||
|
r"""
|
||||||
|
Synchronize gradients in the data parallelism dimension.
|
||||||
|
|
||||||
|
This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients
|
||||||
|
in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions,
|
||||||
|
namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization
|
||||||
|
and readability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Call the superclass `_sync_grad` method to synchronize gradients.
|
||||||
|
super()._sync_grad()
|
||||||
|
|
||||||
|
def _sync_sp_grads(self):
|
||||||
|
r"""
|
||||||
|
Synchronize gradients that are partially derived within sequence parallelism.
|
||||||
|
|
||||||
|
This method is responsible for synchronizing partially derived gradients across tp parallelism groups.
|
||||||
|
It identifies gradients that ara partially derived or not and synchronizes them.
|
||||||
|
If synchronization is required and gradients are found to be synchronized,
|
||||||
|
it performs the synchronization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_all_working_grads() -> List[Tensor]:
|
||||||
|
"""Retrieve all working gradients from different parameter groups."""
|
||||||
|
all_working_grads = []
|
||||||
|
for group_id in range(self.num_param_groups):
|
||||||
|
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
|
||||||
|
all_working_grads.extend(working_grads)
|
||||||
|
return all_working_grads
|
||||||
|
|
||||||
|
def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
|
||||||
|
"""Identify gradients to be synchronized in the sequence parallelism."""
|
||||||
|
grads_to_sync = []
|
||||||
|
for grad in all_working_grads:
|
||||||
|
param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
|
||||||
|
param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
|
||||||
|
if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
|
||||||
|
grads_to_sync.append(grad)
|
||||||
|
|
||||||
|
if len(grads_to_sync) > 0:
|
||||||
|
return grads_to_sync
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get all working gradients and gradients to be synchronized.
|
||||||
|
all_working_grads = _get_all_working_grads()
|
||||||
|
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
||||||
|
|
||||||
|
if self.require_grad_sync and grads_to_sync is not None:
|
||||||
|
# Synchronize sequence parallelism gradients if required.
|
||||||
|
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
def backward(self, loss, retain_graph=False):
|
||||||
|
"""
|
||||||
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
This method performs the backward pass for gradient computation based on a given loss tensor.
|
||||||
|
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||||
|
gradients that are partially derived within sequence parallelism across TP parallelism groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss: The loss tensor to compute gradients with respect to.
|
||||||
|
retain_graph (bool): Whether to retain the computation graph.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call the superclass backward method to compute gradients.
|
||||||
|
super().backward(loss, retain_graph)
|
||||||
|
|
||||||
|
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||||
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
self._sync_sp_grads()
|
||||||
|
else:
|
||||||
|
# If gradient synchronization is is not required, return.
|
||||||
|
return
|
||||||
|
|
||||||
|
def backward_by_grad(self, tensor, grad):
|
||||||
|
"""
|
||||||
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
|
This method performs a backward pass for gradient computation based on a precomputed gradient tensor.
|
||||||
|
If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
|
||||||
|
gradients that are partially derived within sequence parallelism across TP parallelism groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The input tensor for which gradients are computed.
|
||||||
|
grad: The precomputed gradient tensor to compute gradients with respect to the input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# Call the superclass backward_by_grad method to compute gradients.
|
||||||
|
super().backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
|
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||||
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
|
self._sync_sp_grads()
|
||||||
|
else:
|
||||||
|
# If gradient synchronization is is not required, return.
|
||||||
|
return
|
||||||
|
|
||||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||||
r"""
|
r"""
|
||||||
Compute and return the gradient norm for gradient clipping.
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
@ -768,7 +1056,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
|
model,
|
||||||
|
precision=self.precision,
|
||||||
|
shard_config=self.shard_config,
|
||||||
|
dp_group=self.dp_group,
|
||||||
|
tp_group=self.tp_group,
|
||||||
|
use_ddp=use_ddp,
|
||||||
|
ddp_config=self.ddp_config,
|
||||||
|
custom_policy=self.custom_policy,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
if self.zero_stage == 0:
|
if self.zero_stage == 0:
|
||||||
|
@ -826,17 +1121,32 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
return_outputs: bool = False,
|
return_outputs: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
|
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
|
||||||
# return loss or outputs if needed
|
|
||||||
|
# Create a context for gradient synchronization based on the optimizer type.
|
||||||
|
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
|
||||||
|
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
|
||||||
|
# so we disable it, performing manual reduction instead.
|
||||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||||
|
|
||||||
with ctx:
|
with ctx:
|
||||||
outputs = self.schedule.forward_backward_step(
|
outputs = self.schedule.forward_backward_step(
|
||||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Synchronize the grads of shared parameters of the model.
|
||||||
model.sync_shared_params()
|
model.sync_shared_params()
|
||||||
|
|
||||||
|
# Synchronize sequence parallelism gradients of the model.
|
||||||
|
model.sync_sp_grads()
|
||||||
|
|
||||||
|
# Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
|
||||||
|
# Otherwise, synchronize data parallelism gradients of the model.
|
||||||
|
# This is because these are two different forms of data parallelism.
|
||||||
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
if isinstance(optimizer, HybridParallelZeroOptimizer):
|
||||||
optimizer.sync_grad()
|
optimizer.sync_dp_grads()
|
||||||
else:
|
else:
|
||||||
model.sync_grads()
|
model.sync_dp_grads()
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def prepare_dataloader(
|
def prepare_dataloader(
|
||||||
|
|
|
@ -338,7 +338,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
|
module=model,
|
||||||
|
precision=self.precision,
|
||||||
|
shard_config=self.shard_config,
|
||||||
|
dp_group=self.dp_group,
|
||||||
|
tp_group=self.tp_group,
|
||||||
|
use_ddp=use_ddp,
|
||||||
|
ddp_config=self.ddp_config,
|
||||||
|
custom_policy=self.custom_policy,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
if self.zero_stage == 0:
|
if self.zero_stage == 0:
|
||||||
|
|
|
@ -218,10 +218,10 @@ class TPInferEngine:
|
||||||
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
|
||||||
model_name = model.__class__.__name__
|
model_name = model.__class__.__name__
|
||||||
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
|
||||||
|
|
||||||
model = model.model if self.shard_config.inference_gptq else model
|
model = model.model if self.shard_config.inference_gptq else model
|
||||||
|
policy = get_autopolicy(model, shard_config=self.shard_config)
|
||||||
|
|
||||||
policy = get_autopolicy(model, inference_only=True)
|
|
||||||
self.model, _ = shardformer.optimize(model, policy)
|
self.model, _ = shardformer.optimize(model, policy)
|
||||||
|
|
||||||
if self.shard_config.inference_gptq:
|
if self.shard_config.inference_gptq:
|
||||||
|
|
|
@ -235,6 +235,14 @@ class SubModuleReplacementDescription:
|
||||||
|
|
||||||
|
|
||||||
class Policy(ABC):
|
class Policy(ABC):
|
||||||
|
r"""
|
||||||
|
The base class for all the policies. For each different model, it should have a different policy class,
|
||||||
|
like BertPolicy for Bert Model or OPTPolicy for OPT model.
|
||||||
|
|
||||||
|
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
|
||||||
|
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
|
||||||
|
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self)
|
def __init__(self)
|
||||||
self.model = None
|
self.model = None
|
||||||
|
@ -245,6 +253,16 @@ class Policy(ABC):
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
def set_shard_config(self, shard_config: ShardConfig) -> None:
|
||||||
|
r"""
|
||||||
|
Set shard config as an attribute of the Policy object.
|
||||||
|
Args:
|
||||||
|
shard_config (:class:`ShardConfig`): The shard config to be perform
|
||||||
|
"""
|
||||||
|
self.shard_config = shard_config
|
||||||
|
|
||||||
|
self.config_sanity_check()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def preprocess(self) -> nn.Module:
|
def preprocess(self) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -2,7 +2,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||||
from .linear import Linear1D_Col, Linear1D_Row
|
from .linear import Linear1D_Col, Linear1D_Row
|
||||||
from .loss import cross_entropy_1d
|
from .loss import cross_entropy_1d
|
||||||
from .normalization import FusedLayerNorm, FusedRMSNorm
|
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ __all__ = [
|
||||||
"DropoutForParallelInput",
|
"DropoutForParallelInput",
|
||||||
"DropoutForReplicatedInput",
|
"DropoutForReplicatedInput",
|
||||||
"cross_entropy_1d",
|
"cross_entropy_1d",
|
||||||
|
"BaseLayerNorm",
|
||||||
|
"LayerNorm",
|
||||||
|
"RMSNorm",
|
||||||
"FusedLayerNorm",
|
"FusedLayerNorm",
|
||||||
"FusedRMSNorm",
|
"FusedRMSNorm",
|
||||||
"FusedLinear1D_Col",
|
"FusedLinear1D_Col",
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
|
|
||||||
__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
|
from .utils import SeqParallelUtils
|
||||||
|
|
||||||
|
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
|
||||||
|
|
||||||
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||||
1024,
|
1024,
|
||||||
|
@ -35,7 +38,103 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FusedLayerNorm:
|
class BaseLayerNorm(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
|
||||||
|
"""
|
||||||
|
Convert a native PyTorch layer normalization module to a specific layer normalization module,
|
||||||
|
and optionally mark parameters for gradient aggregation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The native PyTorch layer normalization module to be converted.
|
||||||
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The specific layer normalization module.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the provided module is not an instance of the supported layer normalization type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(BaseLayerNorm):
|
||||||
|
r"""
|
||||||
|
This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FusedLayerNorm is not implemented as a physical class. "
|
||||||
|
"It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Convert a native RMSNorm module to colossalai layer norm module,
|
||||||
|
and optionally mark parameters for gradient aggregation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The native RMSNorm module to be converted.
|
||||||
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The RMSNorm module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
if sp_partial_derived:
|
||||||
|
# Since gradients are computed using only a subset of the data,
|
||||||
|
# aggregation of these gradients is necessary during backpropagation.
|
||||||
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||||
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(BaseLayerNorm):
|
||||||
|
r"""
|
||||||
|
This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"LayerNorm is not implemented as a physical class. "
|
||||||
|
"It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
|
r"""
|
||||||
|
Convert a native pytorch layer norm module to colossalai layer norm module,
|
||||||
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||||
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The LayerNorm module.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||||
|
"""
|
||||||
|
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||||
|
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
if sp_partial_derived:
|
||||||
|
# Since gradients are computed using only a subset of the data,
|
||||||
|
# aggregation of these gradients is necessary during backpropagation.
|
||||||
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||||
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
|
||||||
|
SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class FusedLayerNorm(BaseLayerNorm):
|
||||||
r"""
|
r"""
|
||||||
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
||||||
"""
|
"""
|
||||||
|
@ -43,15 +142,29 @@ class FusedLayerNorm:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FusedLayerNorm is not implemented as a physical class. "
|
"FusedLayerNorm is not implemented as a physical class. "
|
||||||
"It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
|
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
|
def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
r"""
|
r"""
|
||||||
Convert a native pytorch layer norm module to colossalai layer norm module
|
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
||||||
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||||
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||||
"""
|
"""
|
||||||
# check if apex is installed
|
# check if apex is installed
|
||||||
|
|
||||||
|
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pass
|
pass
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -85,10 +198,18 @@ class FusedLayerNorm:
|
||||||
|
|
||||||
layernorm.weight = module.weight
|
layernorm.weight = module.weight
|
||||||
layernorm.bias = module.bias
|
layernorm.bias = module.bias
|
||||||
|
|
||||||
|
if sp_partial_derived:
|
||||||
|
# Since gradients are computed using only a subset of the data,
|
||||||
|
# aggregation of these gradients is necessary during backpropagation.
|
||||||
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||||
|
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
|
||||||
|
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
|
||||||
|
|
||||||
return layernorm
|
return layernorm
|
||||||
|
|
||||||
|
|
||||||
class FusedRMSNorm:
|
class FusedRMSNorm(BaseLayerNorm):
|
||||||
"""
|
"""
|
||||||
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
|
||||||
"""
|
"""
|
||||||
|
@ -96,11 +217,22 @@ class FusedRMSNorm:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FusedRMSNorm is not implemented as a physical class. "
|
"FusedRMSNorm is not implemented as a physical class. "
|
||||||
"It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
|
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
|
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
|
||||||
|
r"""
|
||||||
|
Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
|
||||||
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||||
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: FusedRMSNorm module.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -124,4 +256,10 @@ class FusedRMSNorm:
|
||||||
|
|
||||||
rmsnorm.weight = module.weight
|
rmsnorm.weight = module.weight
|
||||||
|
|
||||||
|
if sp_partial_derived:
|
||||||
|
# Since gradients are computed using only a subset of the data,
|
||||||
|
# aggregation of these gradients is necessary during backpropagation.
|
||||||
|
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
|
||||||
|
SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
|
||||||
|
|
||||||
return rmsnorm
|
return rmsnorm
|
||||||
|
|
|
@ -1,8 +1,82 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch import nn
|
||||||
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
from torch.distributed import ProcessGroup, get_world_size
|
||||||
|
|
||||||
|
|
||||||
|
class SeqParallelUtils:
|
||||||
|
@staticmethod
|
||||||
|
def marked_as_sp_partial_derived_param(param):
|
||||||
|
"""
|
||||||
|
Mark a parameter as partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param: The parameter to mark as partially derived.
|
||||||
|
"""
|
||||||
|
setattr(param, "partial_derived", True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_sp_partial_derived_param(param):
|
||||||
|
"""
|
||||||
|
Check if a parameter is marked as partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param: The parameter to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the parameter is marked as partially derived, False otherwise.
|
||||||
|
"""
|
||||||
|
return getattr(param, "partial_derived", False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Allreduce partial derived gradients across the specified process group.
|
||||||
|
|
||||||
|
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tp_group (ProcessGroup): The process group for gradient synchronization.
|
||||||
|
model (nn.Module): The model from which gradients will be synchronized.
|
||||||
|
grads (List[torch.Tensor]): The list of gradients to be synchronized.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If both `model` and `grads` are provided or neither is provided.
|
||||||
|
"""
|
||||||
|
# Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
|
||||||
|
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
|
||||||
|
|
||||||
|
# Get the size of the process group, which determines whether synchronization is needed.
|
||||||
|
tp_size = get_world_size(tp_group) if tp_group is not None else 1
|
||||||
|
|
||||||
|
if tp_size == 1:
|
||||||
|
# If the process group size is 1, no synchronization is required.
|
||||||
|
return
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
# If `model` is provided, extract partial derived gradients from the model's parameters.
|
||||||
|
grads = []
|
||||||
|
for p in model.parameters():
|
||||||
|
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
|
||||||
|
grads.append(p.grad.data)
|
||||||
|
|
||||||
|
# Flatten and reduce the gradients using the specified process group.
|
||||||
|
coalesced = _flatten_dense_tensors(grads)
|
||||||
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
||||||
|
|
||||||
|
# Unflatten the synchronized gradients and update the model's gradients.
|
||||||
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
|
buf.copy_(synced)
|
||||||
|
else:
|
||||||
|
# If `grads` are provided explicitly, synchronize those gradients directly.
|
||||||
|
coalesced = _flatten_dense_tensors(grads)
|
||||||
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
||||||
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
|
buf.copy_(synced)
|
||||||
|
|
||||||
|
|
||||||
class Randomizer:
|
class Randomizer:
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ..shard.shard_config import ShardConfig
|
||||||
from .base_policy import Policy
|
from .base_policy import Policy
|
||||||
|
|
||||||
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
|
||||||
|
@ -197,7 +198,7 @@ def _fullname(obj):
|
||||||
return module + "." + klass.__qualname__
|
return module + "." + klass.__qualname__
|
||||||
|
|
||||||
|
|
||||||
def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
|
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy:
|
||||||
r"""
|
r"""
|
||||||
Return the auto policy for the model
|
Return the auto policy for the model
|
||||||
|
|
||||||
|
@ -208,7 +209,7 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
|
||||||
:class:`Policy`: The auto policy for the model
|
:class:`Policy`: The auto policy for the model
|
||||||
"""
|
"""
|
||||||
full_name = _fullname(model)
|
full_name = _fullname(model)
|
||||||
if inference_only:
|
if shard_config.inference_only:
|
||||||
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
policy_location = _INFER_POLICY_LIST.get(full_name, None)
|
||||||
else:
|
else:
|
||||||
policy_location = _POLICY_LIST.get(full_name, None)
|
policy_location = _POLICY_LIST.get(full_name, None)
|
||||||
|
@ -218,5 +219,5 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
|
||||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
policy = import_policy(policy_location, inference_only)
|
policy = import_policy(policy_location, shard_config.inference_only)
|
||||||
return policy()
|
return policy()
|
||||||
|
|
|
@ -11,6 +11,7 @@ from torch.nn import Module
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
|
from ..layer.normalization import BaseLayerNorm
|
||||||
from ..layer.parallel_module import ParallelModule
|
from ..layer.parallel_module import ParallelModule
|
||||||
from ..shard.shard_config import ShardConfig
|
from ..shard.shard_config import ShardConfig
|
||||||
|
|
||||||
|
@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
|
||||||
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
|
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
|
||||||
"""
|
"""
|
||||||
suffix: str
|
suffix: str
|
||||||
target_module: ParallelModule
|
target_module: Union[ParallelModule, BaseLayerNorm]
|
||||||
kwargs: Dict[str, Any] = None
|
kwargs: Dict[str, Any] = None
|
||||||
ignore_if_not_exist: bool = False
|
ignore_if_not_exist: bool = False
|
||||||
|
|
||||||
|
@ -77,7 +78,6 @@ class Policy(ABC):
|
||||||
def set_model(self, model: nn.Module) -> None:
|
def set_model(self, model: nn.Module) -> None:
|
||||||
r"""
|
r"""
|
||||||
Set model as an attribute of the Policy object so that we can access the model's attributes.
|
Set model as an attribute of the Policy object so that we can access the model's attributes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (:class:`nn.Module`): The model to be perform
|
model (:class:`nn.Module`): The model to be perform
|
||||||
"""
|
"""
|
||||||
|
@ -86,11 +86,11 @@ class Policy(ABC):
|
||||||
def set_shard_config(self, shard_config: ShardConfig) -> None:
|
def set_shard_config(self, shard_config: ShardConfig) -> None:
|
||||||
r"""
|
r"""
|
||||||
Set shard config as an attribute of the Policy object.
|
Set shard config as an attribute of the Policy object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shard_config (:class:`ShardConfig`): The shard config to be perform
|
shard_config (:class:`ShardConfig`): The shard config to be perform
|
||||||
"""
|
"""
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
|
|
||||||
self.config_sanity_check()
|
self.config_sanity_check()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -60,6 +60,12 @@ class BertPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -141,33 +147,34 @@ class BertPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
# Handle bert layer
|
||||||
# Handle bert layer
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=[
|
||||||
description=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="attention.output.LayerNorm",
|
||||||
suffix="attention.output.LayerNorm",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.LayerNorm",
|
suffix="output.LayerNorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
],
|
),
|
||||||
policy=policy,
|
],
|
||||||
target_key=BertLayer,
|
policy=policy,
|
||||||
)
|
target_key=BertLayer,
|
||||||
# handle embedding layer
|
)
|
||||||
self.append_or_create_submodule_replacement(
|
# handle embedding layer
|
||||||
description=[
|
self.append_or_create_submodule_replacement(
|
||||||
SubModuleReplacementDescription(
|
description=[
|
||||||
suffix="LayerNorm",
|
SubModuleReplacementDescription(
|
||||||
target_module=col_nn.FusedLayerNorm,
|
suffix="LayerNorm",
|
||||||
)
|
target_module=norm_cls,
|
||||||
],
|
)
|
||||||
policy=policy,
|
],
|
||||||
target_key=BertEmbeddings,
|
policy=policy,
|
||||||
)
|
target_key=BertEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
@ -288,9 +295,6 @@ class BertPolicy(Policy):
|
||||||
|
|
||||||
# BertModel
|
# BertModel
|
||||||
class BertModelPolicy(BertPolicy):
|
class BertModelPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
from transformers.models.bert.modeling_bert import BertModel
|
from transformers.models.bert.modeling_bert import BertModel
|
||||||
|
@ -313,9 +317,6 @@ class BertModelPolicy(BertPolicy):
|
||||||
|
|
||||||
# BertForPreTraining
|
# BertForPreTraining
|
||||||
class BertForPreTrainingPolicy(BertPolicy):
|
class BertForPreTrainingPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
policy = self.add_lm_head_policy(policy)
|
policy = self.add_lm_head_policy(policy)
|
||||||
|
@ -355,9 +356,6 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||||
|
|
||||||
# BertLMHeadModel
|
# BertLMHeadModel
|
||||||
class BertLMHeadModelPolicy(BertPolicy):
|
class BertLMHeadModelPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
policy = self.add_lm_head_policy(policy)
|
policy = self.add_lm_head_policy(policy)
|
||||||
|
@ -396,9 +394,6 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||||
|
|
||||||
# BertForMaskedLM
|
# BertForMaskedLM
|
||||||
class BertForMaskedLMPolicy(BertPolicy):
|
class BertForMaskedLMPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
policy = self.add_lm_head_policy(policy)
|
policy = self.add_lm_head_policy(policy)
|
||||||
|
@ -437,9 +432,6 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||||
|
|
||||||
# BertForSequenceClassification
|
# BertForSequenceClassification
|
||||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
from transformers.models.bert.modeling_bert import BertForSequenceClassification
|
||||||
|
|
||||||
|
@ -484,9 +476,6 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||||
|
|
||||||
# BertForTokenClassification
|
# BertForTokenClassification
|
||||||
class BertForTokenClassificationPolicy(BertPolicy):
|
class BertForTokenClassificationPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
from transformers.models.bert.modeling_bert import BertForTokenClassification
|
||||||
|
|
||||||
|
@ -531,9 +520,6 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||||
|
|
||||||
# BertForNextSentencePrediction
|
# BertForNextSentencePrediction
|
||||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
|
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
|
||||||
|
@ -564,9 +550,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||||
|
|
||||||
# BertForMultipleChoice
|
# BertForMultipleChoice
|
||||||
class BertForMultipleChoicePolicy(BertPolicy):
|
class BertForMultipleChoicePolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
from transformers.models.bert.modeling_bert import BertForMultipleChoice
|
||||||
|
|
||||||
|
@ -610,9 +593,6 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||||
|
|
||||||
|
|
||||||
class BertForQuestionAnsweringPolicy(BertPolicy):
|
class BertForQuestionAnsweringPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
|
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,11 @@ class BlipPolicy(Policy):
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[Blip2EncoderLayer] = ModulePolicyDescription(
|
policy[Blip2EncoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
|
@ -214,94 +219,93 @@ class BlipPolicy(Policy):
|
||||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
# Handle Blip2EncoderLayer layer
|
||||||
# Handle Blip2EncoderLayer layer
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=[
|
||||||
description=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="layer_norm1",
|
||||||
suffix="layer_norm1",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="layer_norm2",
|
||||||
suffix="layer_norm2",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
],
|
||||||
],
|
policy=policy,
|
||||||
policy=policy,
|
target_key=Blip2EncoderLayer,
|
||||||
target_key=Blip2EncoderLayer,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# handle Blip2VisionModel layer
|
# handle Blip2VisionModel layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="post_layernorm",
|
suffix="post_layernorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=Blip2VisionModel,
|
target_key=Blip2VisionModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle Blip2VisionModel layer
|
# handle Blip2VisionModel layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="layernorm",
|
suffix="layernorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=Blip2QFormerModel,
|
target_key=Blip2QFormerModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle Blip2QFormerLayer layer
|
# handle Blip2QFormerLayer layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.LayerNorm",
|
suffix="attention.output.LayerNorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="crossattention.output.LayerNorm",
|
suffix="crossattention.output.LayerNorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output_query.LayerNorm",
|
suffix="output_query.LayerNorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=Blip2QFormerLayer,
|
target_key=Blip2QFormerLayer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle OPTForCausalLM layer
|
# handle OPTForCausalLM layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="model.decoder.final_layer_norm",
|
suffix="model.decoder.final_layer_norm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=OPTForCausalLM,
|
target_key=OPTForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle OPTDecoderLayer layer
|
# handle OPTDecoderLayer layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn_layer_norm",
|
suffix="self_attn_layer_norm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="final_layer_norm",
|
suffix="final_layer_norm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=OPTDecoderLayer,
|
target_key=OPTDecoderLayer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
|
|
@ -42,6 +42,10 @@ class BloomPolicy(Policy):
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = col_nn.LayerNorm
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -97,38 +101,39 @@ class BloomPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
# handle bloom model
|
||||||
# handle bloom model
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=[
|
||||||
description=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="ln_f",
|
||||||
suffix="ln_f",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="word_embeddings_layernorm",
|
||||||
suffix="word_embeddings_layernorm",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
],
|
||||||
],
|
policy=policy,
|
||||||
policy=policy,
|
target_key=BloomModel,
|
||||||
target_key=BloomModel,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# handle bloom block
|
# handle bloom block
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="input_layernorm",
|
suffix="input_layernorm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
SubModuleReplacementDescription(
|
),
|
||||||
suffix="post_attention_layernorm",
|
SubModuleReplacementDescription(
|
||||||
target_module=col_nn.FusedLayerNorm,
|
suffix="post_attention_layernorm",
|
||||||
),
|
target_module=norm_cls,
|
||||||
],
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
policy=policy,
|
),
|
||||||
target_key=BloomBlock,
|
],
|
||||||
)
|
policy=policy,
|
||||||
|
target_key=BloomBlock,
|
||||||
|
)
|
||||||
|
|
||||||
if use_sequence_parallel:
|
if use_sequence_parallel:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
|
@ -225,9 +230,6 @@ class BloomPolicy(Policy):
|
||||||
|
|
||||||
|
|
||||||
class BloomModelPolicy(BloomPolicy):
|
class BloomModelPolicy(BloomPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
from transformers.models.bloom.modeling_bloom import BloomModel
|
from transformers.models.bloom.modeling_bloom import BloomModel
|
||||||
|
|
|
@ -45,6 +45,16 @@ class ChatGLMPolicy(Policy):
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
if self.model.config.rmsnorm:
|
||||||
|
norm_cls = col_nn.FusedRMSNorm
|
||||||
|
else:
|
||||||
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
|
else:
|
||||||
|
if self.model.config.rmsnorm:
|
||||||
|
norm_cls = col_nn.RMSNorm
|
||||||
|
else:
|
||||||
|
norm_cls = col_nn.LayerNorm
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -96,52 +106,34 @@ class ChatGLMPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
self.append_or_create_submodule_replacement(
|
||||||
if not self.model.config.rmsnorm:
|
description=[
|
||||||
self.append_or_create_submodule_replacement(
|
SubModuleReplacementDescription(
|
||||||
description=[
|
suffix="input_layernorm",
|
||||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
|
target_module=norm_cls,
|
||||||
SubModuleReplacementDescription(
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
],
|
suffix="post_attention_layernorm",
|
||||||
policy=policy,
|
target_module=norm_cls,
|
||||||
target_key=GLMBlock,
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
)
|
),
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=GLMBlock,
|
||||||
|
)
|
||||||
|
|
||||||
if self.model.config.post_layer_norm:
|
if self.model.config.post_layer_norm:
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm
|
suffix="encoder.final_layernorm",
|
||||||
)
|
target_module=norm_cls,
|
||||||
],
|
|
||||||
policy=policy,
|
|
||||||
target_key=ChatGLMModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.append_or_create_submodule_replacement(
|
|
||||||
description=[
|
|
||||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm
|
|
||||||
),
|
|
||||||
],
|
|
||||||
policy=policy,
|
|
||||||
target_key=GLMBlock,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.model.config.post_layer_norm:
|
|
||||||
self.append_or_create_submodule_replacement(
|
|
||||||
description=[
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm
|
|
||||||
)
|
|
||||||
],
|
|
||||||
policy=policy,
|
|
||||||
target_key=ChatGLMModel,
|
|
||||||
)
|
)
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=ChatGLMModel,
|
||||||
|
)
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
@ -224,9 +216,6 @@ class ChatGLMPolicy(Policy):
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMModelPolicy(ChatGLMPolicy):
|
class ChatGLMModelPolicy(ChatGLMPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,11 @@ class GPT2Policy(Policy):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = col_nn.LayerNorm
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
@ -102,33 +107,37 @@ class GPT2Policy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=SubModuleReplacementDescription(
|
||||||
description=SubModuleReplacementDescription(
|
suffix="ln_f",
|
||||||
suffix="ln_f",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
policy=policy,
|
||||||
policy=policy,
|
target_key=GPT2Model,
|
||||||
target_key=GPT2Model,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="ln_1",
|
suffix="ln_1",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
SubModuleReplacementDescription(
|
),
|
||||||
suffix="ln_2",
|
SubModuleReplacementDescription(
|
||||||
target_module=col_nn.FusedLayerNorm,
|
suffix="ln_2",
|
||||||
),
|
target_module=norm_cls,
|
||||||
SubModuleReplacementDescription(
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
],
|
suffix="ln_cross_attn",
|
||||||
policy=policy,
|
target_module=norm_cls,
|
||||||
target_key=GPT2Block,
|
ignore_if_not_exist=True,
|
||||||
)
|
kwargs={"sp_partial_derived": use_sequence_parallel},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=GPT2Block,
|
||||||
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
|
@ -192,9 +201,6 @@ class GPT2Policy(Policy):
|
||||||
|
|
||||||
# GPT2Model
|
# GPT2Model
|
||||||
class GPT2ModelPolicy(GPT2Policy):
|
class GPT2ModelPolicy(GPT2Policy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
|
|
||||||
|
@ -216,9 +222,6 @@ class GPT2ModelPolicy(GPT2Policy):
|
||||||
|
|
||||||
# GPT2LMHeadModel
|
# GPT2LMHeadModel
|
||||||
class GPT2LMHeadModelPolicy(GPT2Policy):
|
class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||||
|
|
||||||
|
@ -263,9 +266,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||||
|
|
||||||
# GPT2DoubleHeadsModel
|
# GPT2DoubleHeadsModel
|
||||||
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
|
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
|
||||||
|
|
||||||
|
@ -317,9 +317,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||||
|
|
||||||
# GPT2ForQuestionAnswering
|
# GPT2ForQuestionAnswering
|
||||||
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
|
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
|
||||||
|
|
||||||
|
@ -347,9 +344,6 @@ class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
||||||
|
|
||||||
# GPT2ForTokenClassification
|
# GPT2ForTokenClassification
|
||||||
class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
|
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
|
||||||
|
|
||||||
|
@ -387,9 +381,6 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||||
|
|
||||||
# GPT2ForSequenceClassification
|
# GPT2ForSequenceClassification
|
||||||
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
|
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
||||||
|
|
||||||
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
@ -35,6 +35,11 @@ class LlamaPolicy(Policy):
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = FusedRMSNorm
|
||||||
|
else:
|
||||||
|
norm_cls = RMSNorm
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
@ -93,31 +98,31 @@ class LlamaPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=[
|
||||||
description=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="input_layernorm",
|
||||||
suffix="input_layernorm",
|
target_module=norm_cls,
|
||||||
target_module=FusedRMSNorm,
|
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="post_attention_layernorm",
|
|
||||||
target_module=FusedRMSNorm,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
policy=policy,
|
|
||||||
target_key=LlamaDecoderLayer,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.append_or_create_submodule_replacement(
|
|
||||||
description=SubModuleReplacementDescription(
|
|
||||||
suffix="norm",
|
|
||||||
target_module=FusedRMSNorm,
|
|
||||||
),
|
),
|
||||||
policy=policy,
|
SubModuleReplacementDescription(
|
||||||
target_key=LlamaModel,
|
suffix="post_attention_layernorm",
|
||||||
)
|
target_module=norm_cls,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaDecoderLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="norm",
|
||||||
|
target_module=norm_cls,
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
|
@ -174,9 +179,6 @@ class LlamaPolicy(Policy):
|
||||||
|
|
||||||
|
|
||||||
class LlamaModelPolicy(LlamaPolicy):
|
class LlamaModelPolicy(LlamaPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
from transformers.models.llama.modeling_llama import LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Callable, Dict, List
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||||
|
|
||||||
from .._utils import getattr_
|
from .._utils import getattr_
|
||||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||||
|
@ -42,6 +42,12 @@ class OPTPolicy(Policy):
|
||||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = LayerNorm
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
@ -94,26 +100,25 @@ class OPTPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=SubModuleReplacementDescription(
|
||||||
description=SubModuleReplacementDescription(
|
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||||
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=OPTDecoder,
|
||||||
|
)
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||||
),
|
),
|
||||||
policy=policy,
|
SubModuleReplacementDescription(
|
||||||
target_key=OPTDecoder,
|
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
|
||||||
)
|
),
|
||||||
self.append_or_create_submodule_replacement(
|
],
|
||||||
description=[
|
policy=policy,
|
||||||
SubModuleReplacementDescription(
|
target_key=OPTDecoderLayer,
|
||||||
suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
)
|
||||||
),
|
|
||||||
SubModuleReplacementDescription(
|
|
||||||
suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
|
|
||||||
),
|
|
||||||
],
|
|
||||||
policy=policy,
|
|
||||||
target_key=OPTDecoderLayer,
|
|
||||||
)
|
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
@ -183,9 +188,6 @@ class OPTPolicy(Policy):
|
||||||
|
|
||||||
|
|
||||||
class OPTModelPolicy(OPTPolicy):
|
class OPTModelPolicy(OPTPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.opt.modeling_opt import OPTModel
|
from transformers.models.opt.modeling_opt import OPTModel
|
||||||
|
|
||||||
|
@ -253,9 +255,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
||||||
|
|
||||||
|
|
||||||
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.opt.modeling_opt import OPTForSequenceClassification
|
from transformers.models.opt.modeling_opt import OPTForSequenceClassification
|
||||||
|
|
||||||
|
@ -281,9 +280,6 @@ class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||||
|
|
||||||
|
|
||||||
class OPTForQuestionAnsweringPolicy(OPTPolicy):
|
class OPTForQuestionAnsweringPolicy(OPTPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering
|
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,11 @@ class SamPolicy(Policy):
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[SamVisionLayer] = ModulePolicyDescription(
|
policy[SamVisionLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
|
@ -151,58 +156,57 @@ class SamPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
# Handle SamVisionLayer
|
||||||
# Handle SamVisionLayer
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=[
|
||||||
description=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="layer_norm1",
|
||||||
suffix="layer_norm1",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="layer_norm2",
|
||||||
suffix="layer_norm2",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
],
|
||||||
],
|
policy=policy,
|
||||||
policy=policy,
|
target_key=SamVisionLayer,
|
||||||
target_key=SamVisionLayer,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Handle SamTwoWayAttentionBlock
|
# Handle SamTwoWayAttentionBlock
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="layer_norm1",
|
suffix="layer_norm1",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="layer_norm2",
|
suffix="layer_norm2",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="layer_norm3",
|
suffix="layer_norm3",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="layer_norm4",
|
suffix="layer_norm4",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=SamTwoWayAttentionBlock,
|
target_key=SamTwoWayAttentionBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle SamTwoWayTransformer
|
# Handle SamTwoWayTransformer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="layer_norm_final_attn",
|
suffix="layer_norm_final_attn",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=SamTwoWayTransformer,
|
target_key=SamTwoWayTransformer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.shardformer.layer import (
|
||||||
FusedRMSNorm,
|
FusedRMSNorm,
|
||||||
Linear1D_Col,
|
Linear1D_Col,
|
||||||
Linear1D_Row,
|
Linear1D_Row,
|
||||||
|
RMSNorm,
|
||||||
VocabParallelEmbedding1D,
|
VocabParallelEmbedding1D,
|
||||||
)
|
)
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||||
|
@ -58,6 +59,11 @@ class T5BasePolicy(Policy):
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = FusedRMSNorm
|
||||||
|
else:
|
||||||
|
norm_cls = RMSNorm
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
@ -169,38 +175,37 @@ class T5BasePolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=SubModuleReplacementDescription(
|
||||||
description=SubModuleReplacementDescription(
|
suffix="layer_norm",
|
||||||
suffix="layer_norm",
|
target_module=norm_cls,
|
||||||
target_module=FusedRMSNorm,
|
),
|
||||||
),
|
policy=policy,
|
||||||
policy=policy,
|
target_key=T5LayerFF,
|
||||||
target_key=T5LayerFF,
|
)
|
||||||
)
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=SubModuleReplacementDescription(
|
||||||
description=SubModuleReplacementDescription(
|
suffix="layer_norm",
|
||||||
suffix="layer_norm",
|
target_module=norm_cls,
|
||||||
target_module=FusedRMSNorm,
|
),
|
||||||
),
|
policy=policy,
|
||||||
policy=policy,
|
target_key=T5LayerFF,
|
||||||
target_key=T5LayerFF,
|
)
|
||||||
)
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
|
||||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
|
policy=policy,
|
||||||
policy=policy,
|
target_key=T5LayerSelfAttention,
|
||||||
target_key=T5LayerSelfAttention,
|
)
|
||||||
)
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
|
||||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
|
policy=policy,
|
||||||
policy=policy,
|
target_key=T5LayerCrossAttention,
|
||||||
target_key=T5LayerCrossAttention,
|
)
|
||||||
)
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls),
|
||||||
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm),
|
policy=policy,
|
||||||
policy=policy,
|
target_key=T5Stack,
|
||||||
target_key=T5Stack,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
@ -363,9 +368,6 @@ class T5BasePolicy(Policy):
|
||||||
|
|
||||||
|
|
||||||
class T5ModelPolicy(T5BasePolicy):
|
class T5ModelPolicy(T5BasePolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import T5Model
|
from transformers import T5Model
|
||||||
|
|
||||||
|
@ -402,9 +404,6 @@ class T5ModelPolicy(T5BasePolicy):
|
||||||
|
|
||||||
|
|
||||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import T5ForConditionalGeneration
|
from transformers import T5ForConditionalGeneration
|
||||||
|
|
||||||
|
@ -466,9 +465,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||||
|
|
||||||
|
|
||||||
class T5EncoderPolicy(T5BasePolicy):
|
class T5EncoderPolicy(T5BasePolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import T5EncoderModel
|
from transformers import T5EncoderModel
|
||||||
|
|
||||||
|
|
|
@ -159,9 +159,6 @@ class ViTPolicy(Policy):
|
||||||
|
|
||||||
# ViTModel
|
# ViTModel
|
||||||
class ViTModelPolicy(ViTPolicy):
|
class ViTModelPolicy(ViTPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.vit.modeling_vit import ViTModel
|
from transformers.models.vit.modeling_vit import ViTModel
|
||||||
|
|
||||||
|
@ -227,9 +224,6 @@ class ViTForImageClassificationPolicy(ViTPolicy):
|
||||||
|
|
||||||
# ViTForMaskedImageModeling
|
# ViTForMaskedImageModeling
|
||||||
class ViTForMaskedImageModelingPolicy(ViTPolicy):
|
class ViTForMaskedImageModelingPolicy(ViTPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
|
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,11 @@ class WhisperPolicy(Policy):
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
|
else:
|
||||||
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -161,62 +166,61 @@ class WhisperPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
if self.shard_config.enable_fused_normalization:
|
# Handle encoder layer
|
||||||
# Handle encoder layer
|
self.append_or_create_submodule_replacement(
|
||||||
self.append_or_create_submodule_replacement(
|
description=[
|
||||||
description=[
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="self_attn_layer_norm",
|
||||||
suffix="self_attn_layer_norm",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
SubModuleReplacementDescription(
|
||||||
SubModuleReplacementDescription(
|
suffix="final_layer_norm",
|
||||||
suffix="final_layer_norm",
|
target_module=norm_cls,
|
||||||
target_module=col_nn.FusedLayerNorm,
|
),
|
||||||
),
|
],
|
||||||
],
|
policy=policy,
|
||||||
policy=policy,
|
target_key=WhisperEncoderLayer,
|
||||||
target_key=WhisperEncoderLayer,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Handle decoder layer
|
# Handle decoder layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn_layer_norm",
|
suffix="self_attn_layer_norm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="final_layer_norm",
|
suffix="final_layer_norm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=WhisperDecoderLayer,
|
target_key=WhisperDecoderLayer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle encoder layer
|
# handle encoder layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="layer_norm",
|
suffix="layer_norm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=WhisperEncoder,
|
target_key=WhisperEncoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle decoder layer
|
# handle decoder layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[
|
description=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="layer_norm",
|
suffix="layer_norm",
|
||||||
target_module=col_nn.FusedLayerNorm,
|
target_module=norm_cls,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=WhisperDecoder,
|
target_key=WhisperDecoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
# enable flash attention
|
# enable flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
|
@ -416,9 +420,6 @@ class WhisperPolicy(Policy):
|
||||||
|
|
||||||
# WhisperModel
|
# WhisperModel
|
||||||
class WhisperModelPolicy(WhisperPolicy):
|
class WhisperModelPolicy(WhisperPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import WhisperModel
|
from transformers import WhisperModel
|
||||||
|
|
||||||
|
@ -441,9 +442,6 @@ class WhisperModelPolicy(WhisperPolicy):
|
||||||
|
|
||||||
# WhisperForConditionalGeneration
|
# WhisperForConditionalGeneration
|
||||||
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
from transformers import WhisperForConditionalGeneration
|
from transformers import WhisperForConditionalGeneration
|
||||||
|
|
||||||
|
@ -502,9 +500,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||||
|
|
||||||
# WhisperForAudioClassification
|
# WhisperForAudioClassification
|
||||||
class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def preprocess(self):
|
def preprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
|
@ -27,8 +27,8 @@ class ModelSharder(object):
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
|
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
|
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy
|
||||||
|
|
||||||
def shard(self) -> List[Dict[int, Tensor]]:
|
def shard(self) -> List[Dict[int, Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -196,7 +196,7 @@ class ModelSharder(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
replace_layer = target_module.from_native_module(
|
replace_layer = target_module.from_native_module(
|
||||||
native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs
|
native_sub_module, process_group=self.shard_config.tensor_parallel_process_group, **kwargs
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
|
@ -700,7 +700,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
############################
|
############################
|
||||||
|
|
||||||
# this method is used to sync gradient manually
|
# this method is used to sync gradient manually
|
||||||
def sync_grad(self):
|
def _sync_grad(self):
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
param_group = self._working_param_groups[group_id]
|
param_group = self._working_param_groups[group_id]
|
||||||
for param in param_group:
|
for param in param_group:
|
||||||
|
@ -713,7 +713,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
# if not overlapping communication (no reduction hook is attached) when zero1
|
# if not overlapping communication (no reduction hook is attached) when zero1
|
||||||
# we need to manually reduce these gradients
|
# we need to manually reduce these gradients
|
||||||
if not partition_grad and not self._overlap_communication:
|
if not partition_grad and not self._overlap_communication:
|
||||||
self.sync_grad()
|
self._sync_grad()
|
||||||
else:
|
else:
|
||||||
self._run_reduction()
|
self._run_reduction()
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,6 @@ for k, v in inputs.items():
|
||||||
new_shape[0] = 16
|
new_shape[0] = 16
|
||||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||||
model = transformers.LlamaForCausalLM(
|
model = transformers.LlamaForCausalLM(
|
||||||
transformers.LlamaConfig(
|
transformers.LlamaConfig(
|
||||||
|
@ -59,6 +58,7 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
|
||||||
@parameterize("pp_size", [2])
|
@parameterize("pp_size", [2])
|
||||||
@parameterize("max_output_len", [4])
|
@parameterize("max_output_len", [4])
|
||||||
@parameterize("micro_batch_size", [1])
|
@parameterize("micro_batch_size", [1])
|
||||||
|
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
|
||||||
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
|
||||||
|
@ -76,6 +76,7 @@ def check_tp_pipeline_inference(rank, world_size, port):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
|
@ -128,7 +128,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -147,7 +147,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -157,7 +157,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -174,7 +174,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -199,7 +199,7 @@ def run_test(test_config):
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -208,7 +208,7 @@ def run_test(test_config):
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
|
|
@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -114,7 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -123,7 +123,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -148,7 +148,7 @@ def run_test(test_config):
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
|
|
@ -106,7 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -116,7 +116,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -126,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -137,7 +137,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -146,7 +146,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -155,7 +155,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -181,7 +181,7 @@ def run_test(test_config):
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "bf16",
|
"precision": "bf16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
@ -191,7 +191,7 @@ def run_test(test_config):
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 4,
|
"num_microbatches": 4,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"enable_all_optimization": False,
|
"enable_all_optimization": True,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp16",
|
"precision": "fp16",
|
||||||
"max_norm": 5,
|
"max_norm": 5,
|
||||||
|
|
|
@ -34,6 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
bert = unwrap_model(org_model, "BertModel", "bert")
|
bert = unwrap_model(org_model, "BertModel", "bert")
|
||||||
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
||||||
|
|
||||||
|
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
|
||||||
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
||||||
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
||||||
|
|
||||||
|
@ -50,8 +51,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
row_layer_grads = get_grad_tensors_for_check(
|
row_layer_grads = get_grad_tensors_for_check(
|
||||||
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
norm_layer_grads = get_grad_tensors_for_check(
|
||||||
|
bert,
|
||||||
|
sharded_bert,
|
||||||
|
norm_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
grads_to_check.update(col_layer_grads)
|
grads_to_check.update(col_layer_grads)
|
||||||
grads_to_check.update(row_layer_grads)
|
grads_to_check.update(row_layer_grads)
|
||||||
|
grads_to_check.update(norm_layer_grads)
|
||||||
|
|
||||||
# optimizer executes step
|
# optimizer executes step
|
||||||
org_optimizer.step()
|
org_optimizer.step()
|
||||||
|
@ -85,6 +99,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"enable_all_optimization": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp32",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
|
|
@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
bloom = unwrap_model(org_model, "BloomModel", "transformer")
|
bloom = unwrap_model(org_model, "BloomModel", "transformer")
|
||||||
sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer")
|
sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer")
|
||||||
|
|
||||||
|
norm_layer_for_check = ["word_embeddings_layernorm", "h[0].input_layernorm"]
|
||||||
row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"]
|
row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"]
|
||||||
col_layer_for_check = ["h[0].self_attention.dense"]
|
col_layer_for_check = ["h[0].self_attention.dense"]
|
||||||
|
|
||||||
|
@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
col_layer_grads = get_grad_tensors_for_check(
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
norm_layer_grads = get_grad_tensors_for_check(
|
||||||
|
bloom,
|
||||||
|
sharded_bloom,
|
||||||
|
norm_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
grads_to_check.update(col_layer_grads)
|
grads_to_check.update(col_layer_grads)
|
||||||
grads_to_check.update(row_layer_grads)
|
grads_to_check.update(row_layer_grads)
|
||||||
|
grads_to_check.update(norm_layer_grads)
|
||||||
|
|
||||||
# optimizer executes step
|
# optimizer executes step
|
||||||
org_optimizer.step()
|
org_optimizer.step()
|
||||||
|
|
|
@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer")
|
chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer")
|
||||||
shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer")
|
shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer")
|
||||||
|
|
||||||
|
norm_layer_for_check = ["encoder.layers[0].input_layernorm"]
|
||||||
row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"]
|
row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"]
|
||||||
col_layer_for_check = ["encoder.layers[0].self_attention.dense"]
|
col_layer_for_check = ["encoder.layers[0].self_attention.dense"]
|
||||||
|
|
||||||
|
@ -66,8 +67,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
dim=1,
|
dim=1,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
norm_layer_grads = get_grad_tensors_for_check(
|
||||||
|
chatglm_model,
|
||||||
|
shard_chatglm_model,
|
||||||
|
norm_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
grads_to_check.update(col_layer_grads)
|
grads_to_check.update(col_layer_grads)
|
||||||
grads_to_check.update(row_layer_grads)
|
grads_to_check.update(row_layer_grads)
|
||||||
|
grads_to_check.update(norm_layer_grads)
|
||||||
|
|
||||||
# optimizer executes step
|
# optimizer executes step
|
||||||
org_optimizer.step()
|
org_optimizer.step()
|
||||||
|
|
|
@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
gpt2 = unwrap_model(org_model, "GPT2Model", "transformer")
|
gpt2 = unwrap_model(org_model, "GPT2Model", "transformer")
|
||||||
sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer")
|
sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer")
|
||||||
|
|
||||||
|
norm_layer_for_check = ["h[0].ln_1", "h[0].ln_2"]
|
||||||
col_layer_for_check = ["h[0].mlp.c_fc"]
|
col_layer_for_check = ["h[0].mlp.c_fc"]
|
||||||
row_layer_for_check = ["wte", "h[0].mlp.c_proj"]
|
row_layer_for_check = ["wte", "h[0].mlp.c_proj"]
|
||||||
|
|
||||||
|
@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
row_layer_grads = get_grad_tensors_for_check(
|
row_layer_grads = get_grad_tensors_for_check(
|
||||||
gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
norm_layer_grads = get_grad_tensors_for_check(
|
||||||
|
gpt2,
|
||||||
|
sharded_gpt2,
|
||||||
|
norm_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
grads_to_check.update(col_layer_grads)
|
grads_to_check.update(col_layer_grads)
|
||||||
grads_to_check.update(row_layer_grads)
|
grads_to_check.update(row_layer_grads)
|
||||||
|
grads_to_check.update(norm_layer_grads)
|
||||||
|
|
||||||
# optimizer executes step
|
# optimizer executes step
|
||||||
org_optimizer.step()
|
org_optimizer.step()
|
||||||
|
|
Loading…
Reference in New Issue