pull/6071/head
wangbluo 2024-09-25 18:56:18 +08:00
parent 65c8297710
commit 6fb1322db1
3 changed files with 13 additions and 9 deletions

View File

@ -4,7 +4,6 @@ from typing import Callable, Dict, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
@ -432,7 +431,7 @@ class RingAttention(torch.autograd.Function):
INTER_RING_GROUP_COPY: dist.ProcessGroup = None INTER_RING_GROUP_COPY: dist.ProcessGroup = None
@staticmethod @staticmethod
def get_double_ring_groups(sp_group,tp_group, inner_ring_size=None): def get_double_ring_groups(sp_group, tp_group, inner_ring_size=None):
""" """
Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size Get 2D ring groups for the given process group. Generally, to avoid congestion, the inner ring size
shouldn't be larger than the number of NICs on each node. shouldn't be larger than the number of NICs on each node.
@ -443,8 +442,8 @@ class RingAttention(torch.autograd.Function):
Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group. Tuple[dist.ProcessGroup, dist.ProcessGroup]: Inner-ring process group and inter-ring process group.
""" """
sp_size = dist.get_world_size(sp_group) sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group) dist.get_rank(sp_group)
tp_size = dist.get_world_size(tp_group) dist.get_world_size(tp_group)
if inner_ring_size is None: if inner_ring_size is None:
if torch.cuda.device_count() >= dist.get_world_size(): if torch.cuda.device_count() >= dist.get_world_size():
@ -475,11 +474,12 @@ class RingAttention(torch.autograd.Function):
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
groups = int(world_size/ sp_size) groups = int(world_size / sp_size)
# Create inner ring groups # Create inner ring groups
for group_id in range(groups): for group_id in range(groups):
for i in range(inner_ring_size): for i in range(inner_ring_size):
ranks = list(range(i +(group_id*sp_size), (1+group_id)*sp_size, inner_ring_size)) ranks = list(range(i + (group_id * sp_size), (1 + group_id) * sp_size, inner_ring_size))
group = dist.new_group(ranks) group = dist.new_group(ranks)
if rank in ranks: if rank in ranks:
inner_ring_group = group inner_ring_group = group
@ -487,7 +487,7 @@ class RingAttention(torch.autograd.Function):
# Create inter ring groups # Create inter ring groups
for group_id in range(groups): for group_id in range(groups):
for i in range(num_rings): for i in range(num_rings):
ranks = list(range(i+group_id * num_rings, world_size, sp_size)) ranks = list(range(i + group_id * num_rings, world_size, sp_size))
group = dist.new_group(ranks) group = dist.new_group(ranks)
if rank in ranks: if rank in ranks:
inter_ring_group = group inter_ring_group = group
@ -500,7 +500,7 @@ class RingAttention(torch.autograd.Function):
k, k,
v, v,
sp_group, sp_group,
tp_group : Optional[dist.ProcessGroup], tp_group: Optional[dist.ProcessGroup],
attention_mask_type, attention_mask_type,
cu_seqlens=None, cu_seqlens=None,
max_seqlen=None, max_seqlen=None,
@ -557,7 +557,9 @@ class RingAttention(torch.autograd.Function):
if RingAttention.SP_GROUP is not sp_group: if RingAttention.SP_GROUP is not sp_group:
RingAttention.SP_GROUP = sp_group RingAttention.SP_GROUP = sp_group
inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(sp_group, tp_group, inner_ring_size) inner_ring_group, inter_ring_group = RingAttention.get_double_ring_groups(
sp_group, tp_group, inner_ring_size
)
RingAttention.INNER_RING_GROUP = inner_ring_group RingAttention.INNER_RING_GROUP = inner_ring_group
RingAttention.INTER_RING_GROUP = inter_ring_group RingAttention.INTER_RING_GROUP = inter_ring_group
else: else:

View File

@ -859,6 +859,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group sp_group = shard_config.sequence_parallel_process_group
tp_group = shard_config.tensor_parallel_process_group tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
query, query,

View File

@ -564,6 +564,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
tp_group = shard_config.tensor_parallel_process_group tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
query_states, query_states,