diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 0674451a4..812f3e629 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1177,7 +1177,10 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config=gradient_checkpoint_config, fp8_communication=fp8_communication, inner_ring_size=inner_ring_size, + pg_mesh=self.pg_mesh, + sp_axis=self.sp_axis, ) + self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5f0e9261c..bbd99d162 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group, inner_ring_size=None): + def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -441,21 +441,17 @@ class RingAttention(torch.autograd.Function): Returns: Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ + assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." + + sp_group = pg_mesh.get_group_along_axis(sp_axis) sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) - if inner_ring_size is None: - if torch.cuda.device_count() >= dist.get_world_size(): - # single node, no need to consider NICs - return sp_group, sp_group - if sp_size <= 4: - inner_ring_size = min(2, sp_size) - else: - inner_ring_size = min(4, sp_size) - else: - assert ( - inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 - ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" + assert inner_ring_size is not None + + assert ( + inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 + ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" if inner_ring_size == sp_size: return sp_group, sp_group @@ -474,14 +470,14 @@ class RingAttention(torch.autograd.Function): # Create inner ring groups for i in range(inner_ring_size): ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) - group = dist.new_group(ranks) + group = pg_mesh.get_group_along_axis(sp_axis, ranks) if sp_rank in ranks: inner_ring_group = group # Create inter ring groups for i in range(num_rings): ranks = list(range(i, sp_size, num_rings)) - group = dist.new_group(ranks) + group = pg_mesh.get_group_along_axis(sp_axis, ranks) if sp_rank in ranks: inter_ring_group = group @@ -492,7 +488,7 @@ class RingAttention(torch.autograd.Function): q, # (B, H, Sq, D) k, v, - sp_group, + sp_axis, attention_mask_type, cu_seqlens=None, max_seqlen=None, @@ -502,6 +498,7 @@ class RingAttention(torch.autograd.Function): deterministic=False, return_softmax=False, inner_ring_size=None, + pg_mesh=None, **kwargs, ): """ @@ -512,7 +509,7 @@ class RingAttention(torch.autograd.Function): q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] - sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism + sp_axis (Optional[int]): Sp axis for the global pg mesh. sp_tream (torch.cuda.Stream): An different stream for output correction. cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths of the sequences in the batch, used to index into q. @@ -537,7 +534,6 @@ class RingAttention(torch.autograd.Function): RingAttention.ATTN_DONE = torch.cuda.Event() if RingAttention.SP_STREAM is None: RingAttention.SP_STREAM = torch.cuda.Stream() - assert ( q.shape[2] == k.shape[2] ), "Q, K and V having different sequence lengths (inference or cross-attn)\ @@ -546,11 +542,13 @@ class RingAttention(torch.autograd.Function): attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES ), f"Mask type {attention_mask_type} is not supported yet." - clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) + assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." - if RingAttention.SP_GROUP is not sp_group: + clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) + sp_group = pg_mesh.get_group_along_axis(sp_axis) + if inner_ring_size != None: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 798fca88f..d550484da 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -857,17 +857,17 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) dropout_p = self.attn_dropout.p if self.training else 0.0 sp_mode = shard_config.sequence_parallelism_mode - sp_group = shard_config.sequence_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, key, value, - sp_group, + sp_axis=shard_config.sp_axis, **attention_mask, dropout_p=dropout_p, scale=scale, inner_ring_size=shard_config.inner_ring_size, + pg_mesh=shard_config.pg_mesh, ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2a5b60287..9a0da82f5 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -569,9 +569,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s query_states, key_states, value_states, - sp_group, + sp_axis=shard_config.sp_axis, **attention_mask, inner_ring_size=shard_config.inner_ring_size, + pg_mesh=shard_config.pg_mesh, ) elif shard_config.enable_flash_attention: diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 911226e5c..4d4a1803b 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -49,6 +49,8 @@ class ShardConfig: extra_kwargs: Dict[str, Any] = field(default_factory=dict) # For ring attention + sp_axis: Optional[int] = None + pg_mesh: Optional[int] = None inner_ring_size: Optional[int] = None # for moe related moe_dp_group: Optional[ProcessGroup] = None diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 1c7647a7d..6ebd8da73 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu from torch.testing import assert_close import colossalai +from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag @@ -17,11 +18,14 @@ from colossalai.utils import get_current_device @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_ring_attn(seq_len, bs, nheads, d, dtype): +def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): torch.cuda.manual_seed(2) device = get_current_device() sp_group = dist.group.WORLD + dp_size, pp_size, tp_size = 1, 1, 1 sp_size = dist.get_world_size() + sp_axis = 2 + pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size) # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) @@ -40,11 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): q, k, v, - sp_group, + sp_axis, AttnMaskType.CAUSAL, return_softmax=True, - inner_ring_size=max(2, sp_size // 2), - # inner_ring_size=4 + inner_ring_size=inner_ring_size, + pg_mesh=pg_mesh, ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( @@ -83,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() + sp_axis = 2 atol = rtol = 7e-3 torch.cuda.manual_seed(2) # Prepare varlen attention mask @@ -123,10 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): q_ring, k_ring, v_ring, - sp_group, + sp_axis, **mask_info, pad_output=False, return_softmax=True, + pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1), # deterministic=True ) ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) @@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): def launch_single_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_packed_seq() - check_ring_attn() + check_ring_attn(inner_ring_size=None) def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - check_ring_attn() + check_ring_attn(inner_ring_size=2) @rerun_if_address_is_in_use()