Browse Source

Merge pull request #6071 from wangbluo/ring_attention

[Ring Attention] fix the 2d ring attn when using multiple machine
pull/5939/merge
Wang Binluo 1 month ago committed by GitHub
parent
commit
dcd41d0973
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  2. 40
      colossalai/shardformer/layer/attn.py
  3. 4
      colossalai/shardformer/modeling/gpt2.py
  4. 3
      colossalai/shardformer/modeling/llama.py
  5. 2
      colossalai/shardformer/shard/shard_config.py
  6. 20
      tests/test_shardformer/test_layer/test_ring_attn.py

3
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -1177,7 +1177,10 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,

40
colossalai/shardformer/layer/attn.py

@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
@ -441,21 +441,17 @@ class RingAttention(torch.autograd.Function):
Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
"""
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
sp_group = pg_mesh.get_group_along_axis(sp_axis)
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size():
# single node, no need to consider NICs
return sp_group, sp_group
if sp_size <= 4:
inner_ring_size = min(2, sp_size)
else:
inner_ring_size = min(4, sp_size)
else:
assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
assert inner_ring_size is not None
assert (
inner_ring_size <= sp_size and sp_size % inner_ring_size == 0
), f"Error: sp_size {sp_size} should be divisible by inner_ring_size {inner_ring_size}"
if inner_ring_size == sp_size:
return sp_group, sp_group
@ -474,14 +470,14 @@ class RingAttention(torch.autograd.Function):
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = dist.new_group(ranks)
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inner_ring_group = group
# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = dist.new_group(ranks)
group = pg_mesh.get_group_along_axis(sp_axis, ranks)
if sp_rank in ranks:
inter_ring_group = group
@ -492,7 +488,7 @@ class RingAttention(torch.autograd.Function):
q, # (B, H, Sq, D)
k,
v,
sp_group,
sp_axis,
attention_mask_type,
cu_seqlens=None,
max_seqlen=None,
@ -502,6 +498,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False,
return_softmax=False,
inner_ring_size=None,
pg_mesh=None,
**kwargs,
):
"""
@ -512,7 +509,7 @@ class RingAttention(torch.autograd.Function):
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
sp_axis (Optional[int]): Sp axis for the global pg mesh.
sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
@ -537,7 +534,6 @@ class RingAttention(torch.autograd.Function):
RingAttention.ATTN_DONE = torch.cuda.Event()
if RingAttention.SP_STREAM is None:
RingAttention.SP_STREAM = torch.cuda.Stream()
assert (
q.shape[2] == k.shape[2]
), "Q, K and V having different sequence lengths (inference or cross-attn)\
@ -546,11 +542,13 @@ class RingAttention(torch.autograd.Function):
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet."
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
if RingAttention.SP_GROUP is not sp_group:
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
sp_group = pg_mesh.get_group_along_axis(sp_axis)
if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:

4
colossalai/shardformer/modeling/gpt2.py

@ -857,17 +857,17 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p = self.attn_dropout.p if self.training else 0.0
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query,
key,
value,
sp_group,
sp_axis=shard_config.sp_axis,
**attention_mask,
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

3
colossalai/shardformer/modeling/llama.py

@ -569,9 +569,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
query_states,
key_states,
value_states,
sp_group,
sp_axis=shard_config.sp_axis,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
elif shard_config.enable_flash_attention:

2
colossalai/shardformer/shard/shard_config.py

@ -49,6 +49,8 @@ class ShardConfig:
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention
sp_axis: Optional[int] = None
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None
# for moe related
moe_dp_group: Optional[ProcessGroup] = None

20
tests/test_shardformer/test_layer/test_ring_attn.py

@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@ -17,11 +18,14 @@ from colossalai.utils import get_current_device
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype):
def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size()
sp_axis = 2
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@ -40,11 +44,11 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype):
q,
k,
v,
sp_group,
sp_axis,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=max(2, sp_size // 2),
# inner_ring_size=4
inner_ring_size=inner_ring_size,
pg_mesh=pg_mesh,
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
@ -83,6 +87,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
sp_axis = 2
atol = rtol = 7e-3
torch.cuda.manual_seed(2)
# Prepare varlen attention mask
@ -123,10 +128,11 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
q_ring,
k_ring,
v_ring,
sp_group,
sp_axis,
**mask_info,
pad_output=False,
return_softmax=True,
pg_mesh=ProcessGroupMesh(1, 1, sp_size, 1),
# deterministic=True
)
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
@ -161,12 +167,12 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq()
check_ring_attn()
check_ring_attn(inner_ring_size=None)
def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn()
check_ring_attn(inner_ring_size=2)
@rerun_if_address_is_in_use()

Loading…
Cancel
Save