mirror of https://github.com/hpcaitech/ColossalAI
fix
parent
8ff7d0c780
commit
3dc08c8a5a
|
@ -1180,7 +1180,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
fp8_communication=fp8_communication,
|
||||
inner_ring_size=inner_ring_size,
|
||||
pg_mesh=self.pg_mesh,
|
||||
)
|
||||
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
|
|
@ -7,7 +7,6 @@ import torch.distributed as dist
|
|||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.kernel.kernel_loader import (
|
||||
FlashAttentionDaoLoader,
|
||||
FlashAttentionForFloatAndCustomMaskLoader,
|
||||
|
@ -432,7 +431,7 @@ class RingAttention(torch.autograd.Function):
|
|||
INTER_RING_GROUP_COPY: dist.ProcessGroup = None
|
||||
|
||||
@staticmethod
|
||||
def get_double_ring_groups(sp_group, inner_ring_size=None):
|
||||
def get_double_ring_groups(sp_group, pg_mesh, inner_ring_size=None):
|
||||
"""
|
||||
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
|
||||
shouldn't be larger than the number of NICs on each node.
|
||||
|
@ -465,20 +464,17 @@ class RingAttention(torch.autograd.Function):
|
|||
inner_ring_group = None
|
||||
inter_ring_group = None
|
||||
|
||||
inter_axis, inner_axis = 0, 1
|
||||
pg_mesh = ProcessGroupMesh(num_rings, inner_ring_size)
|
||||
|
||||
# Create inner ring groups
|
||||
for i in range(inner_ring_size):
|
||||
ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size))
|
||||
group = pg_mesh.get_group_along_axis(inner_axis)
|
||||
group = pg_mesh.get_group_along_axis(2, ranks)
|
||||
if sp_rank in ranks:
|
||||
inner_ring_group = group
|
||||
|
||||
# Create inter ring groups
|
||||
for i in range(num_rings):
|
||||
ranks = list(range(i, sp_size, num_rings))
|
||||
group = pg_mesh.get_group_along_axis(inter_axis)
|
||||
group = pg_mesh.get_group_along_axis(2, ranks)
|
||||
if sp_rank in ranks:
|
||||
inter_ring_group = group
|
||||
|
||||
|
@ -499,6 +495,7 @@ class RingAttention(torch.autograd.Function):
|
|||
deterministic=False,
|
||||
return_softmax=False,
|
||||
inner_ring_size=None,
|
||||
pg_mesh=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -546,7 +543,9 @@ class RingAttention(torch.autograd.Function):
|
|||
|
||||
if inner_ring_size != None:
|
||||
RingAttention.SP_GROUP = sp_group
|
||||
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, inner_ring_size)
|
||||
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(
|
||||
sp_group, pg_mesh, inner_ring_size
|
||||
)
|
||||
RingAttention.INNER_RING_GROUP = inner_ring_group
|
||||
RingAttention.INTER_RING_GROUP = inter_ring_group
|
||||
else:
|
||||
|
|
|
@ -868,6 +868,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
|
|||
dropout_p=dropout_p,
|
||||
scale=scale,
|
||||
inner_ring_size=shard_config.inner_ring_size,
|
||||
pg_mesh=shard_config.pg_mesh,
|
||||
)
|
||||
else:
|
||||
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
|
||||
|
|
|
@ -571,6 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||
sp_group,
|
||||
**attention_mask,
|
||||
inner_ring_size=shard_config.inner_ring_size,
|
||||
pg_mesh=shard_config.pg_mesh,
|
||||
)
|
||||
|
||||
elif shard_config.enable_flash_attention:
|
||||
|
|
|
@ -51,6 +51,7 @@ class ShardConfig:
|
|||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# For ring attention
|
||||
pg_mesh: Optional[int] = None
|
||||
inner_ring_size: Optional[int] = None
|
||||
# for moe related
|
||||
moe_dp_group: Optional[ProcessGroup] = None
|
||||
|
|
|
@ -5,6 +5,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_fu
|
|||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.shardformer.layer import AttnMaskType
|
||||
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
|
||||
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
|
||||
|
@ -21,6 +22,9 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
|
|||
torch.cuda.manual_seed(2)
|
||||
device = get_current_device()
|
||||
sp_group = dist.group.WORLD
|
||||
dp_size, pp_size, tp_size = 1, 1, 1
|
||||
sp_size = dist.get_world_size()
|
||||
pg_mesh = ProcessGroupMesh(dp_size, pp_size, sp_size, tp_size)
|
||||
# Some outliers may seem large, but our errors are still lower than
|
||||
# than Megatron-LM context parallel's
|
||||
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
|
||||
|
@ -36,13 +40,7 @@ def check_ring_attn(seq_len, bs, nheads, d, dtype, inner_ring_size):
|
|||
|
||||
# Ring attention vs single GPU
|
||||
ring_out, ring_lse = RingAttention.attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sp_group,
|
||||
AttnMaskType.CAUSAL,
|
||||
return_softmax=True,
|
||||
inner_ring_size=inner_ring_size,
|
||||
q, k, v, sp_group, AttnMaskType.CAUSAL, return_softmax=True, inner_ring_size=inner_ring_size, pg_mesh=pg_mesh
|
||||
)
|
||||
ring_out = ring_out.transpose(1, 2)
|
||||
out, lse, _ = flash_attn_qkvpacked_func(
|
||||
|
@ -125,6 +123,7 @@ def check_packed_seq(seqlen, bs, nheads, d, dtype):
|
|||
**mask_info,
|
||||
pad_output=False,
|
||||
return_softmax=True,
|
||||
pg_mesh=dist.group.WORLD,
|
||||
# deterministic=True
|
||||
)
|
||||
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
|
||||
|
|
Loading…
Reference in New Issue