pull/6071/head
wangbluo 2024-10-15 11:01:34 +08:00
parent 8ff7d0c780
commit 3dc08c8a5a
6 changed files with 18 additions and 15 deletions

View File

@ -1180,7 +1180,9 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication, fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size, inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
) )
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
growth_factor=growth_factor, growth_factor=growth_factor,

View File

@ -7,7 +7,6 @@ import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from colossalai.cluster import ProcessGroupMesh
from colossalai.kernel.kernel_loader import ( from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader, FlashAttentionDaoLoader,
FlashAttentionForFloatAndCustomMaskLoader, FlashAttentionForFloatAndCustomMaskLoader,
@ -432,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod @staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None): def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None):
""" """
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node. shouldn't be larger than the number of NICs on each node.
@ -465,20 +464,17 @@ class RingAttention(torch.autograd.Function):
inner_ring_group = None inner_ring_group = None
inter_ring_group = None inter_ring_group = None
inter_axis, inner_axis = 0, 1
pg_mesh = ProcessGroupMesh(num_rings, inner_ring_size)
# Create inner ring groups # Create inner ring groups
for i in range(inner_ring_size): for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = pg_mesh.get_group_along_axis(inner_axis) group = pg_mesh.get_group_along_axis(2, ranks)
if sp_rank in ranks: if sp_rank in ranks:
inner_ring_group = group inner_ring_group = group
# Create inter ring groups # Create inter ring groups
for i in range(num_rings): for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings)) ranks = list(range(i, sp_size, num_rings))
group = pg_mesh.get_group_along_axis(inter_axis) group = pg_mesh.get_group_along_axis(2, ranks)
if sp_rank in ranks: if sp_rank in ranks:
inter_ring_group = group inter_ring_group = group
@ -499,6 +495,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False, deterministic=False,
return_softmax=False, return_softmax=False,
inner_ring_size=None, inner_ring_size=None,
pg_mesh=None,
**kwargs, **kwargs,
): ):
""" """
@ -546,7 +543,9 @@ class RingAttention(torch.autograd.Function):
if inner_ring_size != None: if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size) inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(
sp_group, pg_mesh, inner_ring_size
)
RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group
else: else:

View File

@ -868,6 +868,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p=dropout_p, dropout_p=dropout_p,
scale=scale, scale=scale,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
) )
else: else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale) attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

View File

@ -571,6 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
sp_group, sp_group,
**attention_mask, **attention_mask,
inner_ring_size=shard_config.inner_ring_size, inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
) )
elif shard_config.enable_flash_attention: elif shard_config.enable_flash_attention:

View File

@ -51,6 +51,7 @@ class ShardConfig:
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention # For ring attention
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None inner_ring_size: Optional[int] = None
# for moe related # for moe related
moe_dp_group: Optional[ProcessGroup] = None moe_dp_group: Optional[ProcessGroup] = None

View File

@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@ -21,6 +22,9 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
torch.cuda.manual_seed(2) torch.cuda.manual_seed(2)
device = get_current_device() device = get_current_device()
sp_group = dist.group.WORLD sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size()
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than # Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's # than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215) # (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@ -36,13 +40,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
# Ring attention vs single GPU # Ring attention vs single GPU
ring_out, ring_lse = RingAttention.attention( ring_out, ring_lse = RingAttention.attention(
q, q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh
k,
v,
sp_group,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=inner_ring_size,
) )
ring_out = ring_out.transpose(1, 2) ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func( out, lse, _ = flash_attn_qkvpacked_func(
@ -125,6 +123,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
**mask_info, **mask_info,
pad_output=False, pad_output=False,
return_softmax=True, return_softmax=True,
pg_mesh=dist.group.WORLD,
# deterministic=True # deterministic=True
) )
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d) ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)