From cfd9eda628cb9a3e7c6ecb15daeca5b93741fa12 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 25 Sep 2024 18:34:29 +0800 Subject: [PATCH 01/26] fix the ring attn --- colossalai/shardformer/layer/attn.py | 37 +++++++++++++++--------- colossalai/shardformer/modeling/llama.py | 2 ++ 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 5f0e9261c..1f897c1be 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Tuple import torch import torch.distributed import torch.distributed as dist +from torch.distributed import ProcessGroup import torch.nn.functional as F from einops import rearrange @@ -431,7 +432,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group, inner_ring_size=None): + def get_double_ring_groups(sp_group,tp_group, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -443,6 +444,7 @@ class RingAttention(torch.autograd.Function): """ sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) + tp_size = dist.get_world_size(tp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -471,19 +473,24 @@ class RingAttention(torch.autograd.Function): inner_ring_group = None inter_ring_group = None + world_size = dist.get_world_size() + rank = dist.get_rank() + groups = int(world_size/ sp_size) # Create inner ring groups - for i in range(inner_ring_size): - ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) - group = dist.new_group(ranks) - if sp_rank in ranks: - inner_ring_group = group + for group_id in range(groups): + for i in range(inner_ring_size): + ranks = list(range(i +(group_id*sp_size), (1+group_id)*sp_size, inner_ring_size)) + group = dist.new_group(ranks) + if rank in ranks: + inner_ring_group = group # Create inter ring groups - for i in range(num_rings): - ranks = list(range(i, sp_size, num_rings)) - group = dist.new_group(ranks) - if sp_rank in ranks: - inter_ring_group = group + for group_id in range(groups): + for i in range(num_rings): + ranks = list(range(i+group_id * num_rings, world_size, sp_size)) + group = dist.new_group(ranks) + if rank in ranks: + inter_ring_group = group return inner_ring_group, inter_ring_group @@ -493,6 +500,7 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, + tp_group, attention_mask_type, cu_seqlens=None, max_seqlen=None, @@ -537,7 +545,6 @@ class RingAttention(torch.autograd.Function): RingAttention.ATTN_DONE = torch.cuda.Event() if RingAttention.SP_STREAM is None: RingAttention.SP_STREAM = torch.cuda.Stream() - assert ( q.shape[2] == k.shape[2] ), "Q, K and V having different sequence lengths (inference or cross-attn)\ @@ -550,7 +557,7 @@ class RingAttention(torch.autograd.Function): if RingAttention.SP_GROUP is not sp_group: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, tp_group, inner_ring_size) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: @@ -597,6 +604,7 @@ class RingAttention(torch.autograd.Function): attention_mask_type == AttnMaskType.PADDED_CAUSAL, inner_ring_group, inter_ring_group, + tp_group, ) if attention_mask_type == AttnMaskType.PADDED_CAUSAL: @@ -627,6 +635,7 @@ class RingAttention(torch.autograd.Function): is_packed: Optional[bool] = False, inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, + tp_group: Optional[dist.ProcessGroup] = None, ): cu_seqlens_q = cu_seqlens_kv = cu_seqlens @@ -1123,7 +1132,7 @@ class RingAttention(torch.autograd.Function): if not is_packed: dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] - return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None) @staticmethod def prepare_varlen_batch( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 47c17e749..fc5bcac6b 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -563,12 +563,14 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + tp_group = shard_config.tensor_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, key_states, value_states, sp_group, + tp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, ) From 65c829771058433a9f7b88299bd1925d05d53554 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 25 Sep 2024 18:51:03 +0800 Subject: [PATCH 02/26] fix the attn --- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/modeling/gpt2.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1f897c1be..419932a00 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -500,7 +500,7 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, - tp_group, + tp_group : Optional[dist.ProcessGroup], attention_mask_type, cu_seqlens=None, max_seqlen=None, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 798fca88f..8f476ab86 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -858,12 +858,14 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group + tp_group = shard_config.tensor_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, key, value, sp_group, + tp_group, **attention_mask, dropout_p=dropout_p, scale=scale, From 6fb1322db1525c90f7f80cebad4b447e009bacd4 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 25 Sep 2024 18:56:18 +0800 Subject: [PATCH 03/26] fix --- colossalai/shardformer/layer/attn.py | 20 +++++++++++--------- colossalai/shardformer/modeling/gpt2.py | 1 + colossalai/shardformer/modeling/llama.py | 1 + 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 419932a00..15ad09baa 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -4,7 +4,6 @@ from typing import Callable, Dict, Optional, Tuple import torch import torch.distributed import torch.distributed as dist -from torch.distributed import ProcessGroup import torch.nn.functional as F from einops import rearrange @@ -432,7 +431,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group,tp_group, inner_ring_size=None): + def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -443,8 +442,8 @@ class RingAttention(torch.autograd.Function): Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ sp_size = dist.get_world_size(sp_group) - sp_rank = dist.get_rank(sp_group) - tp_size = dist.get_world_size(tp_group) + dist.get_rank(sp_group) + dist.get_world_size(tp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -475,11 +474,12 @@ class RingAttention(torch.autograd.Function): world_size = dist.get_world_size() rank = dist.get_rank() - groups = int(world_size/ sp_size) + groups = int(world_size / sp_size) + # Create inner ring groups for group_id in range(groups): for i in range(inner_ring_size): - ranks = list(range(i +(group_id*sp_size), (1+group_id)*sp_size, inner_ring_size)) + ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group @@ -487,7 +487,7 @@ class RingAttention(torch.autograd.Function): # Create inter ring groups for group_id in range(groups): for i in range(num_rings): - ranks = list(range(i+group_id * num_rings, world_size, sp_size)) + ranks = list(range(i + group_id * num_rings, world_size, sp_size)) group = dist.new_group(ranks) if rank in ranks: inter_ring_group = group @@ -500,7 +500,7 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, - tp_group : Optional[dist.ProcessGroup], + tp_group: Optional[dist.ProcessGroup], attention_mask_type, cu_seqlens=None, max_seqlen=None, @@ -557,7 +557,9 @@ class RingAttention(torch.autograd.Function): if RingAttention.SP_GROUP is not sp_group: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, tp_group, inner_ring_size) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( + sp_group, tp_group, inner_ring_size + ) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8f476ab86..01d47bcd0 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -859,6 +859,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group tp_group = shard_config.tensor_parallel_process_group + if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index fc5bcac6b..b5f505dce 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -564,6 +564,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s value_states = repeat_kv(value_states, self.num_key_value_groups) tp_group = shard_config.tensor_parallel_process_group + if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, From 91ed32c2569b18f03a870f82f7f3ddb4b2da4e4f Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 25 Sep 2024 19:00:38 +0800 Subject: [PATCH 04/26] fix --- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 15ad09baa..2cc6f3163 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -500,8 +500,8 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, - tp_group: Optional[dist.ProcessGroup], attention_mask_type, + tp_group=None, cu_seqlens=None, max_seqlen=None, valid_indices=None, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 01d47bcd0..6be75a3c6 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -866,7 +866,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) key, value, sp_group, - tp_group, + tp_group=tp_group, **attention_mask, dropout_p=dropout_p, scale=scale, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b5f505dce..08f4bc90d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -571,7 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s key_states, value_states, sp_group, - tp_group, + tp_group=tp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, ) From 6705dad41b001b75a94f1f92da48df666f5e7b55 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 25 Sep 2024 19:02:21 +0800 Subject: [PATCH 05/26] fix --- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 2cc6f3163..b7738c7e2 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -501,7 +501,6 @@ class RingAttention(torch.autograd.Function): v, sp_group, attention_mask_type, - tp_group=None, cu_seqlens=None, max_seqlen=None, valid_indices=None, @@ -510,6 +509,7 @@ class RingAttention(torch.autograd.Function): deterministic=False, return_softmax=False, inner_ring_size=None, + tp_group=None, **kwargs, ): """ diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 6be75a3c6..43b3d2c5e 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -866,11 +866,11 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) key, value, sp_group, - tp_group=tp_group, **attention_mask, dropout_p=dropout_p, scale=scale, inner_ring_size=shard_config.inner_ring_size, + tp_group=tp_group, ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 08f4bc90d..aa761fd21 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -571,9 +571,9 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s key_states, value_states, sp_group, - tp_group=tp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, + tp_group=tp_group, ) elif shard_config.enable_flash_attention: From 3fab92166e9547359296d29ebda5b4ce03437502 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 26 Sep 2024 18:03:09 +0800 Subject: [PATCH 06/26] fix --- colossalai/shardformer/layer/attn.py | 46 +++++++++++++++++++--------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index b7738c7e2..21dbe718e 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -442,8 +442,7 @@ class RingAttention(torch.autograd.Function): Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ sp_size = dist.get_world_size(sp_group) - dist.get_rank(sp_group) - dist.get_world_size(tp_group) + tp_size = dist.get_world_size(tp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -476,21 +475,38 @@ class RingAttention(torch.autograd.Function): rank = dist.get_rank() groups = int(world_size / sp_size) - # Create inner ring groups - for group_id in range(groups): - for i in range(inner_ring_size): - ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) - group = dist.new_group(ranks) + if tp_size > 1: + for group_id in range(groups): + for i in range(inner_ring_size): + ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) + group = dist.new_group(ranks) + if rank in ranks: + inner_ring_group = group + for group_id in range(groups): + for i in range(num_rings): + ranks = list(range(i + group_id * num_rings, world_size, sp_size)) + group = dist.new_group(ranks) + if rank in ranks: + inter_ring_group = group + else: + for i in range(sp_size // 2): + ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1)) if rank in ranks: + print( + "rank:", + rank, + "inner ranks:", + ranks, + ) + group = dist.new_group(ranks) inner_ring_group = group - - # Create inter ring groups - for group_id in range(groups): - for i in range(num_rings): - ranks = list(range(i + group_id * num_rings, world_size, sp_size)) - group = dist.new_group(ranks) - if rank in ranks: - inter_ring_group = group + for group_id in range(num_rings): + for i in range(num_rings): + ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size)) + ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7] + if rank in ranks: + group = dist.new_group(ranks) + inter_ring_group = group return inner_ring_group, inter_ring_group From 3532f77b90ccec4e94e3d8dc2b2e0ac76008cab9 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 9 Oct 2024 10:57:19 +0800 Subject: [PATCH 07/26] fix --- colossalai/shardformer/layer/attn.py | 49 ++++++++++++++-------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 21dbe718e..36a4ec963 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -443,6 +443,7 @@ class RingAttention(torch.autograd.Function): """ sp_size = dist.get_world_size(sp_group) tp_size = dist.get_world_size(tp_group) + sp_rank = dist.get_rank(sp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -467,46 +468,44 @@ class RingAttention(torch.autograd.Function): f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", ranks=[0], ) - num_rings = sp_size // inner_ring_size + inner_ring_group = None inter_ring_group = None world_size = dist.get_world_size() rank = dist.get_rank() - groups = int(world_size / sp_size) + + inner_rings = world_size // sp_size + num_rings = sp_size // inner_ring_size if tp_size > 1: - for group_id in range(groups): - for i in range(inner_ring_size): - ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) + for i in range(inner_rings): + for j in range(sp_size // tp_size): + # find inner ring group in one sp group + ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group - for group_id in range(groups): - for i in range(num_rings): - ranks = list(range(i + group_id * num_rings, world_size, sp_size)) + for i in range(inner_rings): + for j in range(sp_size // tp_size): + ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size)) group = dist.new_group(ranks) if rank in ranks: inter_ring_group = group else: - for i in range(sp_size // 2): - ranks = list(range((i) * num_rings, (i + 1) * num_rings, 1)) - if rank in ranks: - print( - "rank:", - rank, - "inner ranks:", - ranks, - ) - group = dist.new_group(ranks) + # Create inner ring groups + for i in range(inner_ring_size): + ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) + group = dist.new_group(ranks) + if sp_rank in ranks: inner_ring_group = group - for group_id in range(num_rings): - for i in range(num_rings): - ranks = list(range(i + group_id * num_rings, world_size, inner_ring_size)) - ranks = [0, 1, 4, 5] if rank == 0 or rank == 1 or rank == 4 or rank == 5 else [2, 3, 6, 7] - if rank in ranks: - group = dist.new_group(ranks) - inter_ring_group = group + + # Create inter ring groups + for i in range(num_rings): + ranks = list(range(i, sp_size, num_rings)) + group = dist.new_group(ranks) + if sp_rank in ranks: + inter_ring_group = group return inner_ring_group, inter_ring_group From b635dd06696f8c5deac29333ae85ce8cf68bd550 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 9 Oct 2024 14:05:26 +0800 Subject: [PATCH 08/26] fix --- colossalai/shardformer/layer/attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 36a4ec963..3853860c4 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -442,7 +442,7 @@ class RingAttention(torch.autograd.Function): Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ sp_size = dist.get_world_size(sp_group) - tp_size = dist.get_world_size(tp_group) + tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 sp_rank = dist.get_rank(sp_group) if inner_ring_size is None: From f98384aef67f413310d507d3dc873578e7ee6a39 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 10 Oct 2024 15:17:06 +0800 Subject: [PATCH 09/26] fix --- colossalai/shardformer/layer/attn.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 3853860c4..75718f608 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -468,27 +468,29 @@ class RingAttention(torch.autograd.Function): f"Using 2D Ring Attention with inner ring size {inner_ring_size} to maximze NIC util for inter-node comm. Cross your fingers for speed-ups!", ranks=[0], ) - + num_rings = sp_size // inner_ring_size inner_ring_group = None inter_ring_group = None world_size = dist.get_world_size() rank = dist.get_rank() - inner_rings = world_size // sp_size - num_rings = sp_size // inner_ring_size + num_ring_size = world_size // num_rings + num_inner_group = num_ring_size // inner_ring_size if tp_size > 1: - for i in range(inner_rings): - for j in range(sp_size // tp_size): - # find inner ring group in one sp group - ranks = list(range(j + i * sp_size, j + (i + 1) * sp_size, tp_size)) + # inner_ring_size = 2 + for i in range(num_rings): + for j in range(num_inner_group): + # find inner ring group in one sp groups + ranks = list(range(j + i * num_ring_size, j + (i + 1) * num_ring_size, tp_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group - for i in range(inner_rings): - for j in range(sp_size // tp_size): - ranks = list(range(j + i * (sp_size // tp_size), inner_rings + (i + 1) * sp_size, sp_size)) + for i in range(num_rings): + for j in range(num_inner_group): + start = j + (i * num_inner_group) + ranks = list(range(start, start + num_ring_size + 1, num_ring_size)) group = dist.new_group(ranks) if rank in ranks: inter_ring_group = group From 5ecc27e1509575e605c47279f243491e20968558 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 10 Oct 2024 15:35:52 +0800 Subject: [PATCH 10/26] fix --- colossalai/shardformer/layer/attn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 75718f608..a191694c1 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -479,7 +479,6 @@ class RingAttention(torch.autograd.Function): num_inner_group = num_ring_size // inner_ring_size if tp_size > 1: - # inner_ring_size = 2 for i in range(num_rings): for j in range(num_inner_group): # find inner ring group in one sp groups From efe3042bb235064b4ebff8f9549d6eb93d653302 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 10 Oct 2024 18:38:47 +0800 Subject: [PATCH 11/26] fix --- colossalai/shardformer/layer/attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index a191694c1..7c25bee1a 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -482,7 +482,8 @@ class RingAttention(torch.autograd.Function): for i in range(num_rings): for j in range(num_inner_group): # find inner ring group in one sp groups - ranks = list(range(j + i * num_ring_size, j + (i + 1) * num_ring_size, tp_size)) + start = j + i * num_ring_size + ranks = list(range(start, start + tp_size * inner_ring_size, tp_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group From 0002ae5956bfbab15810ce95ea8db70865fb80f1 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 11 Oct 2024 14:16:21 +0800 Subject: [PATCH 12/26] fix --- colossalai/shardformer/layer/attn.py | 37 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 7c25bee1a..7af8271eb 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -476,24 +476,31 @@ class RingAttention(torch.autograd.Function): rank = dist.get_rank() num_ring_size = world_size // num_rings - num_inner_group = num_ring_size // inner_ring_size if tp_size > 1: + ranks = [] for i in range(num_rings): - for j in range(num_inner_group): - # find inner ring group in one sp groups - start = j + i * num_ring_size - ranks = list(range(start, start + tp_size * inner_ring_size, tp_size)) - group = dist.new_group(ranks) - if rank in ranks: - inner_ring_group = group - for i in range(num_rings): - for j in range(num_inner_group): - start = j + (i * num_inner_group) - ranks = list(range(start, start + num_ring_size + 1, num_ring_size)) - group = dist.new_group(ranks) - if rank in ranks: - inter_ring_group = group + start = i * num_ring_size + end = (i + 1) * num_ring_size + for idx in range(start, end): + inner_rank = [] + for k in range(inner_ring_size): + current_num = idx + k * tp_size + if current_num >= end: + break + inner_rank.append(current_num) + if len(inner_rank) == inner_ring_size and inner_rank not in ranks: + ranks.append(inner_rank) + group = dist.new_group(inner_rank) + if rank in inner_rank: + inner_ring_group = group + + for i in range(num_ring_size): + inter_rank = [i + j * num_ring_size for j in range(num_rings)] + group = dist.new_group(inter_rank) + if rank in inter_rank: + inter_ring_group = group + else: # Create inner ring groups for i in range(inner_ring_size): From 1507a7528fd86fba5a86427ff7e1283709f863d0 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 11 Oct 2024 06:20:34 +0000 Subject: [PATCH 13/26] fix --- colossalai/shardformer/layer/attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 7af8271eb..411b4dbcc 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -478,6 +478,7 @@ class RingAttention(torch.autograd.Function): num_ring_size = world_size // num_rings if tp_size > 1: + # Create inner ring groups ranks = [] for i in range(num_rings): start = i * num_ring_size @@ -494,7 +495,7 @@ class RingAttention(torch.autograd.Function): group = dist.new_group(inner_rank) if rank in inner_rank: inner_ring_group = group - + # Create inter ring groups for i in range(num_ring_size): inter_rank = [i + j * num_ring_size for j in range(num_rings)] group = dist.new_group(inter_rank) From 4e0e99bb6a356cb734194aa985d837e3744b0b92 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 11 Oct 2024 17:31:40 +0800 Subject: [PATCH 14/26] fix the test --- .../test_layer/test_ring_attn.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 1c7647a7d..38dfcb239 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu from torch.testing import assert_close import colossalai +from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag @@ -17,11 +18,15 @@ from colossalai.utils import get_current_device @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_ring_attn(seq_len, bs, nheads, d, dtype): +def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): torch.cuda.manual_seed(2) device = get_current_device() - sp_group = dist.group.WORLD - sp_size = dist.get_world_size() + sp_axis, tp_axis = 0, 1 + pg_mesh = ProcessGroupMesh(sp_size, tp_size) + tp_group = pg_mesh.get_group_along_axis(tp_axis) + sp_group = pg_mesh.get_group_along_axis(sp_axis) + # sp_group = dist.group.WORLD + # sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) @@ -44,6 +49,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype): AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=max(2, sp_size // 2), + tp_group=tp_group, # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) @@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): def launch_single_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_packed_seq() - check_ring_attn() + check_ring_attn(sp_size=world_size) def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - check_ring_attn() + check_ring_attn(sp_size=4, tp_size=2) @rerun_if_address_is_in_use() @@ -176,7 +182,7 @@ def test_ring_attn(world_size): @rerun_if_address_is_in_use() -@parameterize("world_size", [4]) +@parameterize("world_size", [8]) def test_double_ring(world_size): spawn(launch_double_ring, nprocs=world_size) From 703bb5c18dd3a701209017de828e9cdc8ca58356 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 11 Oct 2024 17:34:20 +0800 Subject: [PATCH 15/26] fix the test --- tests/test_shardformer/test_layer/test_ring_attn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 38dfcb239..b9291a061 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -25,8 +25,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): pg_mesh = ProcessGroupMesh(sp_size, tp_size) tp_group = pg_mesh.get_group_along_axis(tp_axis) sp_group = pg_mesh.get_group_along_axis(sp_axis) - # sp_group = dist.group.WORLD - # sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) @@ -50,7 +48,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): return_softmax=True, inner_ring_size=max(2, sp_size // 2), tp_group=tp_group, - # inner_ring_size=4 ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( From e1e86f9f1fdc7a885d9183b19567e440e3e8b34f Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 14 Oct 2024 11:45:35 +0800 Subject: [PATCH 16/26] fix --- colossalai/shardformer/layer/attn.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 411b4dbcc..aa013e526 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -484,12 +484,7 @@ class RingAttention(torch.autograd.Function): start = i * num_ring_size end = (i + 1) * num_ring_size for idx in range(start, end): - inner_rank = [] - for k in range(inner_ring_size): - current_num = idx + k * tp_size - if current_num >= end: - break - inner_rank.append(current_num) + inner_rank = [idx + k * tp_size for k in range(inner_ring_size) if idx + k * tp_size < end] if len(inner_rank) == inner_ring_size and inner_rank not in ranks: ranks.append(inner_rank) group = dist.new_group(inner_rank) From d891e50617b07bc09f3213069b857ba8c0325d2c Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 14 Oct 2024 14:56:05 +0800 Subject: [PATCH 17/26] fix --- colossalai/shardformer/layer/attn.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index aa013e526..36a239142 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -445,18 +445,11 @@ class RingAttention(torch.autograd.Function): tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 sp_rank = dist.get_rank(sp_group) - if inner_ring_size is None: - if torch.cuda.device_count() >= dist.get_world_size(): - # single node, no need to consider NICs - return sp_group, sp_group - if sp_size <= 4: - inner_ring_size = min(2, sp_size) - else: - inner_ring_size = min(4, sp_size) - else: - assert ( - inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 - ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" + assert inner_ring_size is not None + + assert ( + inner_ring_size <= sp_size and sp_size % inner_ring_size == 0 + ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}" if inner_ring_size == sp_size: return sp_group, sp_group @@ -575,7 +568,7 @@ class RingAttention(torch.autograd.Function): clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) - if RingAttention.SP_GROUP is not sp_group: + if inner_ring_size != None: RingAttention.SP_GROUP = sp_group inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( sp_group, tp_group, inner_ring_size From 23199e34cc0ae6fdd6c9297ef088bf4dc4713b52 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 14 Oct 2024 18:01:53 +0800 Subject: [PATCH 18/26] fix --- colossalai/shardformer/layer/attn.py | 63 ++++++------------- colossalai/shardformer/modeling/gpt2.py | 2 - colossalai/shardformer/modeling/llama.py | 3 - .../test_layer/test_ring_attn.py | 8 +-- 4 files changed, 20 insertions(+), 56 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 36a239142..c49113a24 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -7,6 +7,7 @@ import torch.distributed as dist import torch.nn.functional as F from einops import rearrange +from colossalai.cluster import ProcessGroupMesh from colossalai.kernel.kernel_loader import ( FlashAttentionDaoLoader, FlashAttentionForFloatAndCustomMaskLoader, @@ -431,7 +432,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): + def get_double_ring_groups(sp_group, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -442,7 +443,6 @@ class RingAttention(torch.autograd.Function): Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ sp_size = dist.get_world_size(sp_group) - tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 sp_rank = dist.get_rank(sp_group) assert inner_ring_size is not None @@ -465,45 +465,22 @@ class RingAttention(torch.autograd.Function): inner_ring_group = None inter_ring_group = None - world_size = dist.get_world_size() - rank = dist.get_rank() - - num_ring_size = world_size // num_rings - - if tp_size > 1: - # Create inner ring groups - ranks = [] - for i in range(num_rings): - start = i * num_ring_size - end = (i + 1) * num_ring_size - for idx in range(start, end): - inner_rank = [idx + k * tp_size for k in range(inner_ring_size) if idx + k * tp_size < end] - if len(inner_rank) == inner_ring_size and inner_rank not in ranks: - ranks.append(inner_rank) - group = dist.new_group(inner_rank) - if rank in inner_rank: - inner_ring_group = group - # Create inter ring groups - for i in range(num_ring_size): - inter_rank = [i + j * num_ring_size for j in range(num_rings)] - group = dist.new_group(inter_rank) - if rank in inter_rank: - inter_ring_group = group + inter_axis, inner_axis = 0, 1 + pg_mesh = ProcessGroupMesh(num_rings, inner_ring_size) - else: - # Create inner ring groups - for i in range(inner_ring_size): - ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) - group = dist.new_group(ranks) - if sp_rank in ranks: - inner_ring_group = group - - # Create inter ring groups - for i in range(num_rings): - ranks = list(range(i, sp_size, num_rings)) - group = dist.new_group(ranks) - if sp_rank in ranks: - inter_ring_group = group + # Create inner ring groups + for i in range(inner_ring_size): + ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) + group = pg_mesh.get_group_along_axis(inner_axis) + if sp_rank in ranks: + inner_ring_group = group + + # Create inter ring groups + for i in range(num_rings): + ranks = list(range(i, sp_size, num_rings)) + group = pg_mesh.get_group_along_axis(inter_axis) + if sp_rank in ranks: + inter_ring_group = group return inner_ring_group, inter_ring_group @@ -522,7 +499,6 @@ class RingAttention(torch.autograd.Function): deterministic=False, return_softmax=False, inner_ring_size=None, - tp_group=None, **kwargs, ): """ @@ -570,9 +546,7 @@ class RingAttention(torch.autograd.Function): if inner_ring_size != None: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( - sp_group, tp_group, inner_ring_size - ) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: @@ -619,7 +593,6 @@ class RingAttention(torch.autograd.Function): attention_mask_type == AttnMaskType.PADDED_CAUSAL, inner_ring_group, inter_ring_group, - tp_group, ) if attention_mask_type == AttnMaskType.PADDED_CAUSAL: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 43b3d2c5e..aec8e84a0 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -858,7 +858,6 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group - tp_group = shard_config.tensor_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( @@ -870,7 +869,6 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) dropout_p=dropout_p, scale=scale, inner_ring_size=shard_config.inner_ring_size, - tp_group=tp_group, ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index aa761fd21..47c17e749 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -563,8 +563,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - tp_group = shard_config.tensor_parallel_process_group - if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states, @@ -573,7 +571,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s sp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, - tp_group=tp_group, ) elif shard_config.enable_flash_attention: diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index b9291a061..61ea0677f 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -5,7 +5,6 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu from torch.testing import assert_close import colossalai -from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag @@ -21,10 +20,8 @@ from colossalai.utils import get_current_device def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): torch.cuda.manual_seed(2) device = get_current_device() - sp_axis, tp_axis = 0, 1 - pg_mesh = ProcessGroupMesh(sp_size, tp_size) - tp_group = pg_mesh.get_group_along_axis(tp_axis) - sp_group = pg_mesh.get_group_along_axis(sp_axis) + sp_group = dist.group.WORLD + sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) @@ -47,7 +44,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=max(2, sp_size // 2), - tp_group=tp_group, ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( From 3201377e94d3cc8d2753214e4a136944f230ea42 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 14 Oct 2024 18:06:24 +0800 Subject: [PATCH 19/26] fix --- colossalai/shardformer/modeling/gpt2.py | 1 - tests/test_shardformer/test_layer/test_ring_attn.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index aec8e84a0..798fca88f 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -858,7 +858,6 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group - if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 61ea0677f..bcb2c1f8a 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -17,7 +17,7 @@ from colossalai.utils import get_current_device @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): +def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size): torch.cuda.manual_seed(2) device = get_current_device() sp_group = dist.group.WORLD @@ -165,7 +165,7 @@ def launch_single_ring(rank, world_size, port): def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - check_ring_attn(sp_size=4, tp_size=2) + check_ring_attn(sp_size=world_size) @rerun_if_address_is_in_use() @@ -175,7 +175,7 @@ def test_ring_attn(world_size): @rerun_if_address_is_in_use() -@parameterize("world_size", [8]) +@parameterize("world_size", [4]) def test_double_ring(world_size): spawn(launch_double_ring, nprocs=world_size) From fe9208feaca07159dbff387342346df510ab8a0d Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 14 Oct 2024 18:07:56 +0800 Subject: [PATCH 20/26] fix --- colossalai/shardformer/layer/attn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index c49113a24..1f39a44b3 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -623,7 +623,6 @@ class RingAttention(torch.autograd.Function): is_packed: Optional[bool] = False, inner_ring_group: Optional[dist.ProcessGroup] = None, inter_ring_group: Optional[dist.ProcessGroup] = None, - tp_group: Optional[dist.ProcessGroup] = None, ): cu_seqlens_q = cu_seqlens_kv = cu_seqlens @@ -1120,7 +1119,7 @@ class RingAttention(torch.autograd.Function): if not is_packed: dq, dk, dv = [x.view(b, sq, *x.shape[-2:]) for x in (dq, dk, dv)] - return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None) + return (dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None) @staticmethod def prepare_varlen_batch( From 8ff7d0c78048b4231e1f772a8227282ae7d5822a Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 14 Oct 2024 18:16:03 +0800 Subject: [PATCH 21/26] fix --- tests/test_shardformer/test_layer/test_ring_attn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index bcb2c1f8a..0ffea2016 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -17,11 +17,10 @@ from colossalai.utils import get_current_device @parameterize("nheads", [5]) @parameterize("d", [128]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size): +def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): torch.cuda.manual_seed(2) device = get_current_device() sp_group = dist.group.WORLD - sp_size = dist.get_world_size() # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) @@ -43,7 +42,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size): sp_group, AttnMaskType.CAUSAL, return_softmax=True, - inner_ring_size=max(2, sp_size // 2), + inner_ring_size=inner_ring_size, ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( @@ -160,12 +159,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): def launch_single_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) check_packed_seq() - check_ring_attn(sp_size=world_size) + check_ring_attn(inner_ring_size=None) def launch_double_ring(rank, world_size, port): colossalai.launch(rank, world_size, "localhost", port) - check_ring_attn(sp_size=world_size) + check_ring_attn(inner_ring_size=2) @rerun_if_address_is_in_use() From 3dc08c8a5a03d6bb28d865d2dbaae948c27e7e9a Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 15 Oct 2024 11:01:34 +0800 Subject: [PATCH 22/26] fix --- .../booster/plugin/hybrid_parallel_plugin.py | 2 ++ colossalai/shardformer/layer/attn.py | 15 +++++++-------- colossalai/shardformer/modeling/gpt2.py | 1 + colossalai/shardformer/modeling/llama.py | 1 + colossalai/shardformer/shard/shard_config.py | 1 + .../test_shardformer/test_layer/test_ring_attn.py | 13 ++++++------- 6 files changed, 18 insertions(+), 15 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bb663f6a6..d15a9f397 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1180,7 +1180,9 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_checkpoint_config=gradient_checkpoint_config, fp8_communication=fp8_communication, inner_ring_size=inner_ring_size, + pg_mesh=self.pg_mesh, ) + self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1f39a44b3..6a972f075 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -7,7 +7,6 @@ import torch.distributed as dist import torch.nn.functional as F from einops import rearrange -from colossalai.cluster import ProcessGroupMesh from colossalai.kernel.kernel_loader import ( FlashAttentionDaoLoader, FlashAttentionForFloatAndCustomMaskLoader, @@ -432,7 +431,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group, inner_ring_size=None): + def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -465,20 +464,17 @@ class RingAttention(torch.autograd.Function): inner_ring_group = None inter_ring_group = None - inter_axis, inner_axis = 0, 1 - pg_mesh = ProcessGroupMesh(num_rings, inner_ring_size) - # Create inner ring groups for i in range(inner_ring_size): ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) - group = pg_mesh.get_group_along_axis(inner_axis) + group = pg_mesh.get_group_along_axis(2, ranks) if sp_rank in ranks: inner_ring_group = group # Create inter ring groups for i in range(num_rings): ranks = list(range(i, sp_size, num_rings)) - group = pg_mesh.get_group_along_axis(inter_axis) + group = pg_mesh.get_group_along_axis(2, ranks) if sp_rank in ranks: inter_ring_group = group @@ -499,6 +495,7 @@ class RingAttention(torch.autograd.Function): deterministic=False, return_softmax=False, inner_ring_size=None, + pg_mesh=None, **kwargs, ): """ @@ -546,7 +543,9 @@ class RingAttention(torch.autograd.Function): if inner_ring_size != None: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( + sp_group, pg_mesh, inner_ring_size + ) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 798fca88f..d75e9b14a 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -868,6 +868,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) dropout_p=dropout_p, scale=scale, inner_ring_size=shard_config.inner_ring_size, + pg_mesh=shard_config.pg_mesh, ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 47c17e749..ff7390643 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -571,6 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s sp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, + pg_mesh=shard_config.pg_mesh, ) elif shard_config.enable_flash_attention: diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 1219119bb..14963f7a5 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -51,6 +51,7 @@ class ShardConfig: extra_kwargs: Dict[str, Any] = field(default_factory=dict) # For ring attention + pg_mesh: Optional[int] = None inner_ring_size: Optional[int] = None # for moe related moe_dp_group: Optional[ProcessGroup] = None diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 0ffea2016..48f25dea3 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu from torch.testing import assert_close import colossalai +from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag @@ -21,6 +22,9 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): torch.cuda.manual_seed(2) device = get_current_device() sp_group = dist.group.WORLD + dp_size, pp_size, tp_size = 1, 1, 1 + sp_size = dist.get_world_size() + pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size) # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) @@ -36,13 +40,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): # Ring attention vs single GPU ring_out, ring_lse = RingAttention.attention( - q, - k, - v, - sp_group, - AttnMaskType.CAUSAL, - return_softmax=True, - inner_ring_size=inner_ring_size, + q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( @@ -125,6 +123,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): **mask_info, pad_output=False, return_softmax=True, + pg_mesh=dist.group.WORLD, # deterministic=True ) ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) From 6be9862aafb4230ead0bcb95c65b5568043e2c13 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 15 Oct 2024 11:56:49 +0800 Subject: [PATCH 23/26] fix --- .../booster/plugin/hybrid_parallel_plugin.py | 1 + colossalai/shardformer/layer/attn.py | 14 +++++++++----- colossalai/shardformer/modeling/gpt2.py | 1 + colossalai/shardformer/modeling/llama.py | 1 + colossalai/shardformer/shard/shard_config.py | 1 + .../test_shardformer/test_layer/test_ring_attn.py | 11 ++++++++++- 6 files changed, 23 insertions(+), 6 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d15a9f397..de0cd242f 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1181,6 +1181,7 @@ class HybridParallelPlugin(PipelinePluginBase): fp8_communication=fp8_communication, inner_ring_size=inner_ring_size, pg_mesh=self.pg_mesh, + sp_axis=self.sp_axis, ) self.amp_config = dict( diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 6a972f075..1d8a89ce0 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None): + def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -441,6 +441,9 @@ class RingAttention(torch.autograd.Function): Returns: Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ + assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." + + sp_group = pg_mesh.get_group_along_axis(sp_axis) sp_size = dist.get_world_size(sp_group) sp_rank = dist.get_rank(sp_group) @@ -496,6 +499,7 @@ class RingAttention(torch.autograd.Function): return_softmax=False, inner_ring_size=None, pg_mesh=None, + sp_axis=None, **kwargs, ): """ @@ -506,7 +510,7 @@ class RingAttention(torch.autograd.Function): q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] - sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism + sp_axis (Optional[int]): Sp axis for the global pg mesh. sp_tream (torch.cuda.Stream): An different stream for output correction. cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths of the sequences in the batch, used to index into q. @@ -539,13 +543,13 @@ class RingAttention(torch.autograd.Function): attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES ), f"Mask type {attention_mask_type} is not supported yet." + assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." + clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) if inner_ring_size != None: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( - sp_group, pg_mesh, inner_ring_size - ) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index d75e9b14a..1eaed167c 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -869,6 +869,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) scale=scale, inner_ring_size=shard_config.inner_ring_size, pg_mesh=shard_config.pg_mesh, + sp_axis=shard_config.sp_axis, ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ff7390643..ca0751f04 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -572,6 +572,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s **attention_mask, inner_ring_size=shard_config.inner_ring_size, pg_mesh=shard_config.pg_mesh, + sp_axis=shard_config.sp_axis, ) elif shard_config.enable_flash_attention: diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 14963f7a5..7b31e6928 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -51,6 +51,7 @@ class ShardConfig: extra_kwargs: Dict[str, Any] = field(default_factory=dict) # For ring attention + sp_axis: Optional[int] = None pg_mesh: Optional[int] = None inner_ring_size: Optional[int] = None # for moe related diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 48f25dea3..23e8e5b78 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -24,6 +24,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): sp_group = dist.group.WORLD dp_size, pp_size, tp_size = 1, 1, 1 sp_size = dist.get_world_size() + sp_axis = 2 pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size) # Some outliers may seem large, but our errors are still lower than # than Megatron-LM context parallel's @@ -40,7 +41,15 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): # Ring attention vs single GPU ring_out, ring_lse = RingAttention.attention( - q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh + q, + k, + v, + sp_group, + AttnMaskType.CAUSAL, + return_softmax=True, + inner_ring_size=inner_ring_size, + pg_mesh=pg_mesh, + sp_axis=sp_axis, ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( From fd92789af27d3d7269529508c316db62040568aa Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 15 Oct 2024 13:26:44 +0800 Subject: [PATCH 24/26] fix --- colossalai/shardformer/layer/attn.py | 5 ++--- colossalai/shardformer/modeling/gpt2.py | 5 ++--- colossalai/shardformer/modeling/llama.py | 3 +-- tests/test_shardformer/test_layer/test_ring_attn.py | 8 ++++---- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1d8a89ce0..1a175f426 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -488,7 +488,7 @@ class RingAttention(torch.autograd.Function): q, # (B, H, Sq, D) k, v, - sp_group, + sp_axis, attention_mask_type, cu_seqlens=None, max_seqlen=None, @@ -499,7 +499,6 @@ class RingAttention(torch.autograd.Function): return_softmax=False, inner_ring_size=None, pg_mesh=None, - sp_axis=None, **kwargs, ): """ @@ -546,7 +545,7 @@ class RingAttention(torch.autograd.Function): assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization." clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) - + sp_group = pg_mesh.get_group_along_axis(sp_axis) if inner_ring_size != None: RingAttention.SP_GROUP = sp_group inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 1eaed167c..90ffba8cf 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -857,19 +857,18 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) dropout_p = self.attn_dropout.p if self.training else 0.0 sp_mode = shard_config.sequence_parallelism_mode - sp_group = shard_config.sequence_parallel_process_group + shard_config.sequence_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, key, value, - sp_group, + sp_axis=shard_config.sp_axis, **attention_mask, dropout_p=dropout_p, scale=scale, inner_ring_size=shard_config.inner_ring_size, pg_mesh=shard_config.pg_mesh, - sp_axis=shard_config.sp_axis, ) else: attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ca0751f04..2e6363060 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -568,11 +568,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s query_states, key_states, value_states, - sp_group, + sp_axis=shard_config.sp_axis, **attention_mask, inner_ring_size=shard_config.inner_ring_size, pg_mesh=shard_config.pg_mesh, - sp_axis=shard_config.sp_axis, ) elif shard_config.enable_flash_attention: diff --git a/tests/test_shardformer/test_layer/test_ring_attn.py b/tests/test_shardformer/test_layer/test_ring_attn.py index 23e8e5b78..6ebd8da73 100644 --- a/tests/test_shardformer/test_layer/test_ring_attn.py +++ b/tests/test_shardformer/test_layer/test_ring_attn.py @@ -44,12 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size): q, k, v, - sp_group, + sp_axis, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh, - sp_axis=sp_axis, ) ring_out = ring_out.transpose(1, 2) out, lse, _ = flash_attn_qkvpacked_func( @@ -88,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): device = get_current_device() sp_group = dist.group.WORLD sp_size = dist.get_world_size() + sp_axis = 2 atol = rtol = 7e-3 torch.cuda.manual_seed(2) # Prepare varlen attention mask @@ -128,11 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype): q_ring, k_ring, v_ring, - sp_group, + sp_axis, **mask_info, pad_output=False, return_softmax=True, - pg_mesh=dist.group.WORLD, + pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1), # deterministic=True ) ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) From bc7eeade33e33e3a7c2df26fedab707f3a62d6fe Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 15 Oct 2024 13:28:33 +0800 Subject: [PATCH 25/26] fix --- colossalai/shardformer/modeling/gpt2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 90ffba8cf..d550484da 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -857,7 +857,6 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) dropout_p = self.attn_dropout.p if self.training else 0.0 sp_mode = shard_config.sequence_parallelism_mode - shard_config.sequence_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, From 83cf2f84fb0c08a351a5affc71527556ff8912bc Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 15 Oct 2024 14:50:27 +0800 Subject: [PATCH 26/26] fix --- colossalai/shardformer/layer/attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1a175f426..bbd99d162 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -470,14 +470,14 @@ class RingAttention(torch.autograd.Function): # Create inner ring groups for i in range(inner_ring_size): ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) - group = pg_mesh.get_group_along_axis(2, ranks) + group = pg_mesh.get_group_along_axis(sp_axis, ranks) if sp_rank in ranks: inner_ring_group = group # Create inter ring groups for i in range(num_rings): ranks = list(range(i, sp_size, num_rings)) - group = pg_mesh.get_group_along_axis(2, ranks) + group = pg_mesh.get_group_along_axis(sp_axis, ranks) if sp_rank in ranks: inter_ring_group = group