From 6fb1322db1525c90f7f80cebad4b447e009bacd4 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 25 Sep 2024 18:56:18 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 20 +++++++++++--------- colossalai/shardformer/modeling/gpt2.py | 1 + colossalai/shardformer/modeling/llama.py | 1 + 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 419932a00..15ad09baa 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -4,7 +4,6 @@ from typing import Callable, Dict, Optional, Tuple import torch import torch.distributed import torch.distributed as dist -from torch.distributed import ProcessGroup import torch.nn.functional as F from einops import rearrange @@ -432,7 +431,7 @@ class RingAttention(torch.autograd.Function): INTER_RING_GROUP_COPY: dist.ProcessGroup = None @staticmethod - def get_double_ring_groups(sp_group,tp_group, inner_ring_size=None): + def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): """ Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size shouldn't be larger than the number of NICs on each node. @@ -443,8 +442,8 @@ class RingAttention(torch.autograd.Function): Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. """ sp_size = dist.get_world_size(sp_group) - sp_rank = dist.get_rank(sp_group) - tp_size = dist.get_world_size(tp_group) + dist.get_rank(sp_group) + dist.get_world_size(tp_group) if inner_ring_size is None: if torch.cuda.device_count() >= dist.get_world_size(): @@ -475,11 +474,12 @@ class RingAttention(torch.autograd.Function): world_size = dist.get_world_size() rank = dist.get_rank() - groups = int(world_size/ sp_size) + groups = int(world_size / sp_size) + # Create inner ring groups for group_id in range(groups): for i in range(inner_ring_size): - ranks = list(range(i +(group_id*sp_size), (1+group_id)*sp_size, inner_ring_size)) + ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group @@ -487,7 +487,7 @@ class RingAttention(torch.autograd.Function): # Create inter ring groups for group_id in range(groups): for i in range(num_rings): - ranks = list(range(i+group_id * num_rings, world_size, sp_size)) + ranks = list(range(i + group_id * num_rings, world_size, sp_size)) group = dist.new_group(ranks) if rank in ranks: inter_ring_group = group @@ -500,7 +500,7 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, - tp_group : Optional[dist.ProcessGroup], + tp_group: Optional[dist.ProcessGroup], attention_mask_type, cu_seqlens=None, max_seqlen=None, @@ -557,7 +557,9 @@ class RingAttention(torch.autograd.Function): if RingAttention.SP_GROUP is not sp_group: RingAttention.SP_GROUP = sp_group - inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, tp_group, inner_ring_size) + inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( + sp_group, tp_group, inner_ring_size + ) RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group else: diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8f476ab86..01d47bcd0 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -859,6 +859,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group tp_group = shard_config.tensor_parallel_process_group + if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index fc5bcac6b..b5f505dce 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -564,6 +564,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s value_states = repeat_kv(value_states, self.num_key_value_groups) tp_group = shard_config.tensor_parallel_process_group + if sp_mode == "ring_attn": attn_output = RingAttention.attention( query_states,