mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>pull/5556/head
parent
7e0ec5a85c
commit
8e412a548e
|
@ -34,7 +34,8 @@ from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||||
|
|
||||||
from .pp_plugin_base import PipelinePluginBase
|
from .pp_plugin_base import PipelinePluginBase
|
||||||
|
|
||||||
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
|
DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
|
||||||
|
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||||
|
|
||||||
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
|
PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
|
||||||
|
|
||||||
|
@ -53,6 +54,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
shard_config: ShardConfig,
|
shard_config: ShardConfig,
|
||||||
dp_group: ProcessGroup,
|
dp_group: ProcessGroup,
|
||||||
tp_group: ProcessGroup,
|
tp_group: ProcessGroup,
|
||||||
|
sp_group: ProcessGroup,
|
||||||
use_ddp: bool,
|
use_ddp: bool,
|
||||||
ddp_config: dict,
|
ddp_config: dict,
|
||||||
custom_policy: Policy,
|
custom_policy: Policy,
|
||||||
|
@ -61,6 +63,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
self.dp_group = dp_group
|
self.dp_group = dp_group
|
||||||
self.tp_group = tp_group
|
self.tp_group = tp_group
|
||||||
|
self.sp_group = sp_group
|
||||||
self.use_dpp = use_ddp
|
self.use_dpp = use_ddp
|
||||||
self.require_grad_sync = True
|
self.require_grad_sync = True
|
||||||
|
|
||||||
|
@ -168,13 +171,24 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
if self.shard_config.sequence_parallelism_mode == "all_to_all":
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||||
|
# If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized
|
||||||
|
# across the tensor parallelism group.
|
||||||
|
group = self.tp_group
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}")
|
||||||
|
|
||||||
if grads is not None:
|
if grads is not None:
|
||||||
# Synchronize provided gradient tensors across the tensor parallelism group.
|
# Synchronize provided gradient tensors across the tensor parallelism group.
|
||||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
|
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads)
|
||||||
else:
|
else:
|
||||||
# Synchronize gradients from the model across the tensor parallelism group.
|
# Synchronize gradients from the model across the tensor parallelism group.
|
||||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
|
SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.convert_fn is not None:
|
if self.convert_fn is not None:
|
||||||
|
@ -727,10 +741,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
# Get all working gradients and gradients to be synchronized.
|
# Get all working gradients and gradients to be synchronized.
|
||||||
all_working_grads = _get_all_working_grads()
|
all_working_grads = _get_all_working_grads()
|
||||||
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
||||||
|
|
||||||
if self.require_grad_sync and grads_to_sync is not None:
|
if self.require_grad_sync and grads_to_sync is not None:
|
||||||
# Synchronize sequence parallelism gradients if required.
|
# Synchronize sequence parallelism gradients if required.
|
||||||
SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
|
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -891,6 +904,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
Args:
|
Args:
|
||||||
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
|
||||||
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
|
||||||
|
sp_size (int): The size of sequence parallelism.
|
||||||
precision (str, optional): Specifies the precision of parameters during training.
|
precision (str, optional): Specifies the precision of parameters during training.
|
||||||
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
|
||||||
Defaults to 'fp16'.
|
Defaults to 'fp16'.
|
||||||
|
@ -903,6 +917,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||||
|
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||||
|
@ -938,6 +953,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self,
|
self,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
|
sp_size: int = None,
|
||||||
precision: str = "fp16",
|
precision: str = "fp16",
|
||||||
zero_stage: int = 0,
|
zero_stage: int = 0,
|
||||||
enable_all_optimization: bool = False,
|
enable_all_optimization: bool = False,
|
||||||
|
@ -945,6 +961,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_flash_attention: bool = False,
|
enable_flash_attention: bool = False,
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
|
sequence_parallelism_mode: str = None,
|
||||||
enable_sequence_overlap: bool = False,
|
enable_sequence_overlap: bool = False,
|
||||||
parallel_output: bool = True,
|
parallel_output: bool = True,
|
||||||
num_microbatches: Optional[int] = None,
|
num_microbatches: Optional[int] = None,
|
||||||
|
@ -976,14 +993,41 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert (
|
assert (
|
||||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||||
|
|
||||||
if enable_sequence_parallelism:
|
if enable_sequence_parallelism:
|
||||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
self.sequence_parallelism_mode = sequence_parallelism_mode if sequence_parallelism_mode is not None else "1"
|
||||||
|
assert (
|
||||||
|
self.sequence_parallelism_mode in SUPPORT_SP_MODE
|
||||||
|
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
|
||||||
|
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||||
|
assert (
|
||||||
|
tp_size > 1
|
||||||
|
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
|
||||||
|
if sp_size != 1:
|
||||||
|
warnings.warn(
|
||||||
|
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
|
||||||
|
)
|
||||||
|
self.sp_size = 1
|
||||||
|
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||||
|
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||||
|
assert (
|
||||||
|
tp_size == 1
|
||||||
|
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with tensor parallelism"
|
||||||
|
assert (
|
||||||
|
pp_size == 1
|
||||||
|
), f"Sequence parallelism mode {self.sequence_parallelism_mode} cannot be used with pipeline parallelism"
|
||||||
|
self.sp_size = dist.get_world_size() if sp_size is None else sp_size
|
||||||
|
self.dp_size = dist.get_world_size() // (self.sp_size * pp_size)
|
||||||
|
else:
|
||||||
|
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||||
|
assert (
|
||||||
|
sp_size == 1 or sp_size is None
|
||||||
|
), f"sp_size can only be set to a >1 number when enable_sequence_parallelism is True"
|
||||||
|
self.sp_size = 1
|
||||||
|
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.pp_size = pp_size
|
self.pp_size = pp_size
|
||||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.zero_stage = zero_stage
|
self.zero_stage = zero_stage
|
||||||
self.cpu_offload = cpu_offload
|
self.cpu_offload = cpu_offload
|
||||||
|
@ -992,7 +1036,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.enable_flash_attention = enable_flash_attention
|
self.enable_flash_attention = enable_flash_attention
|
||||||
self.enable_jit_fused = enable_jit_fused
|
self.enable_jit_fused = enable_jit_fused
|
||||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.schedule = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
|
@ -1033,9 +1077,14 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||||
|
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||||
|
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
|
else:
|
||||||
|
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
|
||||||
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
|
sequence_parallel_process_group=self.sp_group,
|
||||||
pipeline_stage_manager=self.stage_manager,
|
pipeline_stage_manager=self.stage_manager,
|
||||||
enable_tensor_parallelism=self.tp_size > 1,
|
enable_tensor_parallelism=self.tp_size > 1,
|
||||||
enable_all_optimization=self.enable_all_optimization,
|
enable_all_optimization=self.enable_all_optimization,
|
||||||
|
@ -1043,6 +1092,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
enable_flash_attention=self.enable_flash_attention,
|
enable_flash_attention=self.enable_flash_attention,
|
||||||
enable_jit_fused=self.enable_jit_fused,
|
enable_jit_fused=self.enable_jit_fused,
|
||||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||||
|
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||||
enable_sequence_overlap=enable_sequence_overlap,
|
enable_sequence_overlap=enable_sequence_overlap,
|
||||||
parallel_output=parallel_output,
|
parallel_output=parallel_output,
|
||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
|
@ -1113,13 +1163,23 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
param_info = get_param_info(optimizer)
|
param_info = get_param_info(optimizer)
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
|
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||||
|
self.dp_size == 1
|
||||||
|
and self.pp_size == 1
|
||||||
|
and self.enable_sequence_parallelism
|
||||||
|
and self.sequence_parallelism_mode == "all_to_all"
|
||||||
|
)
|
||||||
|
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||||
|
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
|
||||||
|
else:
|
||||||
|
dp_group = self.dp_group
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
model,
|
model,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
shard_config=self.shard_config,
|
shard_config=self.shard_config,
|
||||||
dp_group=self.dp_group,
|
dp_group=dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
|
sp_group=self.sp_group,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
ddp_config=self.ddp_config,
|
ddp_config=self.ddp_config,
|
||||||
custom_policy=self.custom_policy,
|
custom_policy=self.custom_policy,
|
||||||
|
@ -1149,7 +1209,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.dp_size == 1:
|
zero_dp_size = dist.get_world_size(dp_group)
|
||||||
|
if zero_dp_size == 1:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||||
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
|
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
|
||||||
|
@ -1161,7 +1222,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
dp_process_group=self.dp_group,
|
dp_process_group=dp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
|
|
@ -254,6 +254,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
|
||||||
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
|
||||||
|
# TODO: Currently moe only support partially sequence parallel
|
||||||
|
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
pipeline_stage_manager=self.stage_manager,
|
pipeline_stage_manager=self.stage_manager,
|
||||||
|
@ -365,6 +368,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
shard_config=self.shard_config,
|
shard_config=self.shard_config,
|
||||||
dp_group=self.dp_group,
|
dp_group=self.dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
|
sp_group=self.sp_group,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
ddp_config=self.ddp_config,
|
ddp_config=self.ddp_config,
|
||||||
custom_policy=self.custom_policy,
|
custom_policy=self.custom_policy,
|
||||||
|
|
|
@ -161,7 +161,7 @@ class ProcessGroupMesh:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_coords_along_axis(
|
def get_coords_along_axis(
|
||||||
base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int]
|
base_coord: Tuple[int, ...], axis: Union[int, List[int]], indices_at_axis: Union[List[int], List[List[int]]]
|
||||||
) -> List[Tuple[int, ...]]:
|
) -> List[Tuple[int, ...]]:
|
||||||
"""Get coordinates along the given axis.
|
"""Get coordinates along the given axis.
|
||||||
|
|
||||||
|
@ -173,13 +173,28 @@ class ProcessGroupMesh:
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[int, ...]]: Coordinates along the axis.
|
List[Tuple[int, ...]]: Coordinates along the axis.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(axis, int):
|
||||||
|
axis = [axis,]
|
||||||
|
assert isinstance(indices_at_axis[0], int)
|
||||||
|
indices_at_axis = [indices_at_axis,]
|
||||||
|
|
||||||
|
def add_index(base_coord, axis, indices_at_axis):
|
||||||
coords_in_group = []
|
coords_in_group = []
|
||||||
for idx in indices_at_axis:
|
for idx in indices_at_axis:
|
||||||
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
|
coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :])
|
||||||
return coords_in_group
|
return coords_in_group
|
||||||
|
|
||||||
|
coords_in_group = [base_coord]
|
||||||
|
for ax, indices_at_ax in zip(axis, indices_at_axis):
|
||||||
|
new_coords_in_group = []
|
||||||
|
for coords in coords_in_group:
|
||||||
|
new_coords_in_group += add_index(coords, ax, indices_at_ax)
|
||||||
|
coords_in_group = new_coords_in_group
|
||||||
|
|
||||||
|
return coords_in_group
|
||||||
|
|
||||||
def create_group_along_axis(
|
def create_group_along_axis(
|
||||||
self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
self, axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None
|
||||||
) -> ProcessGroup:
|
) -> ProcessGroup:
|
||||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||||
|
|
||||||
|
@ -191,10 +206,17 @@ class ProcessGroupMesh:
|
||||||
Returns:
|
Returns:
|
||||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||||
"""
|
"""
|
||||||
indices_at_axis = indices_at_axis or list(range(self._shape[axis]))
|
if isinstance(axis, int):
|
||||||
|
axis = [axis,]
|
||||||
|
if indices_at_axis is not None:
|
||||||
|
assert isinstance(indices_at_axis[0], int)
|
||||||
|
indices_at_axis = [indices_at_axis,]
|
||||||
|
|
||||||
|
indices_at_axis = indices_at_axis or [list(range(self._shape[ax])) for ax in axis]
|
||||||
reduced_shape = list(self._shape)
|
reduced_shape = list(self._shape)
|
||||||
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
||||||
reduced_shape[axis] = 1
|
for ax in axis:
|
||||||
|
reduced_shape[ax] = 1
|
||||||
target_group = None
|
target_group = None
|
||||||
# use Cartesian product to generate all combinations of coordinates
|
# use Cartesian product to generate all combinations of coordinates
|
||||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||||
|
@ -225,4 +247,3 @@ class ProcessGroupMesh:
|
||||||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
||||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
||||||
return self._ranks_to_group[ranks_in_group]
|
return self._ranks_to_group[ranks_in_group]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .attn import AttnMaskType, ColoAttention
|
from .attn import AttnMaskType, ColoAttention
|
||||||
|
from ._operation import all_to_all_comm
|
||||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||||
from .linear import Linear1D_Col, Linear1D_Row
|
from .linear import Linear1D_Col, Linear1D_Row
|
||||||
|
@ -26,4 +27,5 @@ __all__ = [
|
||||||
"ParallelModule",
|
"ParallelModule",
|
||||||
"AttnMaskType",
|
"AttnMaskType",
|
||||||
"ColoAttention",
|
"ColoAttention",
|
||||||
|
"all_to_all_comm",
|
||||||
]
|
]
|
||||||
|
|
|
@ -167,6 +167,97 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
|
||||||
|
# currently only support one single tensor as output
|
||||||
|
group_size = dist.get_world_size(process_group)
|
||||||
|
cur_rank = dist.get_rank(process_group)
|
||||||
|
|
||||||
|
# output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
|
||||||
|
|
||||||
|
# initialization of ring communication
|
||||||
|
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
|
||||||
|
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
|
||||||
|
rank_map = list(dist.get_process_group_ranks(process_group))
|
||||||
|
recv_rank = rank_map[recv_rank]
|
||||||
|
send_rank = rank_map[send_rank]
|
||||||
|
recv_tensors = {}
|
||||||
|
send_tensors = {}
|
||||||
|
for k, v in input_to_gather.items():
|
||||||
|
recv_tensors[k] = torch.empty_like(v)
|
||||||
|
send_tensors[k] = v.clone()
|
||||||
|
|
||||||
|
def communicate_step():
|
||||||
|
comm_ops = []
|
||||||
|
for k in recv_tensors:
|
||||||
|
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
|
||||||
|
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
|
||||||
|
return dist.batch_isend_irecv(comm_ops)
|
||||||
|
|
||||||
|
def switch_step():
|
||||||
|
for k in recv_tensors:
|
||||||
|
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
|
||||||
|
|
||||||
|
output_tensors = []
|
||||||
|
|
||||||
|
handles = communicate_step()
|
||||||
|
# first round: special case, retrive from local tensor
|
||||||
|
output_tensors.append(func(**input_to_gather, **input_local))
|
||||||
|
for i in range(group_size - 2):
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
switch_step()
|
||||||
|
|
||||||
|
handles = communicate_step()
|
||||||
|
|
||||||
|
# actual computation
|
||||||
|
output_tensors.append(func(**send_tensors, **input_local))
|
||||||
|
|
||||||
|
# final round: special case, no need to send/recv again
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
output_tensors.append(func(**recv_tensors, **input_local))
|
||||||
|
|
||||||
|
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
|
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||||
|
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, process_group, dim):
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.dim = dim
|
||||||
|
|
||||||
|
return _gather(input_, dim, process_group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
dim = ctx.dim
|
||||||
|
process_group = ctx.process_group
|
||||||
|
|
||||||
|
# do reduce-scatter
|
||||||
|
new_shape = list(grad_output.shape)
|
||||||
|
assert (
|
||||||
|
new_shape[dim] % dist.get_world_size(process_group) == 0
|
||||||
|
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
|
||||||
|
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
|
||||||
|
grad_list = [
|
||||||
|
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
|
||||||
|
]
|
||||||
|
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
|
||||||
|
dist.reduce_scatter(output, grad_list, group=process_group)
|
||||||
|
|
||||||
|
return output, None, None
|
||||||
|
|
||||||
|
|
||||||
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
||||||
|
|
||||||
|
@ -178,7 +269,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
|
||||||
ctx.save_for_backward(input_, weight, bias)
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
|
@ -186,8 +277,21 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.overlap = overlap
|
ctx.overlap = overlap
|
||||||
|
|
||||||
input_parallel = _gather(input_, dim, process_group)
|
if ring is True:
|
||||||
|
input_to_gather = {"input": input_}
|
||||||
|
input_local = {"weight": weight}
|
||||||
|
|
||||||
|
output = _ring_as_gather(
|
||||||
|
F.linear,
|
||||||
|
input_to_gather=input_to_gather,
|
||||||
|
input_local=input_local,
|
||||||
|
process_group=process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output += bias
|
||||||
|
else:
|
||||||
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output = F.linear(input_parallel, weight, bias)
|
output = F.linear(input_parallel, weight, bias)
|
||||||
else:
|
else:
|
||||||
|
@ -294,11 +398,146 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
# wait until reduce-scatter finished
|
# wait until reduce-scatter finished
|
||||||
reducescatter_handle.wait()
|
reducescatter_handle.wait()
|
||||||
|
|
||||||
return output, grad_weight, grad_bias, None, None, None, None
|
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _ring_as_reducescatter(
|
||||||
|
func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1
|
||||||
|
):
|
||||||
|
# currently only support one single tensor as output
|
||||||
|
group_size = dist.get_world_size(process_group)
|
||||||
|
cur_rank = dist.get_rank(process_group)
|
||||||
|
|
||||||
|
# initialization of ring communication
|
||||||
|
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
|
||||||
|
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
|
||||||
|
rank_map = list(dist.get_process_group_ranks(process_group))
|
||||||
|
recv_rank = rank_map[recv_rank]
|
||||||
|
send_rank = rank_map[send_rank]
|
||||||
|
input_tensors = []
|
||||||
|
for _ in range(group_size):
|
||||||
|
input_tensors.append({})
|
||||||
|
for k, v in input_to_reducescatter.items():
|
||||||
|
input_shape = v.shape
|
||||||
|
assert input_shape[reducescatter_dim] % group_size == 0
|
||||||
|
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
|
||||||
|
for i in range(group_size):
|
||||||
|
input_tensors[i][k] = _input_tensors[i]
|
||||||
|
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
|
||||||
|
input_tensors.reverse()
|
||||||
|
|
||||||
|
output_tensor = func(**input_tensors[0], **input_local)
|
||||||
|
recv_tensor = torch.empty_like(output_tensor)
|
||||||
|
send_tensor = output_tensor.clone()
|
||||||
|
|
||||||
|
def communicate_step():
|
||||||
|
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
|
||||||
|
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
|
||||||
|
return dist.batch_isend_irecv([recv_op, send_op])
|
||||||
|
|
||||||
|
handles = communicate_step()
|
||||||
|
# first round: special case, retrive from local tensor
|
||||||
|
for i in range(group_size - 2):
|
||||||
|
# actual computation
|
||||||
|
output_tensor = func(**input_tensors[i + 1], **input_local)
|
||||||
|
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
output_tensor += recv_tensor
|
||||||
|
|
||||||
|
tmp_tensor = send_tensor
|
||||||
|
send_tensor = output_tensor
|
||||||
|
output_tensor = tmp_tensor
|
||||||
|
|
||||||
|
handles = communicate_step()
|
||||||
|
|
||||||
|
# final round: special case, no need to send/recv again
|
||||||
|
output_tensor = func(**input_tensors[-1], **input_local)
|
||||||
|
for handle in handles:
|
||||||
|
handle.wait()
|
||||||
|
output_tensor += recv_tensor
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||||
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
|
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||||
|
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
|
||||||
|
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, weight, bias, process_group, dim, ring):
|
||||||
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
|
ctx.use_bias = bias is not None
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.dim = dim
|
||||||
|
|
||||||
|
if ring is True:
|
||||||
|
input_to_reducescatter = {"input": input_}
|
||||||
|
input_local = {"weight": weight}
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
input_to_reducescatter["bias"] = bias
|
||||||
|
|
||||||
|
output = _ring_as_reducescatter(
|
||||||
|
F.linear,
|
||||||
|
input_to_reducescatter=input_to_reducescatter,
|
||||||
|
input_local=input_local,
|
||||||
|
process_group=process_group,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if bias is not None:
|
||||||
|
partial_output = F.linear(input_, weight, bias)
|
||||||
|
else:
|
||||||
|
partial_output = F.linear(input_, weight)
|
||||||
|
|
||||||
|
output_shape = list(partial_output.shape)
|
||||||
|
assert (
|
||||||
|
output_shape[dim] % dist.get_world_size(process_group) == 0
|
||||||
|
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
|
||||||
|
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
|
||||||
|
|
||||||
|
output_list = [
|
||||||
|
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
|
||||||
|
]
|
||||||
|
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
|
||||||
|
dist.reduce_scatter(output, output_list, group=process_group)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_, weight, bias = ctx.saved_tensors
|
||||||
|
use_bias = ctx.use_bias
|
||||||
|
dim = ctx.dim
|
||||||
|
process_group = ctx.process_group
|
||||||
|
|
||||||
|
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||||
|
if use_bias:
|
||||||
|
bias = bias.view(bias.shape)
|
||||||
|
|
||||||
|
grad_output = _gather(grad_output, dim, process_group)
|
||||||
|
|
||||||
|
# TODO Need to fully optimize
|
||||||
|
total_input = input_
|
||||||
|
grad_input = grad_output.matmul(weight)
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
# Convert the tensor shapes to 2D for execution compatibility
|
||||||
|
if len(grad_output.shape) > 2:
|
||||||
|
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||||
|
total_input = total_input.view(-1, total_input.shape[-1])
|
||||||
|
grad_weight = grad_output.t().matmul(total_input)
|
||||||
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||||
|
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||||
|
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||||
|
@ -343,7 +582,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
|
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring):
|
||||||
ctx.save_for_backward(input_, weight, bias)
|
ctx.save_for_backward(input_, weight, bias)
|
||||||
ctx.use_bias = bias is not None
|
ctx.use_bias = bias is not None
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
|
@ -351,6 +590,21 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.overlap = overlap
|
ctx.overlap = overlap
|
||||||
|
|
||||||
|
if ring is True:
|
||||||
|
input_to_gather = {}
|
||||||
|
input_local = {}
|
||||||
|
input_to_gather["input"] = input_
|
||||||
|
input_local["other"] = weight
|
||||||
|
|
||||||
|
output = _ring_as_gather(
|
||||||
|
torch.matmul,
|
||||||
|
input_to_gather=input_to_gather,
|
||||||
|
input_local=input_local,
|
||||||
|
process_group=process_group,
|
||||||
|
gather_dim=dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
input_parallel = _gather(input_, dim, process_group)
|
input_parallel = _gather(input_, dim, process_group)
|
||||||
|
|
||||||
output = torch.matmul(input_parallel, weight)
|
output = torch.matmul(input_parallel, weight)
|
||||||
|
@ -433,7 +687,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||||
# wait until reduce-scatter finished
|
# wait until reduce-scatter finished
|
||||||
reducescatter_handle.wait()
|
reducescatter_handle.wait()
|
||||||
|
|
||||||
return output, grad_weight, grad_bias, None, None, None, None
|
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
|
@ -448,14 +702,17 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, dim, process_group):
|
def forward(ctx, input_, dim, process_group, grad_scale=None):
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
|
ctx.grad_scale = grad_scale
|
||||||
return _split(input_, dim, process_group)
|
return _split(input_, dim, process_group)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return _gather(grad_output, ctx.dim, ctx.process_group), None, None
|
if ctx.grad_scale is not None:
|
||||||
|
grad_output = grad_output * ctx.grad_scale
|
||||||
|
return _gather(grad_output, ctx.dim, ctx.process_group), None, None, None
|
||||||
|
|
||||||
|
|
||||||
class _ReduceForward(torch.autograd.Function):
|
class _ReduceForward(torch.autograd.Function):
|
||||||
|
@ -505,14 +762,50 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, dim, process_group):
|
def forward(ctx, input_, dim, process_group, grad_scale=None):
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
|
ctx.grad_scale = grad_scale
|
||||||
return _gather(input_, dim, process_group)
|
return _gather(input_, dim, process_group)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
if ctx.grad_scale is not None:
|
||||||
|
grad_output = grad_output * ctx.grad_scale
|
||||||
|
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class _AllToAll(torch.autograd.Function):
|
||||||
|
"""All-to-all communication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_: input matrix
|
||||||
|
process_group: communication group
|
||||||
|
scatter_dim: scatter dimension
|
||||||
|
gather_dim: gather dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.scatter_dim = scatter_dim
|
||||||
|
ctx.gather_dim = gather_dim
|
||||||
|
world_size = dist.get_world_size(process_group)
|
||||||
|
bsz, _, _ = input_.shape
|
||||||
|
|
||||||
|
# using all_to_all_single when batch size is 1
|
||||||
|
if bsz == 1:
|
||||||
|
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||||
|
else:
|
||||||
|
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *grad_output):
|
||||||
|
process_group = ctx.process_group
|
||||||
|
scatter_dim = ctx.gather_dim
|
||||||
|
gather_dim = ctx.scatter_dim
|
||||||
|
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
||||||
|
return (return_grad, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
class HookParameter(torch.autograd.Function):
|
class HookParameter(torch.autograd.Function):
|
||||||
|
@ -608,6 +901,40 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
|
||||||
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
||||||
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
||||||
|
dist.all_to_all(output_list, input_list, group=group)
|
||||||
|
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
|
||||||
|
inp_shape = list(input_.shape)
|
||||||
|
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
|
||||||
|
if scatter_dim < 2:
|
||||||
|
input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous()
|
||||||
|
else:
|
||||||
|
input_t = (
|
||||||
|
input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :])
|
||||||
|
.transpose(0, 1)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.empty_like(input_t)
|
||||||
|
dist.all_to_all_single(output, input_t, group=group)
|
||||||
|
|
||||||
|
if scatter_dim < 2:
|
||||||
|
output = output.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
return output.reshape(
|
||||||
|
inp_shape[:gather_dim]
|
||||||
|
+ [
|
||||||
|
inp_shape[gather_dim] * seq_world_size,
|
||||||
|
]
|
||||||
|
+ inp_shape[gather_dim + 1 :]
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
|
||||||
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
|
||||||
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||||
|
|
||||||
|
@ -617,31 +944,39 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||||
|
|
||||||
|
|
||||||
def linear_gather_forward_reducescatter_backward(
|
def linear_gather_forward_reducescatter_backward(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||||
):
|
):
|
||||||
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
|
def gather_forward_reducescatter_backward(input_, process_group, dim):
|
||||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||||
|
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
|
||||||
|
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
|
||||||
|
|
||||||
|
|
||||||
def matmul_gather_forward_reducescatter_backward(
|
def matmul_gather_forward_reducescatter_backward(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||||
):
|
):
|
||||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
||||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gather_forward_split_backward(input_, dim, process_group):
|
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None):
|
||||||
return _GatherForwardSplitBackward.apply(input_, dim, process_group)
|
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale)
|
||||||
|
|
||||||
|
|
||||||
def split_forward_gather_backward(input_, dim, process_group):
|
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None):
|
||||||
return _SplitForwardGatherBackward.apply(input_, dim, process_group)
|
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale)
|
||||||
|
|
||||||
|
|
||||||
def reduce_forward(input_, process_group):
|
def reduce_forward(input_, process_group):
|
||||||
|
@ -650,3 +985,7 @@ def reduce_forward(input_, process_group):
|
||||||
|
|
||||||
def reduce_backward(input_, process_group):
|
def reduce_backward(input_, process_group):
|
||||||
return _ReduceBackward.apply(input_, process_group)
|
return _ReduceBackward.apply(input_, process_group)
|
||||||
|
|
||||||
|
|
||||||
|
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
||||||
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
||||||
|
|
|
@ -23,11 +23,13 @@ from colossalai.tensor.d_tensor.api import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
|
gather_forward_reducescatter_backward,
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
linear_gather_forward_reducescatter_backward,
|
linear_gather_forward_reducescatter_backward,
|
||||||
linear_reducescatter_forward_gather_backward,
|
linear_reducescatter_forward_gather_backward,
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
reduce_forward,
|
reduce_forward,
|
||||||
|
reducescatter_forward_gather_backward,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
)
|
)
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
|
@ -74,7 +76,7 @@ class Linear1D_Col(ParallelModule):
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
seq_parallel: bool = False,
|
seq_parallel_mode: str = None,
|
||||||
seq_parallel_dim: int = 1,
|
seq_parallel_dim: int = 1,
|
||||||
overlap: torch.cuda.Stream = None,
|
overlap: torch.cuda.Stream = None,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
|
@ -89,7 +91,7 @@ class Linear1D_Col(ParallelModule):
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
self.seq_parallel = seq_parallel
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
self.seq_parallel_dim = seq_parallel_dim
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
self.overlap = overlap
|
self.overlap = overlap
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
@ -196,12 +198,18 @@ class Linear1D_Col(ParallelModule):
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
if self.seq_parallel:
|
|
||||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
if self.seq_parallel_mode is None:
|
||||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||||
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
|
input_parallel = gather_forward_reducescatter_backward(
|
||||||
|
input_parallel, self.process_group, self.seq_parallel_dim
|
||||||
|
)
|
||||||
|
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
||||||
|
elif self.seq_parallel_mode == "ring":
|
||||||
|
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||||
|
)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
|
@ -225,7 +233,8 @@ class Linear1D_Row(ParallelModule):
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
seq_parallel_mode (`str`): The type of sp mode, it will use sequence parallel when `seq_parallel_mode` is not None. Defaults to None.
|
||||||
|
seq_parallel_dim (`int`): Which dim will sequence parallelism split and gather the sequence.
|
||||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
which is preserved for kernel fusion, defaults to False
|
which is preserved for kernel fusion, defaults to False
|
||||||
weight_initializer (:class:`typing.Callable`, optional):
|
weight_initializer (:class:`typing.Callable`, optional):
|
||||||
|
@ -245,7 +254,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
seq_parallel: bool = False,
|
seq_parallel_mode: str = None,
|
||||||
seq_parallel_dim: int = 1,
|
seq_parallel_dim: int = 1,
|
||||||
parallel_input: bool = True,
|
parallel_input: bool = True,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
|
@ -265,7 +274,7 @@ class Linear1D_Row(ParallelModule):
|
||||||
self.parallel_input = parallel_input
|
self.parallel_input = parallel_input
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.seq_parallel = seq_parallel
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
self.seq_parallel_dim = seq_parallel_dim
|
self.seq_parallel_dim = seq_parallel_dim
|
||||||
self.num_partitions = dist.get_world_size(self.process_group)
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
|
@ -403,18 +412,26 @@ class Linear1D_Row(ParallelModule):
|
||||||
output_parallel_list[i], group=self.process_group, async_op=True
|
output_parallel_list[i], group=self.process_group, async_op=True
|
||||||
)
|
)
|
||||||
handle_list.append(handle)
|
handle_list.append(handle)
|
||||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
|
||||||
for handle in handle_list:
|
for handle in handle_list:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
|
if self.seq_parallel_mode is None:
|
||||||
if self.seq_parallel:
|
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||||
output = linear_reducescatter_forward_gather_backward(
|
output = reduce_forward(output_parallel, self.process_group)
|
||||||
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
|
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||||
|
output = reducescatter_forward_gather_backward(
|
||||||
output_parallel, self.process_group, self.seq_parallel_dim
|
output_parallel, self.process_group, self.seq_parallel_dim
|
||||||
)
|
)
|
||||||
else:
|
elif self.seq_parallel_mode == "ring":
|
||||||
output = reduce_forward(output_parallel, self.process_group)
|
output = linear_reducescatter_forward_gather_backward(
|
||||||
|
input_,
|
||||||
|
self.weight,
|
||||||
|
process_group=self.process_group,
|
||||||
|
dim=self.seq_parallel_dim,
|
||||||
|
ring=True,
|
||||||
|
)
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
|
|
@ -25,12 +25,12 @@ from colossalai.tensor.d_tensor.api import (
|
||||||
|
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
gather_forward_split_backward,
|
gather_forward_split_backward,
|
||||||
linear_reducescatter_forward_gather_backward,
|
|
||||||
linear_with_async_comm,
|
linear_with_async_comm,
|
||||||
matmul_gather_forward_reducescatter_backward,
|
matmul_gather_forward_reducescatter_backward,
|
||||||
matmul_with_async_comm,
|
matmul_with_async_comm,
|
||||||
reduce_backward,
|
reduce_backward,
|
||||||
reduce_forward,
|
reduce_forward,
|
||||||
|
reducescatter_forward_gather_backward,
|
||||||
split_forward_gather_backward,
|
split_forward_gather_backward,
|
||||||
)
|
)
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
|
@ -150,7 +150,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
device (`torch.device`): The device of parameters, defaults to None.
|
device (`torch.device`): The device of parameters, defaults to None.
|
||||||
n_fused (int): The number items fused, defaults to 3 (QKV).
|
n_fused (int): The number items fused, defaults to 3 (QKV).
|
||||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
|
||||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
to all GPUs, otherwise, every GPU will have its output
|
||||||
which is :math:`Y_i = XA_i`, defaults to False
|
which is :math:`Y_i = XA_i`, defaults to False
|
||||||
|
@ -175,7 +175,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
async_communication: bool = False,
|
async_communication: bool = False,
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
seq_parallel: bool = False,
|
seq_parallel_mode: str = None,
|
||||||
overlap: bool = False,
|
overlap: bool = False,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
n_fused: int = 3,
|
n_fused: int = 3,
|
||||||
|
@ -190,7 +190,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
self.seq_parallel = seq_parallel
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
self.overlap = overlap
|
self.overlap = overlap
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -312,17 +312,22 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
if self.seq_parallel:
|
if self.seq_parallel_mode is None:
|
||||||
input_parallel = input_
|
|
||||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
|
||||||
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Set up backprop all-reduce.
|
# Set up backprop all-reduce.
|
||||||
input_parallel = reduce_backward(input_, self.process_group)
|
input_parallel = reduce_backward(input_, self.process_group)
|
||||||
output_parallel = matmul_with_async_comm(
|
output_parallel = matmul_with_async_comm(
|
||||||
input_parallel, self.weight, bias, self.process_group, self.async_communication
|
input_parallel, self.weight, bias, self.process_group, self.async_communication
|
||||||
)
|
)
|
||||||
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
|
input_parallel = input_
|
||||||
|
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
|
||||||
|
)
|
||||||
|
elif self.seq_parallel_mode == "ring":
|
||||||
|
input_parallel = input_
|
||||||
|
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||||
|
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap, True
|
||||||
|
)
|
||||||
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
|
@ -347,7 +352,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||||
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
|
||||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
seq_parallel_mode (str): If set to ``None``, it will not use sequence parallel, otherwise will use corresponding mode of sequence parallel, defaults to None.
|
||||||
which is preserved for kernel fusion, defaults to False
|
which is preserved for kernel fusion, defaults to False
|
||||||
weight_initializer (:class:`typing.Callable`, optional):
|
weight_initializer (:class:`typing.Callable`, optional):
|
||||||
The initializer of weight, defaults to kaiming uniform initializer.
|
The initializer of weight, defaults to kaiming uniform initializer.
|
||||||
|
@ -366,7 +371,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
process_group: ProcessGroup = None,
|
process_group: ProcessGroup = None,
|
||||||
seq_parallel: bool = False,
|
seq_parallel_mode: str = None,
|
||||||
parallel_input: bool = True,
|
parallel_input: bool = True,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
weight: Optional[Parameter] = None,
|
weight: Optional[Parameter] = None,
|
||||||
|
@ -385,7 +390,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
self.parallel_input = parallel_input
|
self.parallel_input = parallel_input
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.seq_parallel = seq_parallel
|
self.seq_parallel_mode = seq_parallel_mode
|
||||||
self.num_partitions = dist.get_world_size(self.process_group)
|
self.num_partitions = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
|
@ -528,11 +533,15 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||||
handle.wait()
|
handle.wait()
|
||||||
output = torch.cat(output_parallel_list, dim=-1)
|
output = torch.cat(output_parallel_list, dim=-1)
|
||||||
else:
|
else:
|
||||||
|
if self.seq_parallel_mode is None:
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
if self.seq_parallel:
|
|
||||||
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
|
||||||
else:
|
|
||||||
output = reduce_forward(output_parallel, self.process_group)
|
output = reduce_forward(output_parallel, self.process_group)
|
||||||
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
|
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||||
|
elif self.seq_parallel_mode == "ring":
|
||||||
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
|
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
|
@ -702,7 +711,6 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
# process_group=process_group,
|
# process_group=process_group,
|
||||||
# is_transposed=False)
|
# is_transposed=False)
|
||||||
# linear_1d.bias.data.copy_(sharded_bias.data)
|
# linear_1d.bias.data.copy_(sharded_bias.data)
|
||||||
print(linear_1d.weight.shape)
|
|
||||||
return linear_1d
|
return linear_1d
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
|
|
@ -35,17 +35,21 @@ class SeqParallelUtils:
|
||||||
return getattr(param, "partial_derived", False)
|
return getattr(param, "partial_derived", False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
|
def allreduce_partial_data_grad(
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
model: nn.Module = None,
|
||||||
|
grads: List[torch.Tensor] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Allreduce partial derived gradients across the specified process group.
|
Allreduce partial derived gradients across the specified process group.
|
||||||
|
|
||||||
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
|
This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tp_group (ProcessGroup): The process group for gradient synchronization.
|
process_group (ProcessGroup): The process group for gradient synchronization.
|
||||||
model (nn.Module): The model from which gradients will be synchronized.
|
model (nn.Module): The model from which gradients will be synchronized.
|
||||||
grads (List[torch.Tensor]): The list of gradients to be synchronized.
|
grads (List[torch.Tensor]): The list of gradients to be synchronized.
|
||||||
|
only_sp_partial (bool): Whether handle all the parameters or only parameters marked as partial derived.
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If both `model` and `grads` are provided or neither is provided.
|
AssertionError: If both `model` and `grads` are provided or neither is provided.
|
||||||
"""
|
"""
|
||||||
|
@ -53,22 +57,26 @@ class SeqParallelUtils:
|
||||||
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
|
assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
|
||||||
|
|
||||||
# Get the size of the process group, which determines whether synchronization is needed.
|
# Get the size of the process group, which determines whether synchronization is needed.
|
||||||
tp_size = get_world_size(tp_group) if tp_group is not None else 1
|
group_size = get_world_size(process_group) if process_group is not None else 1
|
||||||
|
|
||||||
if tp_size == 1:
|
if group_size == 1:
|
||||||
# If the process group size is 1, no synchronization is required.
|
# If the process group size is 1, no synchronization is required.
|
||||||
return
|
return
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
# If `model` is provided, extract partial derived gradients from the model's parameters.
|
# If `model` is provided, extract partial derived gradients from the model's parameters.
|
||||||
grads = []
|
grads = []
|
||||||
|
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
|
if p.grad is not None:
|
||||||
|
if SeqParallelUtils.is_sp_partial_derived_param(p):
|
||||||
grads.append(p.grad.data)
|
grads.append(p.grad.data)
|
||||||
|
|
||||||
# Flatten and reduce the gradients using the specified process group.
|
# Flatten and reduce the gradients using the specified process group.
|
||||||
|
if len(grads) == 0:
|
||||||
|
return
|
||||||
coalesced = _flatten_dense_tensors(grads)
|
coalesced = _flatten_dense_tensors(grads)
|
||||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
|
||||||
|
|
||||||
# Unflatten the synchronized gradients and update the model's gradients.
|
# Unflatten the synchronized gradients and update the model's gradients.
|
||||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
|
@ -76,7 +84,7 @@ class SeqParallelUtils:
|
||||||
else:
|
else:
|
||||||
# If `grads` are provided explicitly, synchronize those gradients directly.
|
# If `grads` are provided explicitly, synchronize those gradients directly.
|
||||||
coalesced = _flatten_dense_tensors(grads)
|
coalesced = _flatten_dense_tensors(grads)
|
||||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=process_group)
|
||||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
buf.copy_(synced)
|
buf.copy_(synced)
|
||||||
|
|
||||||
|
|
|
@ -186,6 +186,7 @@ class BertPipelineForwards:
|
||||||
# split the input tensor along sequence dimension
|
# split the input tensor along sequence dimension
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||||
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||||
)
|
)
|
||||||
|
@ -240,6 +241,7 @@ class BertPipelineForwards:
|
||||||
|
|
||||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||||
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||||
)
|
)
|
||||||
|
|
|
@ -213,7 +213,8 @@ class BloomPipelineForwards:
|
||||||
|
|
||||||
# split the input tensor along sequence dimension
|
# split the input tensor along sequence dimension
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
if shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||||
)
|
)
|
||||||
|
@ -261,7 +262,8 @@ class BloomPipelineForwards:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
|
||||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||||
if shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||||
)
|
)
|
||||||
|
|
|
@ -191,11 +191,10 @@ class ChatGLMPipelineForwards:
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
|
||||||
if shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||||
dim=0,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
)
|
)
|
||||||
for idx in range(start_idx, end_idx):
|
for idx in range(start_idx, end_idx):
|
||||||
layer = self.encoder._get_layer(idx)
|
layer = self.encoder._get_layer(idx)
|
||||||
|
@ -222,11 +221,10 @@ class ChatGLMPipelineForwards:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
presents = presents + (kv_cache,)
|
presents = presents + (kv_cache,)
|
||||||
|
|
||||||
if shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states,
|
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||||
dim=0,
|
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
|
||||||
)
|
)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
|
@ -218,7 +218,8 @@ class GPT2PipelineForwards:
|
||||||
|
|
||||||
# split the input tensor along sequence dimension
|
# split the input tensor along sequence dimension
|
||||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||||
if shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
|
@ -278,7 +279,8 @@ class GPT2PipelineForwards:
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||||
|
|
||||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||||
if shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
|
@ -1141,7 +1143,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.sequence_parallel_process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
@ -1208,7 +1210,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.sequence_parallel_process_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
|
@ -1,18 +1,32 @@
|
||||||
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
SequenceClassifierOutputWithPast,
|
SequenceClassifierOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaForSequenceClassification,
|
||||||
|
LlamaModel,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer._operation import (
|
||||||
|
all_to_all_comm,
|
||||||
|
gather_forward_split_backward,
|
||||||
|
split_forward_gather_backward,
|
||||||
|
)
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import ColoAttention, cross_entropy_1d
|
from ..layer import ColoAttention, cross_entropy_1d
|
||||||
|
@ -438,7 +452,7 @@ class LlamaPipelineForwards:
|
||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
llama_version = 2
|
llama_version = 2
|
||||||
|
@ -459,18 +473,30 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
|
q_len *= sp_size
|
||||||
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
query_states = all_to_all_comm(query_states, sp_group)
|
||||||
|
key_states = all_to_all_comm(key_states, sp_group)
|
||||||
|
value_states = all_to_all_comm(value_states, sp_group)
|
||||||
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -490,6 +516,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
@ -726,3 +755,261 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
|
q_len *= sp_size
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
query_states = all_to_all_comm(query_states, sp_group)
|
||||||
|
key_states = all_to_all_comm(key_states, sp_group)
|
||||||
|
value_states = all_to_all_comm(value_states, sp_group)
|
||||||
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||||
|
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||||
|
else:
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||||
|
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_llama_seq_parallel_model_forward(sp_mode, sp_size, sp_group):
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
# modify past_key_values_length when using sequence parallel
|
||||||
|
past_key_values_length *= sp_size
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if sp_mode in ["ring", "split_gather"]:
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attention_mask, attention_mask.shape, inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
|
@ -66,8 +67,17 @@ class BertPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||||
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
|
||||||
|
if sp_mode == "ring":
|
||||||
|
warnings.warn(
|
||||||
|
f"For Bert, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||||
|
)
|
||||||
|
sp_mode = "split_gather"
|
||||||
|
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[BertLayer] = ModulePolicyDescription(
|
policy[BertLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
|
@ -85,7 +95,7 @@ class BertPolicy(Policy):
|
||||||
suffix="attention.self.query",
|
suffix="attention.self.query",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel": use_sequence_parallel,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -93,7 +103,7 @@ class BertPolicy(Policy):
|
||||||
suffix="attention.self.key",
|
suffix="attention.self.key",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel": use_sequence_parallel,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -101,7 +111,7 @@ class BertPolicy(Policy):
|
||||||
suffix="attention.self.value",
|
suffix="attention.self.value",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel": use_sequence_parallel,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -112,7 +122,7 @@ class BertPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dense",
|
suffix="attention.output.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={"seq_parallel_mode": sp_mode},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.dropout",
|
suffix="attention.output.dropout",
|
||||||
|
@ -122,14 +132,14 @@ class BertPolicy(Policy):
|
||||||
suffix="intermediate.dense",
|
suffix="intermediate.dense",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel": use_sequence_parallel,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dense",
|
suffix="output.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={"seq_parallel_mode": sp_mode},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.dropout",
|
suffix="output.dropout",
|
||||||
|
@ -151,7 +161,7 @@ class BertPolicy(Policy):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_sequence_parallel:
|
if sp_mode == "split_gather":
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)},
|
description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
@ -165,12 +175,12 @@ class BertPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attention.output.LayerNorm",
|
suffix="attention.output.LayerNorm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="output.LayerNorm",
|
suffix="output.LayerNorm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
|
@ -55,8 +56,18 @@ class BloomPolicy(Policy):
|
||||||
norm_cls = col_nn.FusedLayerNorm
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
|
||||||
|
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||||
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
|
||||||
|
if sp_mode == "ring":
|
||||||
|
warnings.warn(
|
||||||
|
f"For BLOOM, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||||
|
)
|
||||||
|
sp_mode = "split_gather"
|
||||||
|
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[BloomBlock] = ModulePolicyDescription(
|
policy[BloomBlock] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
|
@ -70,12 +81,12 @@ class BloomPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.query_key_value",
|
suffix="self_attention.query_key_value",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
|
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.dense",
|
suffix="self_attention.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={"seq_parallel_mode": sp_mode},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.attention_dropout",
|
suffix="self_attention.attention_dropout",
|
||||||
|
@ -84,12 +95,12 @@ class BloomPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dense_h_to_4h",
|
suffix="mlp.dense_h_to_4h",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
|
kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.dense_4h_to_h",
|
suffix="mlp.dense_4h_to_h",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={"seq_parallel_mode": sp_mode},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -132,19 +143,19 @@ class BloomPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="input_layernorm",
|
suffix="input_layernorm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="post_attention_layernorm",
|
suffix="post_attention_layernorm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=BloomBlock,
|
target_key=BloomBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_sequence_parallel:
|
if sp_mode == "split_gather":
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)},
|
description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Union
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
|
@ -55,8 +56,17 @@ class ChatGLMPolicy(Policy):
|
||||||
norm_cls = col_nn.RMSNorm
|
norm_cls = col_nn.RMSNorm
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
|
||||||
|
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||||
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
||||||
|
if sp_mode == "ring":
|
||||||
|
warnings.warn(
|
||||||
|
f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||||
|
)
|
||||||
|
sp_mode = "split_gather"
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[ChatGLMModel] = ModulePolicyDescription(
|
policy[ChatGLMModel] = ModulePolicyDescription(
|
||||||
attribute_replacement={},
|
attribute_replacement={},
|
||||||
|
@ -91,12 +101,12 @@ class ChatGLMPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.query_key_value",
|
suffix="self_attention.query_key_value",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap},
|
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.dense",
|
suffix="self_attention.dense",
|
||||||
target_module=col_nn.Linear1D_Row,
|
target_module=col_nn.Linear1D_Row,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0},
|
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.core_attention.attention_dropout",
|
suffix="self_attention.core_attention.attention_dropout",
|
||||||
|
@ -110,12 +120,12 @@ class ChatGLMPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="input_layernorm",
|
suffix="input_layernorm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="post_attention_layernorm",
|
suffix="post_attention_layernorm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
@ -145,7 +155,7 @@ class ChatGLMPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
# use sequence parallel
|
# use sequence parallel
|
||||||
if use_sequence_parallel:
|
if sp_mode == "split_gather":
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
|
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
|
@ -50,8 +51,25 @@ class GPT2Policy(Policy):
|
||||||
norm_cls = col_nn.FusedLayerNorm
|
norm_cls = col_nn.FusedLayerNorm
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
|
||||||
|
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||||
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
|
||||||
|
if sp_mode == "ring":
|
||||||
|
warnings.warn(
|
||||||
|
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||||
|
)
|
||||||
|
sp_mode = "split_gather"
|
||||||
overlap = self.shard_config.enable_sequence_overlap
|
overlap = self.shard_config.enable_sequence_overlap
|
||||||
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
use_flash_attention = self.shard_config.enable_flash_attention
|
||||||
|
# todo: currently sp cannot be used with flashattention
|
||||||
|
if sp_mode in ["split_gather", "ring", "all_to_all"]:
|
||||||
|
if use_flash_attention:
|
||||||
|
warnings.warn(
|
||||||
|
f"Sequence parallelism mode {sp_mode} cannot be used with FlashAttention, will disable FlashAttention automatically."
|
||||||
|
)
|
||||||
|
self.shard_config.enable_flash_attention = False
|
||||||
|
use_flash_attention = False
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[GPT2Model] = ModulePolicyDescription(
|
policy[GPT2Model] = ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
|
@ -78,7 +96,7 @@ class GPT2Policy(Policy):
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 3,
|
"n_fused": 3,
|
||||||
"seq_parallel": use_sequence_parallel,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
@ -86,7 +104,7 @@ class GPT2Policy(Policy):
|
||||||
suffix="attn.c_proj",
|
suffix="attn.c_proj",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
kwargs={
|
kwargs={
|
||||||
"seq_parallel": use_sequence_parallel,
|
"seq_parallel_mode": sp_mode,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
@ -94,14 +112,16 @@ class GPT2Policy(Policy):
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||||
kwargs={
|
kwargs={
|
||||||
"n_fused": 1,
|
"n_fused": 1,
|
||||||
"seq_parallel": use_sequence_parallel,
|
"seq_parallel_mode": sp_mode,
|
||||||
"overlap": overlap,
|
"overlap": overlap,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.c_proj",
|
suffix="mlp.c_proj",
|
||||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||||
kwargs={"seq_parallel": use_sequence_parallel},
|
kwargs={
|
||||||
|
"seq_parallel_mode": sp_mode,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="attn.attn_dropout",
|
suffix="attn.attn_dropout",
|
||||||
|
@ -133,25 +153,25 @@ class GPT2Policy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="ln_1",
|
suffix="ln_1",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="ln_2",
|
suffix="ln_2",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="ln_cross_attn",
|
suffix="ln_cross_attn",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
ignore_if_not_exist=True,
|
ignore_if_not_exist=True,
|
||||||
kwargs={"sp_partial_derived": use_sequence_parallel},
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=GPT2Block,
|
target_key=GPT2Block,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_flash_attention:
|
if use_flash_attention:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_gpt2_flash_attention_forward(),
|
"forward": get_gpt2_flash_attention_forward(),
|
||||||
|
@ -164,7 +184,7 @@ class GPT2Policy(Policy):
|
||||||
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
|
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if sp_mode is not None:
|
||||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
|
@ -12,6 +12,8 @@ from ..modeling.llama import (
|
||||||
LlamaPipelineForwards,
|
LlamaPipelineForwards,
|
||||||
get_llama_flash_attention_forward,
|
get_llama_flash_attention_forward,
|
||||||
get_llama_model_forward_for_flash_attn,
|
get_llama_model_forward_for_flash_attn,
|
||||||
|
get_llama_seq_parallel_attention_forward,
|
||||||
|
get_llama_seq_parallel_model_forward,
|
||||||
get_lm_forward_with_dist_cross_entropy,
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
)
|
)
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
@ -45,9 +47,74 @@ class LlamaPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = RMSNorm
|
norm_cls = RMSNorm
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.pipeline_stage_manager is not None:
|
||||||
self.shard_config.enable_sequence_parallelism = False
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
warnings.warn("Llama doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
self.shard_config.enable_sequence_overlap = False
|
||||||
|
self.shard_config.sequence_parallelism_mode = None
|
||||||
|
warnings.warn(
|
||||||
|
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||||
|
)
|
||||||
|
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
||||||
|
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
|
||||||
|
sp_group = (
|
||||||
|
self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None
|
||||||
|
)
|
||||||
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
|
||||||
|
use_flash_attention = self.shard_config.enable_flash_attention
|
||||||
|
# Currently sp cannot to be used with flashattention
|
||||||
|
if sp_mode in ["split_gather", "ring", "all_to_all"]:
|
||||||
|
if use_flash_attention:
|
||||||
|
warnings.warn(
|
||||||
|
f"Sequence parallelism mode {sp_mode} need to be used with FlashAttention, will disable FlashAttention automatically."
|
||||||
|
)
|
||||||
|
use_flash_attention = False
|
||||||
|
|
||||||
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_llama_seq_parallel_model_forward(
|
||||||
|
sp_mode=sp_mode, sp_size=sp_size, sp_group=sp_group
|
||||||
|
),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaModel,
|
||||||
|
)
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaAttention,
|
||||||
|
)
|
||||||
|
elif sp_mode == "all_to_all":
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"num_heads": self.model.config.num_attention_heads // sp_size,
|
||||||
|
}
|
||||||
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
|
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
|
||||||
|
|
||||||
|
policy[LlamaAttention] = ModulePolicyDescription(
|
||||||
|
attribute_replacement=decoder_attribute_replacement,
|
||||||
|
)
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaAttention,
|
||||||
|
)
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_llama_seq_parallel_model_forward(
|
||||||
|
sp_mode=sp_mode,
|
||||||
|
sp_size=sp_size,
|
||||||
|
sp_group=sp_group,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaModel,
|
||||||
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
decoder_attribute_replacement = {
|
decoder_attribute_replacement = {
|
||||||
|
@ -65,30 +132,37 @@ class LlamaPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.q_proj",
|
suffix="self_attn.q_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.k_proj",
|
suffix="self_attn.k_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.v_proj",
|
suffix="self_attn.v_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attn.o_proj",
|
suffix="self_attn.o_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.gate_proj",
|
suffix="mlp.gate_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.up_proj",
|
suffix="mlp.up_proj",
|
||||||
target_module=Linear1D_Col,
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="mlp.down_proj",
|
suffix="mlp.down_proj",
|
||||||
target_module=Linear1D_Row,
|
target_module=Linear1D_Row,
|
||||||
|
kwargs=dict(seq_parallel_mode=sp_mode),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -108,10 +182,12 @@ class LlamaPolicy(Policy):
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="input_layernorm",
|
suffix="input_layernorm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="post_attention_layernorm",
|
suffix="post_attention_layernorm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
@ -122,16 +198,17 @@ class LlamaPolicy(Policy):
|
||||||
description=SubModuleReplacementDescription(
|
description=SubModuleReplacementDescription(
|
||||||
suffix="norm",
|
suffix="norm",
|
||||||
target_module=norm_cls,
|
target_module=norm_cls,
|
||||||
|
kwargs={"sp_partial_derived": sp_partial_derived},
|
||||||
),
|
),
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=LlamaModel,
|
target_key=LlamaModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if use_flash_attention:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_llama_flash_attention_forward(self.shard_config),
|
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=LlamaAttention,
|
target_key=LlamaAttention,
|
||||||
|
@ -243,7 +320,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism and not self.shard_config.enable_sequence_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for casual lm
|
||||||
new_item = {
|
new_item = {
|
||||||
LlamaForCausalLM: ModulePolicyDescription(
|
LlamaForCausalLM: ModulePolicyDescription(
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
@ -9,6 +10,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from .grad_ckpt_config import GradientCheckpointConfig
|
from .grad_ckpt_config import GradientCheckpointConfig
|
||||||
|
|
||||||
__all__ = ["ShardConfig"]
|
__all__ = ["ShardConfig"]
|
||||||
|
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -29,13 +31,15 @@ class ShardConfig:
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||||
"""
|
"""
|
||||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
|
sequence_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||||
enable_tensor_parallelism: bool = True
|
enable_tensor_parallelism: bool = True
|
||||||
|
enable_all_optimization: bool = False
|
||||||
enable_fused_normalization: bool = False
|
enable_fused_normalization: bool = False
|
||||||
enable_flash_attention: bool = False
|
enable_flash_attention: bool = False
|
||||||
enable_jit_fused: bool = False
|
enable_jit_fused: bool = False
|
||||||
enable_all_optimization: bool = False
|
|
||||||
enable_sequence_parallelism: bool = False
|
enable_sequence_parallelism: bool = False
|
||||||
|
sequence_parallelism_mode: str = None
|
||||||
enable_sequence_overlap: bool = False
|
enable_sequence_overlap: bool = False
|
||||||
parallel_output: bool = True
|
parallel_output: bool = True
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||||
|
@ -50,22 +54,57 @@ class ShardConfig:
|
||||||
def tensor_parallel_size(self):
|
def tensor_parallel_size(self):
|
||||||
return self._tensor_parallel_size
|
return self._tensor_parallel_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sequence_parallel_size(self):
|
||||||
|
return self._sequence_parallel_size
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
|
|
||||||
raise ValueError(
|
|
||||||
"enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True"
|
|
||||||
)
|
|
||||||
if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
|
|
||||||
raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
|
|
||||||
if not self.enable_tensor_parallelism:
|
|
||||||
self._tensor_parallel_size = 1
|
|
||||||
else:
|
|
||||||
# get the parallel size
|
|
||||||
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
|
||||||
# turn on all optimization if all_optimization is set to True
|
# turn on all optimization if all_optimization is set to True
|
||||||
if self.enable_all_optimization:
|
if self.enable_all_optimization:
|
||||||
self._turn_on_all_optimization()
|
self._turn_on_all_optimization()
|
||||||
|
|
||||||
|
if self.enable_sequence_parallelism:
|
||||||
|
self.sequence_parallelism_mode = (
|
||||||
|
"split_gather" if self.sequence_parallelism_mode is None else self.sequence_parallelism_mode
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
self.sequence_parallelism_mode in SUPPORT_SP_MODE
|
||||||
|
), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}"
|
||||||
|
if self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||||
|
assert (
|
||||||
|
self.enable_tensor_parallelism
|
||||||
|
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
|
||||||
|
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||||
|
assert (
|
||||||
|
not self.enable_tensor_parallelism
|
||||||
|
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
|
||||||
|
if self.enable_sequence_overlap:
|
||||||
|
self.enable_sequence_overlap = False
|
||||||
|
warnings.warn(
|
||||||
|
f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.sequence_parallelism_mode:
|
||||||
|
self.sequence_parallelism_mode = None
|
||||||
|
warnings.warn(
|
||||||
|
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
not self.enable_sequence_overlap
|
||||||
|
), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
|
||||||
|
|
||||||
|
# get the tensor parallel size
|
||||||
|
if not self.enable_tensor_parallelism:
|
||||||
|
self._tensor_parallel_size = 1
|
||||||
|
else:
|
||||||
|
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
||||||
|
|
||||||
|
# get the sequence parallel size
|
||||||
|
if not self.enable_sequence_parallelism:
|
||||||
|
self._sequence_parallel_size = 1
|
||||||
|
else:
|
||||||
|
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
|
||||||
|
|
||||||
def _turn_on_all_optimization(self):
|
def _turn_on_all_optimization(self):
|
||||||
"""
|
"""
|
||||||
Turn on all optimization.
|
Turn on all optimization.
|
||||||
|
|
|
@ -79,6 +79,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
master_weights: bool = True, # master weights
|
master_weights: bool = True, # master weights
|
||||||
):
|
):
|
||||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||||
|
|
||||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
self._verbose = verbose
|
self._verbose = verbose
|
||||||
|
@ -494,7 +495,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
# clear reduced grads
|
# clear reduced grads
|
||||||
if self._overlap_communication:
|
if self._overlap_communication:
|
||||||
get_accelerator().synchronize()
|
get_accelerator().synchronize()
|
||||||
|
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad):
|
||||||
|
|
|
@ -18,8 +18,23 @@ def data_gen():
|
||||||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||||
# input_ids = tokenized_input['input_ids']
|
# input_ids = tokenized_input['input_ids']
|
||||||
# attention_mask = tokenized_input['attention_mask']
|
# attention_mask = tokenized_input['attention_mask']
|
||||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
# input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
# attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
|
||||||
|
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
attention_mask = torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
|
||||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,9 +50,9 @@ def data_gen_for_question_answering():
|
||||||
# question answering data gen
|
# question answering data gen
|
||||||
# `labels` is the type not the token id for token classification, 0 or 1
|
# `labels` is the type not the token id for token classification, 0 or 1
|
||||||
data = data_gen()
|
data = data_gen()
|
||||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
start_positions = torch.tensor([[0], [0]], dtype=torch.int64)
|
||||||
data["start_positions"] = start_positions
|
data["start_positions"] = start_positions
|
||||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
end_positions = torch.tensor([[1], [1]], dtype=torch.int64)
|
||||||
data["end_positions"] = end_positions
|
data["end_positions"] = end_positions
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -46,14 +61,20 @@ def data_gen_for_token_classification():
|
||||||
# token classification data gen
|
# token classification data gen
|
||||||
# `labels` is the type not the token id for token classification, 0 or 1
|
# `labels` is the type not the token id for token classification, 0 or 1
|
||||||
data = data_gen()
|
data = data_gen()
|
||||||
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
data["labels"] = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1],
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def data_gen_for_sequence_classification():
|
def data_gen_for_sequence_classification():
|
||||||
# sequence classification data gen
|
# sequence classification data gen
|
||||||
data = data_gen()
|
data = data_gen()
|
||||||
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
data["labels"] = torch.tensor([[1], [1]], dtype=torch.int64)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,12 +82,18 @@ def date_gen_for_double_heads():
|
||||||
num_choices = 2
|
num_choices = 2
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]],
|
[
|
||||||
|
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
|
||||||
|
[15496, 11, 616, 3290, 318, 13779, 318, 13779, 15496, 11, 616, 3290, 318, 13779, 318, 13779],
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
attention_mask = torch.tensor(
|
||||||
|
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
|
||||||
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
|
|
||||||
|
|
||||||
|
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
|
||||||
mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64)
|
mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64)
|
||||||
mc_token_ids = mc_token_ids.expand((batch_size, num_choices))
|
mc_token_ids = mc_token_ids.expand((batch_size, num_choices))
|
||||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous()
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous()
|
||||||
|
@ -103,6 +130,7 @@ config = transformers.GPT2Config(
|
||||||
hidden_dropout=0,
|
hidden_dropout=0,
|
||||||
problem_type="single_label_classification",
|
problem_type="single_label_classification",
|
||||||
pad_token_id=50256,
|
pad_token_id=50256,
|
||||||
|
tie_word_embeddings=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
config_for_token_classification = copy.deepcopy(config)
|
config_for_token_classification = copy.deepcopy(config)
|
||||||
|
|
|
@ -28,9 +28,19 @@ if HAS_LLAMA:
|
||||||
# -----------------------------------
|
# -----------------------------------
|
||||||
|
|
||||||
input_ids = torch.Tensor(
|
input_ids = torch.Tensor(
|
||||||
[[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]]
|
[
|
||||||
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||||
|
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||||
|
]
|
||||||
).long()
|
).long()
|
||||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
|
||||||
|
attention_mask = torch.Tensor(
|
||||||
|
[
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
]
|
||||||
|
).long()
|
||||||
|
|
||||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
# label is needed for casual lm
|
# label is needed for casual lm
|
||||||
|
|
|
@ -44,7 +44,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||||
|
|
||||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||||
bert_model = model_fn()
|
bert_model = model_fn()
|
||||||
enable_all_optimization = True if tp_size > 1 else False
|
|
||||||
|
enable_flash_attention = True if tp_size > 1 else False
|
||||||
|
enable_fused_normalization = True if tp_size > 1 else False
|
||||||
|
enable_jit_fused = True if tp_size > 1 else False
|
||||||
|
|
||||||
with shared_tempdir() as tempdir:
|
with shared_tempdir() as tempdir:
|
||||||
pretrained_path = os.path.join(tempdir, "pretrained")
|
pretrained_path = os.path.join(tempdir, "pretrained")
|
||||||
|
@ -54,7 +57,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
**placement_config,
|
**placement_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
enable_all_optimization=enable_all_optimization,
|
enable_flash_attention=enable_flash_attention,
|
||||||
|
enable_fused_normalization=enable_fused_normalization,
|
||||||
|
enable_jit_fused=enable_jit_fused,
|
||||||
extra_dp_size=extra_dp_size,
|
extra_dp_size=extra_dp_size,
|
||||||
)
|
)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
@ -80,7 +85,9 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||||
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
|
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
|
||||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
enable_all_optimization = True if tp_size > 1 else False
|
enable_flash_attention = True if tp_size > 1 else False
|
||||||
|
enable_fused_normalization = True if tp_size > 1 else False
|
||||||
|
enable_jit_fused = True if tp_size > 1 else False
|
||||||
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
**placement_config,
|
**placement_config,
|
||||||
|
@ -88,7 +95,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha
|
||||||
initial_scale=(2**14),
|
initial_scale=(2**14),
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
extra_dp_size=extra_dp_size,
|
extra_dp_size=extra_dp_size,
|
||||||
enable_all_optimization=enable_all_optimization,
|
enable_flash_attention=enable_flash_attention,
|
||||||
|
enable_fused_normalization=enable_fused_normalization,
|
||||||
|
enable_jit_fused=enable_jit_fused,
|
||||||
)
|
)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,30 @@ def check_process_group_mesh_with_cases():
|
||||||
2: [2],
|
2: [2],
|
||||||
3: [3],
|
3: [3],
|
||||||
}
|
}
|
||||||
|
TPxPP_RANKS_IN_GROUP = {
|
||||||
|
0: [0, 1, 2, 3],
|
||||||
|
1: [0, 1, 2, 3],
|
||||||
|
2: [0, 1, 2, 3],
|
||||||
|
3: [0, 1, 2, 3],
|
||||||
|
}
|
||||||
|
DPxTP_RANKS_IN_GROUP = {
|
||||||
|
0: [0, 1],
|
||||||
|
1: [0, 1],
|
||||||
|
2: [2, 3],
|
||||||
|
3: [2, 3],
|
||||||
|
}
|
||||||
|
TPxPP_PARTIAL_INDICES = {
|
||||||
|
0: [[0, 1], [0]],
|
||||||
|
1: [[1], [0, 1]],
|
||||||
|
2: [[0], [0, 1]],
|
||||||
|
3: [[0, 1], [1]],
|
||||||
|
}
|
||||||
|
TPxPP_RANKS_IN_GROUP_PARTIAL = {
|
||||||
|
0: [0, 1],
|
||||||
|
1: [1, 3],
|
||||||
|
2: [0, 2],
|
||||||
|
3: [2, 3],
|
||||||
|
}
|
||||||
|
|
||||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE)
|
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE)
|
||||||
|
|
||||||
|
@ -107,6 +131,12 @@ def check_process_group_mesh_with_cases():
|
||||||
assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank]
|
assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank]
|
||||||
dp_group = pg_mesh.get_group_along_axis(DP_DIM)
|
dp_group = pg_mesh.get_group_along_axis(DP_DIM)
|
||||||
assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank]
|
assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank]
|
||||||
|
dpxtp_group = pg_mesh.create_group_along_axis([DP_DIM, TP_DIM])
|
||||||
|
assert pg_mesh.get_ranks_in_group(dpxtp_group) == DPxTP_RANKS_IN_GROUP[rank]
|
||||||
|
tpxpp_group = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM])
|
||||||
|
assert pg_mesh.get_ranks_in_group(tpxpp_group) == TPxPP_RANKS_IN_GROUP[rank]
|
||||||
|
tpxpp_group_partial = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM], TPxPP_PARTIAL_INDICES[rank])
|
||||||
|
assert pg_mesh.get_ranks_in_group(tpxpp_group_partial) == TPxPP_RANKS_IN_GROUP_PARTIAL[rank]
|
||||||
|
|
||||||
# check prev rank
|
# check prev rank
|
||||||
if RANK_TO_COORDINATE[rank][TP_DIM] != 0:
|
if RANK_TO_COORDINATE[rank][TP_DIM] != 0:
|
||||||
|
|
|
@ -56,13 +56,18 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||||
return rearanged_tensor
|
return rearanged_tensor
|
||||||
|
|
||||||
|
|
||||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = Conv1D(192, 48).cuda()
|
linear_copy = Conv1D(192, 48).cuda()
|
||||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
|
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(
|
||||||
linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap
|
linear_copy,
|
||||||
|
process_group=None,
|
||||||
|
gather_output=True,
|
||||||
|
seq_parallel_mode=seq_parallel_mode,
|
||||||
|
n_fused=3,
|
||||||
|
overlap=overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
|
@ -79,7 +84,9 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool)
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(1, 4, 48).cuda()
|
x = torch.rand(1, 4, 48).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
x_for_shard = (
|
||||||
|
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
)
|
||||||
gather_out = linear_conv_col(x_for_shard)
|
gather_out = linear_conv_col(x_for_shard)
|
||||||
assert_close(rearrange(out, -1), gather_out)
|
assert_close(rearrange(out, -1), gather_out)
|
||||||
|
|
||||||
|
@ -91,14 +98,14 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool)
|
||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = Conv1D(192, 48).cuda()
|
linear_copy = Conv1D(192, 48).cuda()
|
||||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(
|
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(
|
||||||
linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel
|
linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
|
@ -115,7 +122,7 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||||
x = torch.rand(1, 4, 48).cuda()
|
x = torch.rand(1, 4, 48).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
gather_out = linear_row(x)
|
gather_out = linear_row(x)
|
||||||
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
assert_close(target_out, gather_out)
|
assert_close(target_out, gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
|
@ -128,11 +135,11 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||||
|
|
||||||
|
|
||||||
@parameterize("lazy_init", [False, True])
|
@parameterize("lazy_init", [False, True])
|
||||||
@parameterize("seq_parallel", [False, True])
|
@parameterize("seq_parallel_mode", ["split_gather", None])
|
||||||
@parameterize("overlap", [True])
|
@parameterize("overlap", [True])
|
||||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||||
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
|
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||||
check_linear_conv_1d_row(lazy_init, seq_parallel)
|
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
|
|
@ -15,13 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = nn.Linear(32, 128).cuda()
|
linear = nn.Linear(32, 128).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = nn.Linear(32, 128).cuda()
|
linear_copy = nn.Linear(32, 128).cuda()
|
||||||
linear_col = Linear1D_Col.from_native_module(
|
linear_col = Linear1D_Col.from_native_module(
|
||||||
linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap
|
linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap
|
||||||
)
|
)
|
||||||
|
|
||||||
# ensure that the parameters are distributed
|
# ensure that the parameters are distributed
|
||||||
|
@ -43,7 +43,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||||
x = torch.rand(2, 4, 32).cuda()
|
x = torch.rand(2, 4, 32).cuda()
|
||||||
x_for_unshard = x.expand_as(x.clone())
|
x_for_unshard = x.expand_as(x.clone())
|
||||||
x_for_unshard.requires_grad_(True)
|
x_for_unshard.requires_grad_(True)
|
||||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
x_for_shard = (
|
||||||
|
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
)
|
||||||
x_for_shard.requires_grad_(True)
|
x_for_shard.requires_grad_(True)
|
||||||
|
|
||||||
out = linear(x_for_unshard)
|
out = linear(x_for_unshard)
|
||||||
|
@ -63,20 +65,20 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
||||||
assert x_for_unshard.grad is not None
|
assert x_for_unshard.grad is not None
|
||||||
target_unshard_gard = (
|
target_unshard_gard = (
|
||||||
x_for_unshard.grad
|
x_for_unshard.grad
|
||||||
if seq_parallel is False
|
if seq_parallel_mode is None
|
||||||
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
)
|
)
|
||||||
assert_close(target_unshard_gard, x_for_shard.grad)
|
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear = nn.Linear(32, 128).cuda()
|
linear = nn.Linear(32, 128).cuda()
|
||||||
with ctx:
|
with ctx:
|
||||||
linear_copy = nn.Linear(32, 128).cuda()
|
linear_copy = nn.Linear(32, 128).cuda()
|
||||||
linear_row = Linear1D_Row.from_native_module(
|
linear_row = Linear1D_Row.from_native_module(
|
||||||
linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel
|
linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||||
|
@ -98,7 +100,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||||
# run forward
|
# run forward
|
||||||
out = linear(x_for_unshard)
|
out = linear(x_for_unshard)
|
||||||
gather_out = linear_row(x_for_shard)
|
gather_out = linear_row(x_for_shard)
|
||||||
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
assert_close(target_out, gather_out)
|
assert_close(target_out, gather_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
|
@ -115,7 +117,7 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
|
||||||
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
assert_close(x_for_unshard.grad, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
|
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear_1 = nn.Linear(32, 128).cuda()
|
linear_1 = nn.Linear(32, 128).cuda()
|
||||||
|
@ -125,10 +127,10 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||||
linear_1_copy = nn.Linear(32, 128).cuda()
|
linear_1_copy = nn.Linear(32, 128).cuda()
|
||||||
linear_2_copy = nn.Linear(128, 32).cuda()
|
linear_2_copy = nn.Linear(128, 32).cuda()
|
||||||
linear_col = Linear1D_Col.from_native_module(
|
linear_col = Linear1D_Col.from_native_module(
|
||||||
linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap
|
linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap
|
||||||
)
|
)
|
||||||
linear_row = Linear1D_Row.from_native_module(
|
linear_row = Linear1D_Row.from_native_module(
|
||||||
linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel
|
linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
linear_1.load_state_dict(linear_col.state_dict())
|
linear_1.load_state_dict(linear_col.state_dict())
|
||||||
|
@ -141,13 +143,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||||
x = torch.rand(2, 4, 32).cuda()
|
x = torch.rand(2, 4, 32).cuda()
|
||||||
x_for_unshard = x.expand_as(x.clone())
|
x_for_unshard = x.expand_as(x.clone())
|
||||||
x_for_unshard.requires_grad_(True)
|
x_for_unshard.requires_grad_(True)
|
||||||
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
x_for_shard = (
|
||||||
|
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
)
|
||||||
x_for_shard.requires_grad_(True)
|
x_for_shard.requires_grad_(True)
|
||||||
|
|
||||||
# run forward
|
# run forward
|
||||||
unshard_out = linear_2(linear_1(x_for_unshard))
|
unshard_out = linear_2(linear_1(x_for_unshard))
|
||||||
shard_out = linear_row(linear_col(x_for_shard))
|
shard_out = linear_row(linear_col(x_for_shard))
|
||||||
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
|
target_out = (
|
||||||
|
unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
|
)
|
||||||
assert_close(target_out, shard_out)
|
assert_close(target_out, shard_out)
|
||||||
|
|
||||||
# check backward correctness
|
# check backward correctness
|
||||||
|
@ -163,19 +169,19 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool
|
||||||
assert x_for_unshard.grad is not None
|
assert x_for_unshard.grad is not None
|
||||||
target_unshard_gard = (
|
target_unshard_gard = (
|
||||||
x_for_unshard.grad
|
x_for_unshard.grad
|
||||||
if seq_parallel is False
|
if seq_parallel_mode is None
|
||||||
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
|
||||||
)
|
)
|
||||||
assert_close(target_unshard_gard, x_for_shard.grad)
|
assert_close(target_unshard_gard, x_for_shard.grad)
|
||||||
|
|
||||||
|
|
||||||
@parameterize("lazy_init", [False, True])
|
@parameterize("lazy_init", [False, True])
|
||||||
@parameterize("seq_parallel", [False, True])
|
@parameterize("seq_parallel_mode", [None, "split_gather"])
|
||||||
@parameterize("overlap", [True])
|
@parameterize("overlap", [True])
|
||||||
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
|
def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
|
||||||
check_linear_1d_col(lazy_init, seq_parallel, overlap)
|
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||||
check_linear_1d_row(lazy_init, seq_parallel)
|
check_linear_1d_row(lazy_init, seq_parallel_mode)
|
||||||
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
|
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
|
||||||
|
|
||||||
|
|
||||||
def check_dist_linear(rank, world_size, port):
|
def check_dist_linear(rank, world_size, port):
|
||||||
|
|
|
@ -0,0 +1,178 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.shardformer.layer import all_to_all_comm
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelAttention(torch.nn.Module):
|
||||||
|
"""Initialization.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
local_attention (Module): local attention with q,k,v
|
||||||
|
sequence_process_group (ProcessGroup): sequence parallel process group
|
||||||
|
scatter_idx (int): scatter_idx for all2all comm
|
||||||
|
gather_idx (int): gather_idx for all2all comm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
heads_num: torch.Tensor,
|
||||||
|
hidden_dim: torch.Tensor,
|
||||||
|
enable_sequence_parallellism: bool = False,
|
||||||
|
sequence_process_group: dist.ProcessGroup = None,
|
||||||
|
scatter_idx: int = 2,
|
||||||
|
gather_idx: int = 1,
|
||||||
|
) -> None:
|
||||||
|
super(SequenceParallelAttention, self).__init__()
|
||||||
|
self.spg = sequence_process_group
|
||||||
|
self.scatter_idx = scatter_idx
|
||||||
|
self.gather_idx = gather_idx
|
||||||
|
self.heads_num = heads_num
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
assert hidden_dim % heads_num == 0
|
||||||
|
self.head_dim = hidden_dim // heads_num
|
||||||
|
self.enable_sequence_parallellism = enable_sequence_parallellism
|
||||||
|
|
||||||
|
self.q = nn.Linear(hidden_dim, hidden_dim)
|
||||||
|
self.k = nn.Linear(hidden_dim, hidden_dim)
|
||||||
|
self.v = nn.Linear(hidden_dim, hidden_dim)
|
||||||
|
self.out = nn.Linear(hidden_dim, hidden_dim)
|
||||||
|
|
||||||
|
def attn(self, q, k, v):
|
||||||
|
batch_size, seq_len = q.shape[0], q.shape[1]
|
||||||
|
|
||||||
|
scale = self.head_dim**0.5
|
||||||
|
qk = torch.matmul(q, k.transpose(-2, -1)) / scale
|
||||||
|
weights = F.softmax(qk, dim=-1)
|
||||||
|
|
||||||
|
attention_score = torch.matmul(weights, v)
|
||||||
|
|
||||||
|
return attention_score
|
||||||
|
|
||||||
|
def forward(self, x) -> Tensor:
|
||||||
|
bsz, q_len, _ = x.size()
|
||||||
|
|
||||||
|
seq_len = q_len * dist.get_world_size(self.spg) if self.enable_sequence_parallellism else q_len
|
||||||
|
num_heads = (
|
||||||
|
self.heads_num // dist.get_world_size(self.spg) if self.enable_sequence_parallellism else self.heads_num
|
||||||
|
)
|
||||||
|
|
||||||
|
# in shape : e.g., [s/p:h:]
|
||||||
|
query_states = self.q(x)
|
||||||
|
key_states = self.k(x)
|
||||||
|
value_states = self.v(x)
|
||||||
|
|
||||||
|
if self.enable_sequence_parallellism:
|
||||||
|
query_states = all_to_all_comm(query_states, self.spg, self.scatter_idx, self.gather_idx)
|
||||||
|
key_states = all_to_all_comm(key_states, self.spg, self.scatter_idx, self.gather_idx)
|
||||||
|
value_states = all_to_all_comm(value_states, self.spg, self.scatter_idx, self.gather_idx)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, seq_len, num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
# out shape : e.g., [s:h/p:]
|
||||||
|
attn_score = self.attn(query_states, key_states, value_states)
|
||||||
|
attn_score = attn_score.transpose(1, 2).contiguous()
|
||||||
|
attn_score = attn_score.reshape(bsz, seq_len, num_heads * self.head_dim)
|
||||||
|
if self.enable_sequence_parallellism:
|
||||||
|
attn_score = all_to_all_comm(attn_score, self.spg, self.gather_idx, self.scatter_idx)
|
||||||
|
|
||||||
|
# output e.g., [s/p::h]
|
||||||
|
output = self.out(attn_score)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
|
||||||
|
seq_len = seq_len
|
||||||
|
hidden_dim = hidden_dim
|
||||||
|
head_num = head_num
|
||||||
|
batch_size = batch_size
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, seq_len, hidden_dim).cuda()
|
||||||
|
x_unshard = x.clone()
|
||||||
|
x_unshard.requires_grad_(True)
|
||||||
|
x_input = torch.chunk(x.clone(), world_size, dim=1)[dist.get_rank()]
|
||||||
|
x_input.requires_grad_(True)
|
||||||
|
|
||||||
|
# Multi-head Attention
|
||||||
|
mha = SequenceParallelAttention(head_num, hidden_dim).cuda()
|
||||||
|
# Multi-head Attention forward
|
||||||
|
mha_out = mha(x_unshard)
|
||||||
|
|
||||||
|
# Sequence parallel Attention
|
||||||
|
sp_attn = SequenceParallelAttention(head_num, hidden_dim, True).cuda()
|
||||||
|
sp_attn.load_state_dict(copy.deepcopy(mha.state_dict()))
|
||||||
|
# Sequence parallel Attention forward
|
||||||
|
dist_attn_out = sp_attn(x_input)
|
||||||
|
|
||||||
|
# gather the output of sequence parallel attention
|
||||||
|
out_list = [torch.empty_like(dist_attn_out) for _ in range(world_size)]
|
||||||
|
dist.all_gather(out_list, dist_attn_out)
|
||||||
|
seq_out = torch.cat(out_list, dim=1)
|
||||||
|
|
||||||
|
# forward result check
|
||||||
|
assert_close(seq_out, mha_out)
|
||||||
|
|
||||||
|
# Multi-head Attention backward
|
||||||
|
mha_out.sum().backward()
|
||||||
|
q_grad = mha.q.weight.grad
|
||||||
|
k_grad = mha.k.weight.grad
|
||||||
|
v_grad = mha.v.weight.grad
|
||||||
|
o_grad = mha.out.weight.grad
|
||||||
|
x_grad = x_unshard.grad
|
||||||
|
|
||||||
|
# Sequence parallel Attention backward
|
||||||
|
dist_attn_out.sum().backward()
|
||||||
|
q_grad_seq = sp_attn.q.weight.grad
|
||||||
|
k_grad_seq = sp_attn.k.weight.grad
|
||||||
|
v_grad_seq = sp_attn.v.weight.grad
|
||||||
|
o_grad_seq = sp_attn.out.weight.grad
|
||||||
|
x_grad_seq = x_input.grad
|
||||||
|
# all_reduce the grad of sequence parallel attention weight
|
||||||
|
dist.all_reduce(q_grad_seq)
|
||||||
|
dist.all_reduce(k_grad_seq)
|
||||||
|
dist.all_reduce(v_grad_seq)
|
||||||
|
dist.all_reduce(o_grad_seq)
|
||||||
|
# gather the grad of sequence parallel attention input
|
||||||
|
x_grad_seq_list = [torch.empty_like(x_grad_seq) for _ in range(world_size)]
|
||||||
|
dist.all_gather(x_grad_seq_list, x_grad_seq)
|
||||||
|
x_grad_seq_gather = torch.cat(x_grad_seq_list, dim=1)
|
||||||
|
|
||||||
|
# backward result check
|
||||||
|
assert_close(q_grad_seq, q_grad)
|
||||||
|
assert_close(k_grad_seq, k_grad)
|
||||||
|
assert_close(v_grad_seq, v_grad, atol=1e-4, rtol=1e-4)
|
||||||
|
assert_close(o_grad_seq, o_grad)
|
||||||
|
assert_close(x_grad_seq_gather, x_grad)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("seq_len", [128])
|
||||||
|
@parameterize("hidden_dim", [64])
|
||||||
|
@parameterize("head_num", [4])
|
||||||
|
@parameterize("batch_size", [1])
|
||||||
|
def run_seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size):
|
||||||
|
seq_parallel_attn(seq_len, hidden_dim, head_num, batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
def check_all2all_attn(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_seq_parallel_attn()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_all_to_all_attention():
|
||||||
|
spawn(check_all2all_attn, nprocs=4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_all_to_all_attention()
|
|
@ -1,5 +1,4 @@
|
||||||
import copy
|
import copy
|
||||||
import math
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -123,7 +122,6 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
|
||||||
sharded_model = copy.deepcopy(org_model)
|
sharded_model = copy.deepcopy(org_model)
|
||||||
if use_lazy_init:
|
if use_lazy_init:
|
||||||
ctx.materialize(org_model)
|
ctx.materialize(org_model)
|
||||||
|
|
||||||
org_model = org_model.cuda()
|
org_model = org_model.cuda()
|
||||||
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
org_optimizer = Adam(org_model.parameters(), lr=1e-3)
|
||||||
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
|
||||||
|
@ -162,24 +160,22 @@ def run_forward_backward_with_hybrid_plugin(
|
||||||
|
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
||||||
if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0:
|
shard_test_data = {}
|
||||||
seq_len = data["input_ids"].shape[-1]
|
|
||||||
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
|
|
||||||
times = lcm // seq_len
|
|
||||||
input_shape = data["input_ids"].shape
|
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if v.shape == input_shape:
|
shard_test_data[k] = data[k].clone()
|
||||||
data[k] = v.repeat((1,) * (v.dim() - 1) + (times,))
|
unshard_test_data = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
unshard_test_data[k] = data[k].clone()
|
||||||
|
|
||||||
sharded_model.train()
|
sharded_model.train()
|
||||||
if booster.plugin.stage_manager is not None:
|
if booster.plugin.stage_manager is not None:
|
||||||
for k, v in data.items():
|
for k, v in shard_test_data.items():
|
||||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||||
new_shape = [1] * v.dim()
|
new_shape = [1] * v.dim()
|
||||||
new_shape[0] = 4
|
new_shape[0] = 4
|
||||||
data[k] = v.to("cuda").repeat(*new_shape)
|
shard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
||||||
|
|
||||||
data_iter = iter([data])
|
data_iter = iter([shard_test_data])
|
||||||
sharded_output = booster.execute_pipeline(
|
sharded_output = booster.execute_pipeline(
|
||||||
data_iter,
|
data_iter,
|
||||||
sharded_model,
|
sharded_model,
|
||||||
|
@ -189,17 +185,22 @@ def run_forward_backward_with_hybrid_plugin(
|
||||||
return_outputs=True,
|
return_outputs=True,
|
||||||
)
|
)
|
||||||
sharded_loss = sharded_output["loss"]
|
sharded_loss = sharded_output["loss"]
|
||||||
else:
|
|
||||||
data = {k: v.cuda() for k, v in data.items()}
|
|
||||||
sharded_output = sharded_model(**data)
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
shard_test_data = {k: v.cuda() for k, v in shard_test_data.items()}
|
||||||
|
sharded_output = sharded_model(**shard_test_data)
|
||||||
sharded_loss = criterion(sharded_output)
|
sharded_loss = criterion(sharded_output)
|
||||||
sharded_optimizer.backward(sharded_loss)
|
sharded_optimizer.backward(sharded_loss)
|
||||||
|
|
||||||
org_model.train()
|
org_model.train()
|
||||||
data = {k: v.cuda() for k, v in data.items()}
|
if booster.plugin.stage_manager is not None:
|
||||||
org_output = org_model(**data)
|
for k, v in unshard_test_data.items():
|
||||||
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||||
|
new_shape = [1] * v.dim()
|
||||||
|
new_shape[0] = 4
|
||||||
|
unshard_test_data[k] = v.to("cuda").repeat(*new_shape)
|
||||||
|
unshard_test_data = {k: v.cuda() for k, v in unshard_test_data.items()}
|
||||||
|
org_output = org_model(**unshard_test_data)
|
||||||
org_loss = criterion(org_output)
|
org_loss = criterion(org_output)
|
||||||
org_loss.backward()
|
org_loss.backward()
|
||||||
|
|
||||||
|
@ -212,7 +213,6 @@ def check_output_hidden_state(
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
atol: float = 1e-5,
|
atol: float = 1e-5,
|
||||||
rtol: float = 1e-3,
|
rtol: float = 1e-3,
|
||||||
dim: int = 0,
|
|
||||||
):
|
):
|
||||||
org_hidden_state = org_output.last_hidden_state
|
org_hidden_state = org_output.last_hidden_state
|
||||||
|
|
||||||
|
|
|
@ -100,6 +100,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp32",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
|
@ -154,7 +176,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
)
|
)
|
||||||
def run_bert_test(test_config):
|
def run_bert_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
||||||
|
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp32",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
|
|
@ -135,6 +135,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp32",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
|
|
@ -131,6 +131,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp32",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
|
|
@ -2,6 +2,8 @@ import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
@ -46,6 +48,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
|
|
||||||
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
|
||||||
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
col_layer_for_check = ["layers[0].self_attn.o_proj"]
|
||||||
|
# Here we check the grad of layernorm because an all-reduce operation should be performed during sequence parallelism
|
||||||
|
norm_layer_for_check = ["layers[0].input_layernorm", "layers[0].post_attention_layernorm"]
|
||||||
|
|
||||||
|
# During pipeline parallelism, we cannot get the grad of norm layer during first stage, so we only check this when pp is not enbaled
|
||||||
|
if stage_manager is None:
|
||||||
|
norm_layer_for_check.append("norm")
|
||||||
|
|
||||||
|
# Check the grad when using ZeRO-1 and ZeRO-2
|
||||||
|
if (
|
||||||
|
booster.plugin.zero_stage in [1, 2]
|
||||||
|
and booster.plugin.shard_config.enable_sequence_parallelism
|
||||||
|
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||||
|
):
|
||||||
|
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||||
|
working_p = sharded_optimizer._param_store.master_to_working_param[id(p2)]
|
||||||
|
grads = sharded_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||||
|
grad_index = 0 if sharded_optimizer._partition_grads else sharded_optimizer._local_rank
|
||||||
|
grad = grads[grad_index]
|
||||||
|
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||||
|
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||||
|
|
||||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||||
grads_to_check = {}
|
grads_to_check = {}
|
||||||
|
@ -60,8 +82,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
col_layer_grads = get_grad_tensors_for_check(
|
col_layer_grads = get_grad_tensors_for_check(
|
||||||
llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||||
)
|
)
|
||||||
|
norm_layer_grads = get_grad_tensors_for_check(
|
||||||
|
llama_model,
|
||||||
|
shard_llama_model,
|
||||||
|
norm_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=1,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
grads_to_check.update(col_layer_grads)
|
grads_to_check.update(col_layer_grads)
|
||||||
grads_to_check.update(row_layer_grads)
|
grads_to_check.update(row_layer_grads)
|
||||||
|
grads_to_check.update(norm_layer_grads)
|
||||||
|
|
||||||
# optimizer executes step
|
# optimizer executes step
|
||||||
org_optimizer.step()
|
org_optimizer.step()
|
||||||
|
@ -98,6 +131,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"test_config",
|
"test_config",
|
||||||
[
|
[
|
||||||
|
{
|
||||||
|
"tp_size": 2,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": True,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "ring",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp32",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "split_gather",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"sp_size": 2,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"zero_stage": 2,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tp_size": 1,
|
||||||
|
"pp_size": 1,
|
||||||
|
"num_microbatches": 1,
|
||||||
|
"enable_sequence_parallelism": True,
|
||||||
|
"sequence_parallelism_mode": "all_to_all",
|
||||||
|
"enable_flash_attention": False,
|
||||||
|
"use_lazy_init": True,
|
||||||
|
"precision": "fp16",
|
||||||
|
"initial_scale": 1,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
|
|
Loading…
Reference in New Issue