pull/6071/head
wangbluo 2024-10-14 18:01:53 +08:00
parent d891e50617
commit 23199e34cc
4 changed files with 19 additions and 55 deletions

View File

@ -7,6 +7,7 @@ import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from colossalai.cluster import ProcessGroupMesh
from colossalai.kernel.kernel_loader import ( from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader, FlashAttentionDaoLoader,
FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionForFloatAndCustomMaskLoader,
@ -431,7 +432,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod @staticmethod
def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None): def get_double_ring_groups(sp_group, inner_ring_size=None):
""" """
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node. shouldn't be larger than the number of NICs on each node.
@ -442,7 +443,6 @@ class RingAttention(torch.autograd.Function):
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
""" """
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
assert inner_ring_size is not None assert inner_ring_size is not None
@ -465,43 +465,20 @@ class RingAttention(torch.autograd.Function):
inner_ring_group = None inner_ring_group = None
inter_ring_group = None inter_ring_group = None
world_size = dist.get_world_size() inter_axis, inner_axis = 0, 1
rank = dist.get_rank() pg_mesh = ProcessGroupMesh(num_rings, inner_ring_size)
num_ring_size = world_size // num_rings
if tp_size > 1:
# Create inner ring groups
ranks = []
for i in range(num_rings):
start = i * num_ring_size
end = (i + 1) * num_ring_size
for idx in range(start, end):
inner_rank = [idx + k * tp_size for k in range(inner_ring_size) if idx + k * tp_size < end]
if len(inner_rank) == inner_ring_size and inner_rank not in ranks:
ranks.append(inner_rank)
group = dist.new_group(inner_rank)
if rank in inner_rank:
inner_ring_group = group
# Create inter ring groups
for i in range(num_ring_size):
inter_rank = [i + j * num_ring_size for j in range(num_rings)]
group = dist.new_group(inter_rank)
if rank in inter_rank:
inter_ring_group = group
else:
# Create inner ring groups # Create inner ring groups
for i in range(inner_ring_size): for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks) group = pg_mesh.get_group_along_axis(inner_axis)
if sp_rank in ranks: if sp_rank in ranks:
inner_ring_group = group inner_ring_group = group
# Create inter ring groups # Create inter ring groups
for i in range(num_rings): for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings)) ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks) group = pg_mesh.get_group_along_axis(inter_axis)
if sp_rank in ranks: if sp_rank in ranks:
inter_ring_group = group inter_ring_group = group
@ -522,7 +499,6 @@ class RingAttention(torch.autograd.Function):
deterministic=False, deterministic=False,
return_softmax=False, return_softmax=False,
inner_ring_size=None, inner_ring_size=None,
tp_group=None,
**kwargs, **kwargs,
): ):
""" """
@ -570,9 +546,7 @@ class RingAttention(torch.autograd.Function):
if inner_ring_size != None: if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups( inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
sp_group, tp_group, inner_ring_size
)
RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group
else: else:
@ -619,7 +593,6 @@ class RingAttention(torch.autograd.Function):
attention_mask_type == AttnMaskType.PADDED_CAUSAL, attention_mask_type == AttnMaskType.PADDED_CAUSAL,
inner_ring_group, inner_ring_group,
inter_ring_group, inter_ring_group,
tp_group,
) )
if attention_mask_type == AttnMaskType.PADDED_CAUSAL: if attention_mask_type == AttnMaskType.PADDED_CAUSAL:

View File

@ -858,7 +858,6 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group sp_group = shard_config.sequence_parallel_process_group
tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
@ -870,7 +869,6 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p=dropout_p, dropout_p=dropout_p,
scale=scale, scale=scale,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
tp_group=tp_group,
) )
else: else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

View File

@ -563,8 +563,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
query_states, query_states,
@ -573,7 +571,6 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
sp_group, sp_group,
**attention_mask, **attention_mask,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
tp_group=tp_group,
) )
elif shard_config.enable_flash_attention: elif shard_config.enable_flash_attention:

View File

@ -5,7 +5,6 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@ -21,10 +20,8 @@ from colossalai.utils import get_current_device
def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1): def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1):
torch.cuda.manual_seed(2) torch.cuda.manual_seed(2)
device = get_current_device() device = get_current_device()
sp_axis, tp_axis = 0, 1 sp_group = dist.group.WORLD
pg_mesh = ProcessGroupMesh(sp_size, tp_size) sp_size = dist.get_world_size()
tp_group = pg_mesh.get_group_along_axis(tp_axis)
sp_group = pg_mesh.get_group_along_axis(sp_axis)
# Some outliers may seem large, but our errors are still lower than # Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's # than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@ -47,7 +44,6 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, sp_size, tp_size=1):
AttnMaskType.CAUSAL, AttnMaskType.CAUSAL,
return_softmax=True, return_softmax=True,
inner_ring_size=max(2, sp_size // 2), inner_ring_size=max(2, sp_size // 2),
tp_group=tp_group,
) )
ring_out = ring_out.transpose(1, 2) ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func( out, lse, _ = flash_attn_qkvpacked_func(