Browse Source

fix

pull/6071/head
wangbluo 1 month ago
parent
commit
3dc08c8a5a
  1. 2
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  2. 15
      colossalai/shardformer/layer/attn.py
  3. 1
      colossalai/shardformer/modeling/gpt2.py
  4. 1
      colossalai/shardformer/modeling/llama.py
  5. 1
      colossalai/shardformer/shard/shard_config.py
  6. 13
      tests/test_shardformer/test_layer/test_ring_attn.py

2
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -1180,7 +1180,9 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,

15
colossalai/shardformer/layer/attn.py

@ -7,7 +7,6 @@ import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from colossalai.cluster import ProcessGroupMesh
from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader,
FlashAttentionForFloatAndCustomMaskLoader,
@ -432,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
def get_double_ring_groups(sp_group, inner_ring_size=None):
def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
@ -465,20 +464,17 @@ class RingAttention(torch.autograd.Function):
inner_ring_group = None
inter_ring_group = None
inter_axis, inner_axis = 0, 1
pg_mesh = ProcessGroupMesh(num_rings, inner_ring_size)
# Create inner ring groups
for i in range(inner_ring_size):
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
group = pg_mesh.get_group_along_axis(inner_axis)
group = pg_mesh.get_group_along_axis(2, ranks)
if sp_rank in ranks:
inner_ring_group = group
# Create inter ring groups
for i in range(num_rings):
ranks = list(range(i, sp_size, num_rings))
group = pg_mesh.get_group_along_axis(inter_axis)
group = pg_mesh.get_group_along_axis(2, ranks)
if sp_rank in ranks:
inter_ring_group = group
@ -499,6 +495,7 @@ class RingAttention(torch.autograd.Function):
deterministic=False,
return_softmax=False,
inner_ring_size=None,
pg_mesh=None,
**kwargs,
):
"""
@ -546,7 +543,9 @@ class RingAttention(torch.autograd.Function):
if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(
sp_group, pg_mesh, inner_ring_size
)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:

1
colossalai/shardformer/modeling/gpt2.py

@ -868,6 +868,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p=dropout_p,
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

1
colossalai/shardformer/modeling/llama.py

@ -571,6 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
sp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
)
elif shard_config.enable_flash_attention:

1
colossalai/shardformer/shard/shard_config.py

@ -51,6 +51,7 @@ class ShardConfig:
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None
# for moe related
moe_dp_group: Optional[ProcessGroup] = None

13
tests/test_shardformer/test_layer/test_ring_attn.py

@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
@ -21,6 +22,9 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size()
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
@ -36,13 +40,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
# Ring attention vs single GPU
ring_out, ring_lse = RingAttention.attention(
q,
k,
v,
sp_group,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=inner_ring_size,
q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
@ -125,6 +123,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
**mask_info,
pad_output=False,
return_softmax=True,
pg_mesh=dist.group.WORLD,
# deterministic=True
)
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)

Loading…
Cancel
Save