Browse Source

fix

pull/6071/head
wangbluo 1 month ago
parent
commit
6be9862aaf
  1. 1
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  2. 14
      colossalai/shardformer/layer/attn.py
  3. 1
      colossalai/shardformer/modeling/gpt2.py
  4. 1
      colossalai/shardformer/modeling/llama.py
  5. 1
      colossalai/shardformer/shard/shard_config.py
  6. 11
      tests/test_shardformer/test_layer/test_ring_attn.py

1
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -1181,6 +1181,7 @@ class HybridParallelPlugin(PipelinePluginBase):
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
pg_mesh=self.pg_mesh,
sp_axis=self.sp_axis,
)
self.amp_config = dict(

14
colossalai/shardformer/layer/attn.py

@ -431,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod
def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None):
def get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size=None):
"""
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node.
@ -441,6 +441,9 @@ class RingAttention(torch.autograd.Function):
Returns:
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
"""
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
sp_group = pg_mesh.get_group_along_axis(sp_axis)
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
@ -496,6 +499,7 @@ class RingAttention(torch.autograd.Function):
return_softmax=False,
inner_ring_size=None,
pg_mesh=None,
sp_axis=None,
**kwargs,
):
"""
@ -506,7 +510,7 @@ class RingAttention(torch.autograd.Function):
q (torch.Tensor): Query tensor. Shape should be [B, nHeads, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, nHeads, Sq, Sq, D]
v (torch.Tensor): Value tensor. Shape should be [B, nHeads, Sq, Sq, D]
sp_group (Optional[dist.ProcessGroup]): Process group for sequence parallelism
sp_axis (Optional[int]): Sp axis for the global pg mesh.
sp_tream (torch.cuda.Stream): An different stream for output correction.
cu_seqlens (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
@ -539,13 +543,13 @@ class RingAttention(torch.autograd.Function):
attention_mask_type in RingAttention.SUPPORTED_MASK_TYPES
), f"Mask type {attention_mask_type} is not supported yet."
assert pg_mesh is not None, f"Error: The pg mesh is None! please check the process group initialization."
clone_pg = lambda pg: dist.new_group(dist.get_process_group_ranks(pg))
if inner_ring_size != None:
RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(
sp_group, pg_mesh, inner_ring_size
)
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_axis, pg_mesh, inner_ring_size)
RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group
else:

1
colossalai/shardformer/modeling/gpt2.py

@ -869,6 +869,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
sp_axis=shard_config.sp_axis,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)

1
colossalai/shardformer/modeling/llama.py

@ -572,6 +572,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
sp_axis=shard_config.sp_axis,
)
elif shard_config.enable_flash_attention:

1
colossalai/shardformer/shard/shard_config.py

@ -51,6 +51,7 @@ class ShardConfig:
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# For ring attention
sp_axis: Optional[int] = None
pg_mesh: Optional[int] = None
inner_ring_size: Optional[int] = None
# for moe related

11
tests/test_shardformer/test_layer/test_ring_attn.py

@ -24,6 +24,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
sp_group = dist.group.WORLD
dp_size, pp_size, tp_size = 1, 1, 1
sp_size = dist.get_world_size()
sp_axis = 2
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
@ -40,7 +41,15 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
# Ring attention vs single GPU
ring_out, ring_lse = RingAttention.attention(
q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh
q,
k,
v,
sp_group,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=inner_ring_size,
pg_mesh=pg_mesh,
sp_axis=sp_axis,
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(

Loading…
Cancel
Save