[shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
pull/5556/head
Zhongkai Zhao 2024-04-03 17:15:47 +08:00 committed by GitHub
parent 7e0ec5a85c
commit 8e412a548e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 1630 additions and 256 deletions

View File

@ -34,7 +34,8 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer
from .pp_plugin_base import PipelinePluginBase from .pp_plugin_base import PipelinePluginBase
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
@ -53,6 +54,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
shard_config: ShardConfig, shard_config: ShardConfig,
dp_group: ProcessGroup, dp_group: ProcessGroup,
tp_group: ProcessGroup, tp_group: ProcessGroup,
sp_group: ProcessGroup,
use_ddp: bool, use_ddp: bool,
ddp_config: dict, ddp_config: dict,
custom_policy: Policy, custom_policy: Policy,
@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.shard_config = shard_config self.shard_config = shard_config
self.dp_group = dp_group self.dp_group = dp_group
self.tp_group = tp_group self.tp_group = tp_group
self.sp_group = sp_group
self.use_dpp = use_ddp self.use_dpp = use_ddp
self.require_grad_sync = True self.require_grad_sync = True
@ -168,13 +171,24 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
Returns: Returns:
None None
""" """
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if self.shard_config.enable_sequence_parallelism:
if self.shard_config.sequence_parallelism_mode == "all_to_all":
return
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
# If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized
# across the tensor parallelism group.
group = self.tp_group
else:
raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}")
if grads is not None: if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group. # Synchronize provided gradient tensors across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads) SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads)
else: else:
# Synchronize gradients from the model across the tensor parallelism group. # Synchronize gradients from the model across the tensor parallelism group.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module) SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if self.convert_fn is not None: if self.convert_fn is not None:
@ -727,10 +741,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# Get all working gradients and gradients to be synchronized. # Get all working gradients and gradients to be synchronized.
all_working_grads = _get_all_working_grads() all_working_grads = _get_all_working_grads()
grads_to_sync = _get_grads_to_sync(all_working_grads) grads_to_sync = _get_grads_to_sync(all_working_grads)
if self.require_grad_sync and grads_to_sync is not None: if self.require_grad_sync and grads_to_sync is not None:
# Synchronize sequence parallelism gradients if required. # Synchronize sequence parallelism gradients if required.
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync) SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
else: else:
return return
@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
Args: Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
sp_size (int): The size of sequence parallelism.
precision (str, optional): Specifies the precision of parameters during training. precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'. Defaults to 'fp16'.
@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
@ -938,6 +953,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self, self,
tp_size: int, tp_size: int,
pp_size: int, pp_size: int,
sp_size: int = None,
precision: str = "fp16", precision: str = "fp16",
zero_stage: int = 0, zero_stage: int = 0,
enable_all_optimization: bool = False, enable_all_optimization: bool = False,
@ -945,6 +961,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention: bool = False, enable_flash_attention: bool = False,
enable_jit_fused: bool = False, enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False, enable_sequence_parallelism: bool = False,
sequence_parallelism_mode: str = None,
enable_sequence_overlap: bool = False, enable_sequence_overlap: bool = False,
parallel_output: bool = True, parallel_output: bool = True,
num_microbatches: Optional[int] = None, num_microbatches: Optional[int] = None,
@ -976,14 +993,41 @@ class HybridParallelPlugin(PipelinePluginBase):
super().__init__() super().__init__()
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
if enable_sequence_parallelism: if enable_sequence_parallelism:
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
assert (
self.sequence_parallelism_mode in SUPPORT_SP_MODE
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
assert (
tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1:
warnings.warn(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
elif self.sequence_parallelism_mode in ["all_to_all"]:
assert (
tp_size == 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
assert (
pp_size == 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
self.sp_size = dist.get_world_size() if sp_size is None else sp_size
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
else:
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
assert (
sp_size == 1 or sp_size is None
), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
self.sp_size = 1
self.tp_size = tp_size self.tp_size = tp_size
self.pp_size = pp_size self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.precision = precision self.precision = precision
self.zero_stage = zero_stage self.zero_stage = zero_stage
self.cpu_offload = cpu_offload self.cpu_offload = cpu_offload
@ -992,7 +1036,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_flash_attention = enable_flash_attention self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
self.custom_policy = custom_policy self.custom_policy = custom_policy
@ -1033,9 +1077,14 @@ class HybridParallelPlugin(PipelinePluginBase):
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
pipeline_stage_manager=self.stage_manager, pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1, enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization, enable_all_optimization=self.enable_all_optimization,
@ -1043,6 +1092,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_flash_attention=self.enable_flash_attention, enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused, enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_parallelism=enable_sequence_parallelism,
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap, enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output, parallel_output=parallel_output,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,
@ -1113,13 +1163,23 @@ class HybridParallelPlugin(PipelinePluginBase):
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer) param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1
and self.pp_size == 1
and self.enable_sequence_parallelism
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
else:
dp_group = self.dp_group
model = HybridParallelModule( model = HybridParallelModule(
model, model,
precision=self.precision, precision=self.precision,
shard_config=self.shard_config, shard_config=self.shard_config,
dp_group=self.dp_group, dp_group=dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp, use_ddp=use_ddp,
ddp_config=self.ddp_config, ddp_config=self.ddp_config,
custom_policy=self.custom_policy, custom_policy=self.custom_policy,
@ -1149,7 +1209,8 @@ class HybridParallelPlugin(PipelinePluginBase):
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
) )
else: else:
if self.dp_size == 1: zero_dp_size = dist.get_world_size(dp_group)
if zero_dp_size == 1:
warnings.warn( warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you are not intended to use cpu_offload, please consider set zero_stage=0." "If you are not intended to use cpu_offload, please consider set zero_stage=0."
@ -1161,7 +1222,7 @@ class HybridParallelPlugin(PipelinePluginBase):
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info, param_info=param_info,
dp_process_group=self.dp_group, dp_process_group=dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
pp_process_group=self.pp_group, pp_process_group=self.pp_group,
verbose=True, verbose=True,

View File

@ -254,6 +254,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager, pipeline_stage_manager=self.stage_manager,
@ -365,6 +368,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
shard_config=self.shard_config, shard_config=self.shard_config,
dp_group=self.dp_group, dp_group=self.dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp, use_ddp=use_ddp,
ddp_config=self.ddp_config, ddp_config=self.ddp_config,
custom_policy=self.custom_policy, custom_policy=self.custom_policy,

View File

@ -161,7 +161,7 @@ class ProcessGroupMesh:
@staticmethod @staticmethod
def get_coords_along_axis( def get_coords_along_axis(
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int] base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]]
) -> List[Tuple[int, ...]]: ) -> List[Tuple[int, ...]]:
"""Get coordinates along the given axis. """Get coordinates along the given axis.
@ -173,13 +173,28 @@ class ProcessGroupMesh:
Returns: Returns:
List[Tuple[int, ...]]: Coordinates along the axis. List[Tuple[int, ...]]: Coordinates along the axis.
""" """
if isinstance(axis, int):
axis = [axis,]
assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,]
def add_index(base_coord, axis, indices_at_axis):
coords_in_group = [] coords_in_group = []
for idx in indices_at_axis: for idx in indices_at_axis:
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
return coords_in_group return coords_in_group
coords_in_group = [base_coord]
for ax, indices_at_ax in zip(axis, indices_at_axis):
new_coords_in_group = []
for coords in coords_in_group:
new_coords_in_group += add_index(coords, ax, indices_at_ax)
coords_in_group = new_coords_in_group
return coords_in_group
def create_group_along_axis( def create_group_along_axis(
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None
) -> ProcessGroup: ) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to. """Create all process groups along the given axis, and return the one which the current process belongs to.
@ -191,10 +206,17 @@ class ProcessGroupMesh:
Returns: Returns:
ProcessGroup: The process group along the given axis which the current process belongs to. ProcessGroup: The process group along the given axis which the current process belongs to.
""" """
indices_at_axis = indices_at_axis or list(range(self._shape[axis])) if isinstance(axis, int):
axis = [axis,]
if indices_at_axis is not None:
assert isinstance(indices_at_axis[0], int)
indices_at_axis = [indices_at_axis,]
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
reduced_shape = list(self._shape) reduced_shape = list(self._shape)
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
reduced_shape[axis] = 1 for ax in axis:
reduced_shape[ax] = 1
target_group = None target_group = None
# use Cartesian product to generate all combinations of coordinates # use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]): for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
@ -225,4 +247,3 @@ class ProcessGroupMesh:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis` # no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend) return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group] return self._ranks_to_group[ranks_in_group]

View File

@ -1,4 +1,5 @@
from .attn import AttnMaskType, ColoAttention from .attn import AttnMaskType, ColoAttention
from ._operation import all_to_all_comm
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row
@ -26,4 +27,5 @@ __all__ = [
"ParallelModule", "ParallelModule",
"AttnMaskType", "AttnMaskType",
"ColoAttention", "ColoAttention",
"all_to_all_comm",
] ]

View File

@ -167,6 +167,97 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None return grad_input, grad_weight, grad_bias, None, None, None
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)
# output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
# initialization of ring communication
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
rank_map = list(dist.get_process_group_ranks(process_group))
recv_rank = rank_map[recv_rank]
send_rank = rank_map[send_rank]
recv_tensors = {}
send_tensors = {}
for k, v in input_to_gather.items():
recv_tensors[k] = torch.empty_like(v)
send_tensors[k] = v.clone()
def communicate_step():
comm_ops = []
for k in recv_tensors:
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
return dist.batch_isend_irecv(comm_ops)
def switch_step():
for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
output_tensors = []
handles = communicate_step()
# first round: special case, retrive from local tensor
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
handle.wait()
switch_step()
handles = communicate_step()
# actual computation
output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
output_tensors.append(func(**recv_tensors, **input_local))
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
grad_list = [
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
dist.reduce_scatter(output, grad_list, group=process_group)
return output, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward """Gather input from sequence parallel in forward and reduce-scatter gradient in backward
@ -178,7 +269,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
@ -186,8 +277,21 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
ctx.dim = dim ctx.dim = dim
ctx.overlap = overlap ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group) if ring is True:
input_to_gather = {"input": input_}
input_local = {"weight": weight}
output = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)
if bias is not None:
output += bias
else:
input_parallel = _gather(input_, dim, process_group)
if bias is not None: if bias is not None:
output = F.linear(input_parallel, weight, bias) output = F.linear(input_parallel, weight, bias)
else: else:
@ -294,11 +398,146 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# wait until reduce-scatter finished # wait until reduce-scatter finished
reducescatter_handle.wait() reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None return output, grad_weight, grad_bias, None, None, None, None, None
def _ring_as_reducescatter(
func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1
):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)
# initialization of ring communication
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
rank_map = list(dist.get_process_group_ranks(process_group))
recv_rank = rank_map[recv_rank]
send_rank = rank_map[send_rank]
input_tensors = []
for _ in range(group_size):
input_tensors.append({})
for k, v in input_to_reducescatter.items():
input_shape = v.shape
assert input_shape[reducescatter_dim] % group_size == 0
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
for i in range(group_size):
input_tensors[i][k] = _input_tensors[i]
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
input_tensors.reverse()
output_tensor = func(**input_tensors[0], **input_local)
recv_tensor = torch.empty_like(output_tensor)
send_tensor = output_tensor.clone()
def communicate_step():
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
return dist.batch_isend_irecv([recv_op, send_op])
handles = communicate_step()
# first round: special case, retrive from local tensor
for i in range(group_size - 2):
# actual computation
output_tensor = func(**input_tensors[i + 1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
tmp_tensor = send_tensor
send_tensor = output_tensor
output_tensor = tmp_tensor
handles = communicate_step()
# final round: special case, no need to send/recv again
output_tensor = func(**input_tensors[-1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
return output_tensor
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function): class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward """Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, dim, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.dim = dim
if ring is True:
input_to_reducescatter = {"input": input_}
input_local = {"weight": weight}
if bias is not None:
input_to_reducescatter["bias"] = bias
output = _ring_as_reducescatter(
F.linear,
input_to_reducescatter=input_to_reducescatter,
input_local=input_local,
process_group=process_group,
)
else:
if bias is not None:
partial_output = F.linear(input_, weight, bias)
else:
partial_output = F.linear(input_, weight)
output_shape = list(partial_output.shape)
assert (
output_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
output_list = [
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
dist.reduce_scatter(output, output_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
grad_output = _gather(grad_output, dim, process_group)
# TODO Need to fully optimize
total_input = input_
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward
Args: Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region. input_ (`torch.Tensor`): The input tensor from sequence parallel region.
@ -343,7 +582,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
ctx.save_for_backward(input_, weight, bias) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
@ -351,6 +590,21 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
ctx.dim = dim ctx.dim = dim
ctx.overlap = overlap ctx.overlap = overlap
if ring is True:
input_to_gather = {}
input_local = {}
input_to_gather["input"] = input_
input_local["other"] = weight
output = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim,
)
else:
input_parallel = _gather(input_, dim, process_group) input_parallel = _gather(input_, dim, process_group)
output = torch.matmul(input_parallel, weight) output = torch.matmul(input_parallel, weight)
@ -433,7 +687,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
# wait until reduce-scatter finished # wait until reduce-scatter finished
reducescatter_handle.wait() reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
@ -448,14 +702,17 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, dim, process_group): def forward(ctx, input_, dim, process_group, grad_scale=None):
ctx.process_group = process_group ctx.process_group = process_group
ctx.dim = dim ctx.dim = dim
ctx.grad_scale = grad_scale
return _split(input_, dim, process_group) return _split(input_, dim, process_group)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, ctx.process_group), None, None if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
class _ReduceForward(torch.autograd.Function): class _ReduceForward(torch.autograd.Function):
@ -505,14 +762,50 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def forward(ctx, input_, dim, process_group): def forward(ctx, input_, dim, process_group, grad_scale=None):
ctx.process_group = process_group ctx.process_group = process_group
ctx.dim = dim ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group) return _gather(input_, dim, process_group)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape
# using all_to_all_single when batch size is 1
if bsz == 1:
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
else:
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)
class HookParameter(torch.autograd.Function): class HookParameter(torch.autograd.Function):
@ -608,6 +901,40 @@ def _reduce_scatter(input_, dim=1, process_group=None):
return output return output
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
if scatter_dim < 2:
input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous()
else:
input_t = (
input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :])
.transpose(0, 1)
.contiguous()
)
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
if scatter_dim < 2:
output = output.transpose(0, 1).contiguous()
return output.reshape(
inp_shape[:gather_dim]
+ [
inp_shape[gather_dim] * seq_world_size,
]
+ inp_shape[gather_dim + 1 :]
).contiguous()
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
@ -617,31 +944,39 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
def linear_gather_forward_reducescatter_backward( def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
): ):
return _LinearWithGatherForwardReduceScatterBackward.apply( return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
) )
def linear_reducescatter_forward_gather_backward(input_, process_group, dim): def gather_forward_reducescatter_backward(input_, process_group, dim):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
def reducescatter_forward_gather_backward(input_, process_group, dim):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
def matmul_gather_forward_reducescatter_backward( def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
): ):
return _MatmulWithGatherForwardReduceScatterBackward.apply( return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
) )
def gather_forward_split_backward(input_, dim, process_group): def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
return _GatherForwardSplitBackward.apply(input_, dim, process_group) return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
def split_forward_gather_backward(input_, dim, process_group): def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
return _SplitForwardGatherBackward.apply(input_, dim, process_group) return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
def reduce_forward(input_, process_group): def reduce_forward(input_, process_group):
@ -650,3 +985,7 @@ def reduce_forward(input_, process_group):
def reduce_backward(input_, process_group): def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group) return _ReduceBackward.apply(input_, process_group)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)

View File

@ -23,11 +23,13 @@ from colossalai.tensor.d_tensor.api import (
) )
from ._operation import ( from ._operation import (
gather_forward_reducescatter_backward,
gather_forward_split_backward, gather_forward_split_backward,
linear_gather_forward_reducescatter_backward, linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward, linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
reduce_forward, reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
@ -74,7 +76,7 @@ class Linear1D_Col(ParallelModule):
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
gather_output: bool = False, gather_output: bool = False,
seq_parallel: bool = False, seq_parallel_mode: str = None,
seq_parallel_dim: int = 1, seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None, overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
@ -89,7 +91,7 @@ class Linear1D_Col(ParallelModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel = seq_parallel self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
@ -196,12 +198,18 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward( if self.seq_parallel_mode is None:
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap
)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
elif self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim
)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
elif self.seq_parallel_mode == "ring":
output_parallel = linear_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
@ -225,7 +233,8 @@ class Linear1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
@ -245,7 +254,7 @@ class Linear1D_Row(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
seq_parallel: bool = False, seq_parallel_mode: str = None,
seq_parallel_dim: int = 1, seq_parallel_dim: int = 1,
parallel_input: bool = True, parallel_input: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
@ -265,7 +274,7 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.process_group = process_group self.process_group = process_group
self.seq_parallel = seq_parallel self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group) self.num_partitions = dist.get_world_size(self.process_group)
@ -403,18 +412,26 @@ class Linear1D_Row(ParallelModule):
output_parallel_list[i], group=self.process_group, async_op=True output_parallel_list[i], group=self.process_group, async_op=True
) )
handle_list.append(handle) handle_list.append(handle)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list: for handle in handle_list:
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
output_parallel = linear_with_async_comm(input_, self.weight, None, None, False) if self.seq_parallel_mode is None:
if self.seq_parallel: output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = linear_reducescatter_forward_gather_backward( output = reduce_forward(output_parallel, self.process_group)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim output_parallel, self.process_group, self.seq_parallel_dim
) )
else: elif self.seq_parallel_mode == "ring":
output = reduce_forward(output_parallel, self.process_group) output = linear_reducescatter_forward_gather_backward(
input_,
self.weight,
process_group=self.process_group,
dim=self.seq_parallel_dim,
ring=True,
)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:

View File

@ -25,12 +25,12 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import ( from ._operation import (
gather_forward_split_backward, gather_forward_split_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm, linear_with_async_comm,
matmul_gather_forward_reducescatter_backward, matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm, matmul_with_async_comm,
reduce_backward, reduce_backward,
reduce_forward, reduce_forward,
reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
@ -150,7 +150,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
device (`torch.device`): The device of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV). n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False which is :math:`Y_i = XA_i`, defaults to False
@ -175,7 +175,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
async_communication: bool = False, async_communication: bool = False,
gather_output: bool = False, gather_output: bool = False,
seq_parallel: bool = False, seq_parallel_mode: str = None,
overlap: bool = False, overlap: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
n_fused: int = 3, n_fused: int = 3,
@ -190,7 +190,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.gather_output = gather_output self.gather_output = gather_output
self.seq_parallel = seq_parallel self.seq_parallel_mode = seq_parallel_mode
self.overlap = overlap self.overlap = overlap
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
@ -312,17 +312,22 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
# Matrix multiply. # Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.seq_parallel: if self.seq_parallel_mode is None:
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
)
else:
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group) input_parallel = reduce_backward(input_, self.process_group)
output_parallel = matmul_with_async_comm( output_parallel = matmul_with_async_comm(
input_parallel, self.weight, bias, self.process_group, self.async_communication input_parallel, self.weight, bias, self.process_group, self.async_communication
) )
elif self.seq_parallel_mode == "split_gather":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
)
elif self.seq_parallel_mode == "ring":
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
@ -347,7 +352,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None. dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False. seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
which is preserved for kernel fusion, defaults to False which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer. The initializer of weight, defaults to kaiming uniform initializer.
@ -366,7 +371,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype: torch.dtype = None, dtype: torch.dtype = None,
device: torch.device = None, device: torch.device = None,
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
seq_parallel: bool = False, seq_parallel_mode: str = None,
parallel_input: bool = True, parallel_input: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight: Optional[Parameter] = None, weight: Optional[Parameter] = None,
@ -385,7 +390,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.process_group = process_group self.process_group = process_group
self.seq_parallel = seq_parallel self.seq_parallel_mode = seq_parallel_mode
self.num_partitions = dist.get_world_size(self.process_group) self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias: if skip_bias_add and not bias:
@ -528,11 +533,15 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
if self.seq_parallel_mode is None:
output_parallel = torch.matmul(input_, self.weight) output_parallel = torch.matmul(input_, self.weight)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
output = reduce_forward(output_parallel, self.process_group) output = reduce_forward(output_parallel, self.process_group)
elif self.seq_parallel_mode == "split_gather":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
elif self.seq_parallel_mode == "ring":
output_parallel = torch.matmul(input_, self.weight)
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
if not self.skip_bias_add: if not self.skip_bias_add:
if self.bias is not None: if self.bias is not None:
@ -702,7 +711,6 @@ class FusedLinear1D_Col(ParallelModule):
# process_group=process_group, # process_group=process_group,
# is_transposed=False) # is_transposed=False)
# linear_1d.bias.data.copy_(sharded_bias.data) # linear_1d.bias.data.copy_(sharded_bias.data)
print(linear_1d.weight.shape)
return linear_1d return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:

View File

@ -35,17 +35,21 @@ class SeqParallelUtils:
return getattr(param, "partial_derived", False) return getattr(param, "partial_derived", False)
@staticmethod @staticmethod
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None): def allreduce_partial_data_grad(
process_group: ProcessGroup,
model: nn.Module = None,
grads: List[torch.Tensor] = None,
):
""" """
Allreduce partial derived gradients across the specified process group. Allreduce partial derived gradients across the specified process group.
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism. This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
Args: Args:
tp_group (ProcessGroup): The process group for gradient synchronization. process_group (ProcessGroup): The process group for gradient synchronization.
model (nn.Module): The model from which gradients will be synchronized. model (nn.Module): The model from which gradients will be synchronized.
grads (List[torch.Tensor]): The list of gradients to be synchronized. grads (List[torch.Tensor]): The list of gradients to be synchronized.
only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived.
Raises: Raises:
AssertionError: If both `model` and `grads` are provided or neither is provided. AssertionError: If both `model` and `grads` are provided or neither is provided.
""" """
@ -53,22 +57,26 @@ class SeqParallelUtils:
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None." assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
# Get the size of the process group, which determines whether synchronization is needed. # Get the size of the process group, which determines whether synchronization is needed.
tp_size = get_world_size(tp_group) if tp_group is not None else 1 group_size = get_world_size(process_group) if process_group is not None else 1
if tp_size == 1: if group_size == 1:
# If the process group size is 1, no synchronization is required. # If the process group size is 1, no synchronization is required.
return return
if model is not None: if model is not None:
# If `model` is provided, extract partial derived gradients from the model's parameters. # If `model` is provided, extract partial derived gradients from the model's parameters.
grads = [] grads = []
for p in model.parameters(): for p in model.parameters():
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p): if p.grad is not None:
if SeqParallelUtils.is_sp_partial_derived_param(p):
grads.append(p.grad.data) grads.append(p.grad.data)
# Flatten and reduce the gradients using the specified process group. # Flatten and reduce the gradients using the specified process group.
if len(grads) == 0:
return
coalesced = _flatten_dense_tensors(grads) coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
# Unflatten the synchronized gradients and update the model's gradients. # Unflatten the synchronized gradients and update the model's gradients.
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
@ -76,7 +84,7 @@ class SeqParallelUtils:
else: else:
# If `grads` are provided explicitly, synchronize those gradients directly. # If `grads` are provided explicitly, synchronize those gradients directly.
coalesced = _flatten_dense_tensors(grads) coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group) dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced) buf.copy_(synced)

View File

@ -186,6 +186,7 @@ class BertPipelineForwards:
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config is not None and shard_config.enable_sequence_parallelism: if shard_config is not None and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
) )
@ -240,6 +241,7 @@ class BertPipelineForwards:
# When sequence parallelism done, gather the output tensor in forward and split it in backward # When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config is not None and shard_config.enable_sequence_parallelism: if shard_config is not None and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
) )

View File

@ -213,7 +213,8 @@ class BloomPipelineForwards:
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
) )
@ -261,7 +262,8 @@ class BloomPipelineForwards:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward # When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
) )

View File

@ -191,11 +191,10 @@ class ChatGLMPipelineForwards:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
if shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
@ -222,11 +221,10 @@ class ChatGLMPipelineForwards:
if use_cache: if use_cache:
presents = presents + (kv_cache,) presents = presents + (kv_cache,)
if shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)

View File

@ -218,7 +218,8 @@ class GPT2PipelineForwards:
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states,
dim=1, dim=1,
@ -278,7 +279,8 @@ class GPT2PipelineForwards:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward # When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, hidden_states,
dim=1, dim=1,
@ -1141,7 +1143,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, hidden_states,
dim=1, dim=1,
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.sequence_parallel_process_group,
) )
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -1208,7 +1210,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, hidden_states,
dim=1, dim=1,
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.sequence_parallel_process_group,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)

View File

@ -1,18 +1,32 @@
import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d from ..layer import ColoAttention, cross_entropy_1d
@ -438,7 +452,7 @@ class LlamaPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_llama_flash_attention_forward(shard_config: ShardConfig): def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
llama_version = 2 llama_version = 2
@ -459,18 +473,30 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
@ -490,6 +516,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
@ -726,3 +755,261 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
) )
return forward return forward
def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return forward
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
logger = logging.get_logger(__name__)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
# modify past_key_values_length when using sequence parallel
past_key_values_length *= sp_size
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward

View File

@ -1,3 +1,4 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List from typing import Callable, Dict, List
@ -66,8 +67,17 @@ class BertPolicy(Policy):
else: else:
norm_cls = col_nn.LayerNorm norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
if sp_mode == "ring":
warnings.warn(
f"For Bert, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription( policy[BertLayer] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
@ -85,7 +95,7 @@ class BertPolicy(Policy):
suffix="attention.self.query", suffix="attention.self.query",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel": use_sequence_parallel, "seq_parallel_mode": sp_mode,
"overlap": overlap, "overlap": overlap,
}, },
), ),
@ -93,7 +103,7 @@ class BertPolicy(Policy):
suffix="attention.self.key", suffix="attention.self.key",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel": use_sequence_parallel, "seq_parallel_mode": sp_mode,
"overlap": overlap, "overlap": overlap,
}, },
), ),
@ -101,7 +111,7 @@ class BertPolicy(Policy):
suffix="attention.self.value", suffix="attention.self.value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel": use_sequence_parallel, "seq_parallel_mode": sp_mode,
"overlap": overlap, "overlap": overlap,
}, },
), ),
@ -112,7 +122,7 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dense", suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel}, kwargs={"seq_parallel_mode": sp_mode},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.dropout", suffix="attention.output.dropout",
@ -122,14 +132,14 @@ class BertPolicy(Policy):
suffix="intermediate.dense", suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={ kwargs={
"seq_parallel": use_sequence_parallel, "seq_parallel_mode": sp_mode,
"overlap": overlap, "overlap": overlap,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dense", suffix="output.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel}, kwargs={"seq_parallel_mode": sp_mode},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.dropout", suffix="output.dropout",
@ -151,7 +161,7 @@ class BertPolicy(Policy):
] ]
) )
if use_sequence_parallel: if sp_mode == "split_gather":
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)}, description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy, policy=policy,
@ -165,12 +175,12 @@ class BertPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.output.LayerNorm", suffix="attention.output.LayerNorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="output.LayerNorm", suffix="output.LayerNorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
], ],
policy=policy, policy=policy,

View File

@ -1,3 +1,4 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List from typing import Callable, Dict, List
@ -55,8 +56,18 @@ class BloomPolicy(Policy):
norm_cls = col_nn.FusedLayerNorm norm_cls = col_nn.FusedLayerNorm
else: else:
norm_cls = col_nn.LayerNorm norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
if sp_mode == "ring":
warnings.warn(
f"For BLOOM, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription( policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
@ -70,12 +81,12 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.query_key_value", suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.dense", suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel}, kwargs={"seq_parallel_mode": sp_mode},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.attention_dropout", suffix="self_attention.attention_dropout",
@ -84,12 +95,12 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel}, kwargs={"seq_parallel_mode": sp_mode},
), ),
], ],
) )
@ -132,19 +143,19 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="input_layernorm", suffix="input_layernorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="post_attention_layernorm", suffix="post_attention_layernorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
], ],
policy=policy, policy=policy,
target_key=BloomBlock, target_key=BloomBlock,
) )
if use_sequence_parallel: if sp_mode == "split_gather":
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)},
policy=policy, policy=policy,

View File

@ -1,3 +1,4 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
@ -55,8 +56,17 @@ class ChatGLMPolicy(Policy):
norm_cls = col_nn.RMSNorm norm_cls = col_nn.RMSNorm
else: else:
norm_cls = col_nn.LayerNorm norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
if sp_mode == "ring":
warnings.warn(
f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription( policy[ChatGLMModel] = ModulePolicyDescription(
attribute_replacement={}, attribute_replacement={},
@ -91,12 +101,12 @@ class ChatGLMPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.query_key_value", suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap}, kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.dense", suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row, target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0}, kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout", suffix="self_attention.core_attention.attention_dropout",
@ -110,12 +120,12 @@ class ChatGLMPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="input_layernorm", suffix="input_layernorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="post_attention_layernorm", suffix="post_attention_layernorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
], ],
policy=policy, policy=policy,
@ -145,7 +155,7 @@ class ChatGLMPolicy(Policy):
) )
# use sequence parallel # use sequence parallel
if use_sequence_parallel: if sp_mode == "split_gather":
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy, policy=policy,

View File

@ -1,3 +1,4 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List from typing import Callable, Dict, List
@ -50,8 +51,25 @@ class GPT2Policy(Policy):
norm_cls = col_nn.FusedLayerNorm norm_cls = col_nn.FusedLayerNorm
else: else:
norm_cls = col_nn.LayerNorm norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
if sp_mode == "ring":
warnings.warn(
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap overlap = self.shard_config.enable_sequence_overlap
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
# todo: currently sp cannot be used with flashattention
if sp_mode in ["split_gather", "ring", "all_to_all"]:
if use_flash_attention:
warnings.warn(
f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically."
)
self.shard_config.enable_flash_attention = False
use_flash_attention = False
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription( policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
@ -78,7 +96,7 @@ class GPT2Policy(Policy):
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 3, "n_fused": 3,
"seq_parallel": use_sequence_parallel, "seq_parallel_mode": sp_mode,
"overlap": overlap, "overlap": overlap,
}, },
), ),
@ -86,7 +104,7 @@ class GPT2Policy(Policy):
suffix="attn.c_proj", suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row, target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={ kwargs={
"seq_parallel": use_sequence_parallel, "seq_parallel_mode": sp_mode,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
@ -94,14 +112,16 @@ class GPT2Policy(Policy):
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 1, "n_fused": 1,
"seq_parallel": use_sequence_parallel, "seq_parallel_mode": sp_mode,
"overlap": overlap, "overlap": overlap,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_proj", suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row, target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={"seq_parallel": use_sequence_parallel}, kwargs={
"seq_parallel_mode": sp_mode,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.attn_dropout", suffix="attn.attn_dropout",
@ -133,25 +153,25 @@ class GPT2Policy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="ln_1", suffix="ln_1",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="ln_2", suffix="ln_2",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="ln_cross_attn", suffix="ln_cross_attn",
target_module=norm_cls, target_module=norm_cls,
ignore_if_not_exist=True, ignore_if_not_exist=True,
kwargs={"sp_partial_derived": use_sequence_parallel}, kwargs={"sp_partial_derived": sp_partial_derived},
), ),
], ],
policy=policy, policy=policy,
target_key=GPT2Block, target_key=GPT2Block,
) )
if self.shard_config.enable_flash_attention: if use_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_gpt2_flash_attention_forward(), "forward": get_gpt2_flash_attention_forward(),
@ -164,7 +184,7 @@ class GPT2Policy(Policy):
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config) "forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
} }
if self.shard_config.enable_sequence_parallelism: if sp_mode is not None:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
return policy return policy

View File

@ -12,6 +12,8 @@ from ..modeling.llama import (
LlamaPipelineForwards, LlamaPipelineForwards,
get_llama_flash_attention_forward, get_llama_flash_attention_forward,
get_llama_model_forward_for_flash_attn, get_llama_model_forward_for_flash_attn,
get_llama_seq_parallel_attention_forward,
get_llama_seq_parallel_model_forward,
get_lm_forward_with_dist_cross_entropy, get_lm_forward_with_dist_cross_entropy,
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -45,9 +47,74 @@ class LlamaPolicy(Policy):
else: else:
norm_cls = RMSNorm norm_cls = RMSNorm
if self.shard_config.enable_sequence_parallelism: if self.pipeline_stage_manager is not None:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") self.shard_config.enable_sequence_overlap = False
self.shard_config.sequence_parallelism_mode = None
warnings.warn(
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
)
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
sp_group = (
self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None
)
sp_partial_derived = sp_mode in ["split_gather", "ring"]
use_flash_attention = self.shard_config.enable_flash_attention
# Currently sp cannot to be used with flashattention
if sp_mode in ["split_gather", "ring", "all_to_all"]:
if use_flash_attention:
warnings.warn(
f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically."
)
use_flash_attention = False
if sp_mode in ["split_gather", "ring"]:
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_model_forward(
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
),
},
policy=policy,
target_key=LlamaModel,
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=LlamaAttention,
)
elif sp_mode == "all_to_all":
decoder_attribute_replacement = {
"num_heads": self.model.config.num_attention_heads // sp_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
policy[LlamaAttention] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=LlamaAttention,
)
self.append_or_create_method_replacement(
description={
"forward": get_llama_seq_parallel_model_forward(
sp_mode=sp_mode,
sp_size=sp_size,
sp_group=sp_group,
),
},
policy=policy,
target_key=LlamaModel,
)
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = { decoder_attribute_replacement = {
@ -65,30 +132,37 @@ class LlamaPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.q_proj", suffix="self_attn.q_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.k_proj", suffix="self_attn.k_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.v_proj", suffix="self_attn.v_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn.o_proj", suffix="self_attn.o_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.gate_proj", suffix="mlp.gate_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.up_proj", suffix="mlp.up_proj",
target_module=Linear1D_Col, target_module=Linear1D_Col,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.down_proj", suffix="mlp.down_proj",
target_module=Linear1D_Row, target_module=Linear1D_Row,
kwargs=dict(seq_parallel_mode=sp_mode),
), ),
], ],
) )
@ -108,10 +182,12 @@ class LlamaPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="input_layernorm", suffix="input_layernorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="post_attention_layernorm", suffix="post_attention_layernorm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
), ),
], ],
policy=policy, policy=policy,
@ -122,16 +198,17 @@ class LlamaPolicy(Policy):
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="norm", suffix="norm",
target_module=norm_cls, target_module=norm_cls,
kwargs={"sp_partial_derived": sp_partial_derived},
), ),
policy=policy, policy=policy,
target_key=LlamaModel, target_key=LlamaModel,
) )
# use flash attention # use flash attention
if self.shard_config.enable_flash_attention: if use_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_llama_flash_attention_forward(self.shard_config), "forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
}, },
policy=policy, policy=policy,
target_key=LlamaAttention, target_key=LlamaAttention,
@ -243,7 +320,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism:
# add a new item for casual lm # add a new item for casual lm
new_item = { new_item = {
LlamaForCausalLM: ModulePolicyDescription( LlamaForCausalLM: ModulePolicyDescription(

View File

@ -1,3 +1,4 @@
import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -9,6 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"] __all__ = ["ShardConfig"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
@dataclass @dataclass
@ -29,13 +31,15 @@ class ShardConfig:
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
""" """
tensor_parallel_process_group: Optional[ProcessGroup] = None tensor_parallel_process_group: Optional[ProcessGroup] = None
sequence_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None pipeline_stage_manager: Optional[PipelineStageManager] = None
enable_tensor_parallelism: bool = True enable_tensor_parallelism: bool = True
enable_all_optimization: bool = False
enable_fused_normalization: bool = False enable_fused_normalization: bool = False
enable_flash_attention: bool = False enable_flash_attention: bool = False
enable_jit_fused: bool = False enable_jit_fused: bool = False
enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False enable_sequence_parallelism: bool = False
sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False enable_sequence_overlap: bool = False
parallel_output: bool = True parallel_output: bool = True
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
@ -50,22 +54,57 @@ class ShardConfig:
def tensor_parallel_size(self): def tensor_parallel_size(self):
return self._tensor_parallel_size return self._tensor_parallel_size
@property
def sequence_parallel_size(self):
return self._sequence_parallel_size
def __post_init__(self): def __post_init__(self):
if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
raise ValueError(
"enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True"
)
if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
# get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# turn on all optimization if all_optimization is set to True # turn on all optimization if all_optimization is set to True
if self.enable_all_optimization: if self.enable_all_optimization:
self._turn_on_all_optimization() self._turn_on_all_optimization()
if self.enable_sequence_parallelism:
self.sequence_parallelism_mode = (
"split_gather" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode
)
assert (
self.sequence_parallelism_mode in SUPPORT_SP_MODE
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
assert (
self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
elif self.sequence_parallelism_mode in ["all_to_all"]:
assert (
not self.enable_tensor_parallelism
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
if self.enable_sequence_overlap:
self.enable_sequence_overlap = False
warnings.warn(
f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
)
else:
if self.sequence_parallelism_mode:
self.sequence_parallelism_mode = None
warnings.warn(
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
)
assert (
not self.enable_sequence_overlap
), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
# get the tensor parallel size
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# get the sequence parallel size
if not self.enable_sequence_parallelism:
self._sequence_parallel_size = 1
else:
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
def _turn_on_all_optimization(self): def _turn_on_all_optimization(self):
""" """
Turn on all optimization. Turn on all optimization.

View File

@ -79,6 +79,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_weights: bool = True, # master weights master_weights: bool = True, # master weights
): ):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]["params"][0].dtype self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._logger = get_dist_logger() self._logger = get_dist_logger()
self._verbose = verbose self._verbose = verbose
@ -494,7 +495,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
get_accelerator().synchronize() get_accelerator().synchronize()
self.zero_grad() self.zero_grad()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):

View File

@ -18,8 +18,23 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt') # tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids'] # input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask'] # attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) # input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) # attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
input_ids = torch.tensor(
[
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
],
dtype=torch.int64,
)
attention_mask = torch.tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
dtype=torch.int64,
)
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
@ -35,9 +50,9 @@ def data_gen_for_question_answering():
# question answering data gen # question answering data gen
# `labels` is the type not the token id for token classification, 0 or 1 # `labels` is the type not the token id for token classification, 0 or 1
data = data_gen() data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64) start_positions = torch.tensor([[0], [0]], dtype=torch.int64)
data["start_positions"] = start_positions data["start_positions"] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64) end_positions = torch.tensor([[1], [1]], dtype=torch.int64)
data["end_positions"] = end_positions data["end_positions"] = end_positions
return data return data
@ -46,14 +61,20 @@ def data_gen_for_token_classification():
# token classification data gen # token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1 # `labels` is the type not the token id for token classification, 0 or 1
data = data_gen() data = data_gen()
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) data["labels"] = torch.tensor(
[
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
],
dtype=torch.int64,
)
return data return data
def data_gen_for_sequence_classification(): def data_gen_for_sequence_classification():
# sequence classification data gen # sequence classification data gen
data = data_gen() data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64) data["labels"] = torch.tensor([[1], [1]], dtype=torch.int64)
return data return data
@ -61,12 +82,18 @@ def date_gen_for_double_heads():
num_choices = 2 num_choices = 2
batch_size = 2 batch_size = 2
input_ids = torch.tensor( input_ids = torch.tensor(
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]], [
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
],
dtype=torch.int64,
)
attention_mask = torch.tensor(
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.int64, dtype=torch.int64,
) )
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64) mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64)
mc_token_ids = mc_token_ids.expand((batch_size, num_choices)) mc_token_ids = mc_token_ids.expand((batch_size, num_choices))
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous()
@ -103,6 +130,7 @@ config = transformers.GPT2Config(
hidden_dropout=0, hidden_dropout=0,
problem_type="single_label_classification", problem_type="single_label_classification",
pad_token_id=50256, pad_token_id=50256,
tie_word_embeddings=True,
) )
config_for_token_classification = copy.deepcopy(config) config_for_token_classification = copy.deepcopy(config)

View File

@ -28,9 +28,19 @@ if HAS_LLAMA:
# ----------------------------------- # -----------------------------------
input_ids = torch.Tensor( input_ids = torch.Tensor(
[[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] [
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
]
).long() ).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
attention_mask = torch.Tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
).long()
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm # label is needed for casual lm

View File

@ -44,7 +44,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn() bert_model = model_fn()
enable_all_optimization = True if tp_size > 1 else False
enable_flash_attention = True if tp_size > 1 else False
enable_fused_normalization = True if tp_size > 1 else False
enable_jit_fused = True if tp_size > 1 else False
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained") pretrained_path = os.path.join(tempdir, "pretrained")
@ -54,7 +57,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
plugin = GeminiPlugin( plugin = GeminiPlugin(
**placement_config, **placement_config,
tp_size=tp_size, tp_size=tp_size,
enable_all_optimization=enable_all_optimization, enable_flash_attention=enable_flash_attention,
enable_fused_normalization=enable_fused_normalization,
enable_jit_fused=enable_jit_fused,
extra_dp_size=extra_dp_size, extra_dp_size=extra_dp_size,
) )
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
@ -80,7 +85,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int): def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
enable_all_optimization = True if tp_size > 1 else False enable_flash_attention = True if tp_size > 1 else False
enable_fused_normalization = True if tp_size > 1 else False
enable_jit_fused = True if tp_size > 1 else False
extra_dp_size = dist.get_world_size() // (zero_size * tp_size) extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
plugin = GeminiPlugin( plugin = GeminiPlugin(
**placement_config, **placement_config,
@ -88,7 +95,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
initial_scale=(2**14), initial_scale=(2**14),
tp_size=tp_size, tp_size=tp_size,
extra_dp_size=extra_dp_size, extra_dp_size=extra_dp_size,
enable_all_optimization=enable_all_optimization, enable_flash_attention=enable_flash_attention,
enable_fused_normalization=enable_fused_normalization,
enable_jit_fused=enable_jit_fused,
) )
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)

View File

@ -84,6 +84,30 @@ def check_process_group_mesh_with_cases():
2: [2], 2: [2],
3: [3], 3: [3],
} }
TPxPP_RANKS_IN_GROUP = {
0: [0, 1, 2, 3],
1: [0, 1, 2, 3],
2: [0, 1, 2, 3],
3: [0, 1, 2, 3],
}
DPxTP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
TPxPP_PARTIAL_INDICES = {
0: [[0, 1], [0]],
1: [[1], [0, 1]],
2: [[0], [0, 1]],
3: [[0, 1], [1]],
}
TPxPP_RANKS_IN_GROUP_PARTIAL = {
0: [0, 1],
1: [1, 3],
2: [0, 2],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE)
@ -107,6 +131,12 @@ def check_process_group_mesh_with_cases():
assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank]
dp_group = pg_mesh.get_group_along_axis(DP_DIM) dp_group = pg_mesh.get_group_along_axis(DP_DIM)
assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank]
dpxtp_group = pg_mesh.create_group_along_axis([DP_DIM, TP_DIM])
assert pg_mesh.get_ranks_in_group(dpxtp_group) == DPxTP_RANKS_IN_GROUP[rank]
tpxpp_group = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM])
assert pg_mesh.get_ranks_in_group(tpxpp_group) == TPxPP_RANKS_IN_GROUP[rank]
tpxpp_group_partial = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM], TPxPP_PARTIAL_INDICES[rank])
assert pg_mesh.get_ranks_in_group(tpxpp_group_partial) == TPxPP_RANKS_IN_GROUP_PARTIAL[rank]
# check prev rank # check prev rank
if RANK_TO_COORDINATE[rank][TP_DIM] != 0: if RANK_TO_COORDINATE[rank][TP_DIM] != 0:

View File

@ -56,13 +56,18 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor return rearanged_tensor
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = Conv1D(192, 48).cuda()
with ctx: with ctx:
linear_copy = Conv1D(192, 48).cuda() linear_copy = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap linear_copy,
process_group=None,
gather_output=True,
seq_parallel_mode=seq_parallel_mode,
n_fused=3,
overlap=overlap,
) )
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([48, 192])
@ -79,7 +84,9 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool)
# check computation correctness # check computation correctness
x = torch.rand(1, 4, 48).cuda() x = torch.rand(1, 4, 48).cuda()
out = linear(x) out = linear(x)
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard = (
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
gather_out = linear_conv_col(x_for_shard) gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out) assert_close(rearrange(out, -1), gather_out)
@ -91,14 +98,14 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool)
assert_close(target_grad, linear_conv_col.weight.grad) assert_close(target_grad, linear_conv_col.weight.grad)
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda() linear = Conv1D(192, 48).cuda()
with ctx: with ctx:
linear_copy = Conv1D(192, 48).cuda() linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module( linear_row = GPT2FusedLinearConv1D_Row.from_native_module(
linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
) )
assert linear.weight.shape == torch.Size([48, 192]) assert linear.weight.shape == torch.Size([48, 192])
@ -115,7 +122,7 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
x = torch.rand(1, 4, 48).cuda() x = torch.rand(1, 4, 48).cuda()
out = linear(x) out = linear(x)
gather_out = linear_row(x) gather_out = linear_row(x)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out) assert_close(target_out, gather_out)
# check backward correctness # check backward correctness
@ -128,11 +135,11 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
@parameterize("lazy_init", [False, True]) @parameterize("lazy_init", [False, True])
@parameterize("seq_parallel", [False, True]) @parameterize("seq_parallel_mode", ["split_gather", None])
@parameterize("overlap", [True]) @parameterize("overlap", [True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_conv_1d_row(lazy_init, seq_parallel) check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

View File

@ -15,13 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda() linear = nn.Linear(32, 128).cuda()
with ctx: with ctx:
linear_copy = nn.Linear(32, 128).cuda() linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module( linear_col = Linear1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap
) )
# ensure that the parameters are distributed # ensure that the parameters are distributed
@ -43,7 +43,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
x = torch.rand(2, 4, 32).cuda() x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone()) x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True) x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard = (
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
x_for_shard.requires_grad_(True) x_for_shard.requires_grad_(True)
out = linear(x_for_unshard) out = linear(x_for_unshard)
@ -63,20 +65,20 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
assert x_for_unshard.grad is not None assert x_for_unshard.grad is not None
target_unshard_gard = ( target_unshard_gard = (
x_for_unshard.grad x_for_unshard.grad
if seq_parallel is False if seq_parallel_mode is None
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
) )
assert_close(target_unshard_gard, x_for_shard.grad) assert_close(target_unshard_gard, x_for_shard.grad)
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda() linear = nn.Linear(32, 128).cuda()
with ctx: with ctx:
linear_copy = nn.Linear(32, 128).cuda() linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module( linear_row = Linear1D_Row.from_native_module(
linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
) )
assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.weight.shape == torch.Size([128, 16])
@ -98,7 +100,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
# run forward # run forward
out = linear(x_for_unshard) out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard) gather_out = linear_row(x_for_shard)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()] target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out) assert_close(target_out, gather_out)
# check backward correctness # check backward correctness
@ -115,7 +117,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad) assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool): def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext() ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda() linear_1 = nn.Linear(32, 128).cuda()
@ -125,10 +127,10 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
linear_1_copy = nn.Linear(32, 128).cuda() linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda() linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module( linear_col = Linear1D_Col.from_native_module(
linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap
) )
linear_row = Linear1D_Row.from_native_module( linear_row = Linear1D_Row.from_native_module(
linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode
) )
linear_1.load_state_dict(linear_col.state_dict()) linear_1.load_state_dict(linear_col.state_dict())
@ -141,13 +143,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
x = torch.rand(2, 4, 32).cuda() x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone()) x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True) x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()] x_for_shard = (
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
x_for_shard.requires_grad_(True) x_for_shard.requires_grad_(True)
# run forward # run forward
unshard_out = linear_2(linear_1(x_for_unshard)) unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard)) shard_out = linear_row(linear_col(x_for_shard))
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()] target_out = (
unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
)
assert_close(target_out, shard_out) assert_close(target_out, shard_out)
# check backward correctness # check backward correctness
@ -163,19 +169,19 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
assert x_for_unshard.grad is not None assert x_for_unshard.grad is not None
target_unshard_gard = ( target_unshard_gard = (
x_for_unshard.grad x_for_unshard.grad
if seq_parallel is False if seq_parallel_mode is None
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
) )
assert_close(target_unshard_gard, x_for_shard.grad) assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize("lazy_init", [False, True]) @parameterize("lazy_init", [False, True])
@parameterize("seq_parallel", [False, True]) @parameterize("seq_parallel_mode", [None, "split_gather"])
@parameterize("overlap", [True]) @parameterize("overlap", [True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap): def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap) check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_1d_row(lazy_init, seq_parallel) check_linear_1d_row(lazy_init, seq_parallel_mode)
check_linear_col_plus_row(lazy_init, seq_parallel, overlap) check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
def check_dist_linear(rank, world_size, port): def check_dist_linear(rank, world_size, port):

View File

@ -0,0 +1,178 @@
import copy
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import all_to_all_comm
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
class SequenceParallelAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
heads_num: torch.Tensor,
hidden_dim: torch.Tensor,
enable_sequence_parallellism: bool = False,
sequence_process_group: dist.ProcessGroup = None,
scatter_idx: int = 2,
gather_idx: int = 1,
) -> None:
super(SequenceParallelAttention, self).__init__()
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.heads_num = heads_num
self.hidden_dim = hidden_dim
assert hidden_dim % heads_num == 0
self.head_dim = hidden_dim // heads_num
self.enable_sequence_parallellism = enable_sequence_parallellism
self.q = nn.Linear(hidden_dim, hidden_dim)
self.k = nn.Linear(hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, hidden_dim)
def attn(self, q, k, v):
batch_size, seq_len = q.shape[0], q.shape[1]
scale = self.head_dim**0.5
qk = torch.matmul(q, k.transpose(-2, -1)) / scale
weights = F.softmax(qk, dim=-1)
attention_score = torch.matmul(weights, v)
return attention_score
def forward(self, x) -> Tensor:
bsz, q_len, _ = x.size()
seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len
num_heads = (
self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num
)
# in shape : e.g., [s/p:h:]
query_states = self.q(x)
key_states = self.k(x)
value_states = self.v(x)
if self.enable_sequence_parallellism:
query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx)
key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx)
value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx)
query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
# out shape : e.g., [s:h/p:]
attn_score = self.attn(query_states, key_states, value_states)
attn_score = attn_score.transpose(1, 2).contiguous()
attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim)
if self.enable_sequence_parallellism:
attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx)
# output e.g., [s/p::h]
output = self.out(attn_score)
return output
def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
seq_len = seq_len
hidden_dim = hidden_dim
head_num = head_num
batch_size = batch_size
world_size = dist.get_world_size()
x = torch.randn(batch_size, seq_len, hidden_dim).cuda()
x_unshard = x.clone()
x_unshard.requires_grad_(True)
x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()]
x_input.requires_grad_(True)
# Multi-head Attention
mha = SequenceParallelAttention(head_num, hidden_dim).cuda()
# Multi-head Attention forward
mha_out = mha(x_unshard)
# Sequence parallel Attention
sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda()
sp_attn.load_state_dict(copy.deepcopy(mha.state_dict()))
# Sequence parallel Attention forward
dist_attn_out = sp_attn(x_input)
# gather the output of sequence parallel attention
out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)]
dist.all_gather(out_list, dist_attn_out)
seq_out = torch.cat(out_list, dim=1)
# forward result check
assert_close(seq_out, mha_out)
# Multi-head Attention backward
mha_out.sum().backward()
q_grad = mha.q.weight.grad
k_grad = mha.k.weight.grad
v_grad = mha.v.weight.grad
o_grad = mha.out.weight.grad
x_grad = x_unshard.grad
# Sequence parallel Attention backward
dist_attn_out.sum().backward()
q_grad_seq = sp_attn.q.weight.grad
k_grad_seq = sp_attn.k.weight.grad
v_grad_seq = sp_attn.v.weight.grad
o_grad_seq = sp_attn.out.weight.grad
x_grad_seq = x_input.grad
# all_reduce the grad of sequence parallel attention weight
dist.all_reduce(q_grad_seq)
dist.all_reduce(k_grad_seq)
dist.all_reduce(v_grad_seq)
dist.all_reduce(o_grad_seq)
# gather the grad of sequence parallel attention input
x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)]
dist.all_gather(x_grad_seq_list, x_grad_seq)
x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1)
# backward result check
assert_close(q_grad_seq, q_grad)
assert_close(k_grad_seq, k_grad)
assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4)
assert_close(o_grad_seq, o_grad)
assert_close(x_grad_seq_gather, x_grad)
@parameterize("seq_len", [128])
@parameterize("hidden_dim", [64])
@parameterize("head_num", [4])
@parameterize("batch_size", [1])
def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size)
def check_all2all_attn(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_seq_parallel_attn()
@rerun_if_address_is_in_use()
def test_all_to_all_attention():
spawn(check_all2all_attn, nprocs=4)
if __name__ == "__main__":
test_all_to_all_attention()

View File

@ -1,5 +1,4 @@
import copy import copy
import math
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
@ -123,7 +122,6 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
sharded_model = copy.deepcopy(org_model) sharded_model = copy.deepcopy(org_model)
if use_lazy_init: if use_lazy_init:
ctx.materialize(org_model) ctx.materialize(org_model)
org_model = org_model.cuda() org_model = org_model.cuda()
org_optimizer = Adam(org_model.parameters(), lr=1e-3) org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3) sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
@ -162,24 +160,22 @@ def run_forward_backward_with_hybrid_plugin(
data = data_gen_fn() data = data_gen_fn()
if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0: shard_test_data = {}
seq_len = data["input_ids"].shape[-1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data["input_ids"].shape
for k, v in data.items(): for k, v in data.items():
if v.shape == input_shape: shard_test_data[k] = data[k].clone()
data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) unshard_test_data = {}
for k, v in data.items():
unshard_test_data[k] = data[k].clone()
sharded_model.train() sharded_model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
for k, v in data.items(): for k, v in shard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim() new_shape = [1] * v.dim()
new_shape[0] = 4 new_shape[0] = 4
data[k] = v.to("cuda").repeat(*new_shape) shard_test_data[k] = v.to("cuda").repeat(*new_shape)
data_iter = iter([data]) data_iter = iter([shard_test_data])
sharded_output = booster.execute_pipeline( sharded_output = booster.execute_pipeline(
data_iter, data_iter,
sharded_model, sharded_model,
@ -189,17 +185,22 @@ def run_forward_backward_with_hybrid_plugin(
return_outputs=True, return_outputs=True,
) )
sharded_loss = sharded_output["loss"] sharded_loss = sharded_output["loss"]
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
else:
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
sharded_output = sharded_model(**shard_test_data)
sharded_loss = criterion(sharded_output) sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss) sharded_optimizer.backward(sharded_loss)
org_model.train() org_model.train()
data = {k: v.cuda() for k, v in data.items()} if booster.plugin.stage_manager is not None:
org_output = org_model(**data) for k, v in unshard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
org_output = org_model(**unshard_test_data)
org_loss = criterion(org_output) org_loss = criterion(org_output)
org_loss.backward() org_loss.backward()
@ -212,7 +213,6 @@ def check_output_hidden_state(
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5, atol: float = 1e-5,
rtol: float = 1e-3, rtol: float = 1e-3,
dim: int = 0,
): ):
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state

View File

@ -100,6 +100,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
@ -154,7 +176,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
def run_bert_test(test_config): def run_bert_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

View File

@ -99,6 +99,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,

View File

@ -135,6 +135,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,

View File

@ -131,6 +131,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,

View File

@ -2,6 +2,8 @@ import os
import pytest import pytest
import torch import torch
import torch.distributed as dist
from torch.testing import assert_close
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
@ -46,6 +48,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"] col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
norm_layer_for_check = ["layers[0].input_layernorm", "layers[0].post_attention_layernorm"]
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
if stage_manager is None:
norm_layer_for_check.append("norm")
# Check the grad when using ZeRO-1 and ZeRO-2
if (
booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
@ -60,8 +82,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
col_layer_grads = get_grad_tensors_for_check( col_layer_grads = get_grad_tensors_for_check(
llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
) )
norm_layer_grads = get_grad_tensors_for_check(
llama_model,
shard_llama_model,
norm_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads) grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads) grads_to_check.update(row_layer_grads)
grads_to_check.update(norm_layer_grads)
# optimizer executes step # optimizer executes step
org_optimizer.step() org_optimizer.step()
@ -98,6 +131,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize( @parameterize(
"test_config", "test_config",
[ [
{
"tp_size": 2,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,