Merge pull request #6071 from wangbluo/ring_attention

[Ring Attention] fix the 2d ring attn when using multiple machine
pull/6092/head
Wang Binluo 2024-10-15 15:17:21 +08:00 committed by GitHub
commit dcd41d0973
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 41 additions and 31 deletions

View File

@ -1177,7 +1177,10 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication, fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size, inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
) )
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,

View File

@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod @staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None): def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
""" """
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node. shouldn't be larger than the number of NICs on each node.
@ -441,21 +441,17 @@ class RingAttention(torch.autograd.Function):
Returns: Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
""" """
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
sp_group = pg_mesh.get_group_along_axis(sp_axis)
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) sp_rank = dist.get_rank(sp_group)
if inner_ring_size is None: assert inner_ring_size is not None
if torch.cuda.device_count() >= dist.get_world_size():
# single node, no need to consider NICs assert (
return sp_group, sp_group inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
if sp_size <= 4: ), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
inner_ring_size = min(2, sp_size)
else:
inner_ring_size = min(4, sp_size)
else:
assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
if inner_ring_size == sp_size: if inner_ring_size == sp_size:
return sp_group, sp_group return sp_group, sp_group
@ -474,14 +470,14 @@ class RingAttention(torch.autograd.Function):
# Create inner ring groups # Create inner ring groups
for i in range(inner_ring_size): for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks) group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks: if sp_rank in ranks:
inner_ring_group = group inner_ring_group = group
# Create inter ring groups # Create inter ring groups
for i in range(num_rings): for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings)) ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks) group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks: if sp_rank in ranks:
inter_ring_group = group inter_ring_group = group
@ -492,7 +488,7 @@ class RingAttention(torch.autograd.Function):
q, # (B, H, Sq, D) q, # (B, H, Sq, D)
k, k,
v, v,
sp_group, sp_axis,
attention_mask_type, attention_mask_type,
cu_seqlens=None, cu_seqlens=None,
max_seqlen=None, max_seqlen=None,
@ -502,6 +498,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False, deterministic=False,
return_softmax=False, return_softmax=False,
inner_ring_size=None, inner_ring_size=None,
pg_mesh=None,
**kwargs, **kwargs,
): ):
""" """
@ -512,7 +509,7 @@ class RingAttention(torch.autograd.Function):
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D] q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D] k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D] v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism sp_axis (Optional[int]): Sp axis for the global pg mesh.
sp_tream (torch.cuda.Stream): An different stream for output correction. sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q. of the sequences in the batch, used to index into q.
@ -537,7 +534,6 @@ class RingAttention(torch.autograd.Function):
RingAttention.ATTN_DONE = torch.cuda.Event() RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None: if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream() RingAttention.SP_STREAM = torch.cuda.Stream()
assert ( assert (
q.shape[2] == k.shape[2] q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\ ), "Q, K and V having different sequence lengths (inference or cross-attn)\
@ -546,11 +542,13 @@ class RingAttention(torch.autograd.Function):
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet." ), f"Mask type {attention_mask_type} is not supported yet."
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg)) assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
if RingAttention.SP_GROUP is not sp_group: clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
sp_group = pg_mesh.get_group_along_axis(sp_axis)
if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group
else: else:

View File

@ -857,17 +857,17 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p = self.attn_dropout.p if self.training else 0.0 dropout_p = self.attn_dropout.p if self.training else 0.0
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
query, query,
key, key,
value, value,
sp_group, sp_axis=shard_config.sp_axis,
**attention_mask, **attention_mask,
dropout_p=dropout_p, dropout_p=dropout_p,
scale=scale, scale=scale,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
) )
else: else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

View File

@ -569,9 +569,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states, query_states,
key_states, key_states,
value_states, value_states,
sp_group, sp_axis=shard_config.sp_axis,
**attention_mask, **attention_mask,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
) )
elif shard_config.enable_flash_attention: elif shard_config.enable_flash_attention:

View File

@ -49,6 +49,8 @@ class ShardConfig:
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention # For ring attention
sp_axis: Optional[int] = None
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None inner_ring_size: Optional[int] = None
# for moe related # for moe related
moe_dp_group: Optional[ProcessGroup] = None moe_dp_group: Optional[ProcessGroup] = None

View File

@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@ -17,11 +18,14 @@ from colossalai.utils import get_current_device
@parameterize("nheads", [5]) @parameterize("nheads", [5])
@parameterize("d", [128]) @parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16]) @parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype): def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
torch.cuda.manual_seed(2) torch.cuda.manual_seed(2)
device = get_current_device() device = get_current_device()
sp_group = dist.group.WORLD sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size() sp_size = dist.get_world_size()
sp_axis = 2
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than # Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's # than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@ -40,11 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
q, q,
k, k,
v, v,
sp_group, sp_axis,
AttnMaskType.CAUSAL, AttnMaskType.CAUSAL,
return_softmax=True, return_softmax=True,
inner_ring_size=max(2, sp_size // 2), inner_ring_size=inner_ring_size,
# inner_ring_size=4 pg_mesh=pg_mesh,
) )
ring_out = ring_out.transpose(1, 2) ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func( out, lse, _ = flash_attn_qkvpacked_func(
@ -83,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device() device = get_current_device()
sp_group = dist.group.WORLD sp_group = dist.group.WORLD
sp_size = dist.get_world_size() sp_size = dist.get_world_size()
sp_axis = 2
atol = rtol = 7e-3 atol = rtol = 7e-3
torch.cuda.manual_seed(2) torch.cuda.manual_seed(2)
# Prepare varlen attention mask # Prepare varlen attention mask
@ -123,10 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
q_ring, q_ring,
k_ring, k_ring,
v_ring, v_ring,
sp_group, sp_axis,
**mask_info, **mask_info,
pad_output=False, pad_output=False,
return_softmax=True, return_softmax=True,
pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1),
# deterministic=True # deterministic=True
) )
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
def launch_single_ring(rank, world_size, port): def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port) colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq() check_packed_seq()
check_ring_attn() check_ring_attn(inner_ring_size=None)
def launch_double_ring(rank, world_size, port): def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port) colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn() check_ring_attn(inner_ring_size=2)
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()