fix the attn

pull/6071/head
wangbluo 2024-09-25 18:51:03 +08:00
parent cfd9eda628
commit 65c8297710
2 changed files with 3 additions and 1 deletions

View File

@ -500,7 +500,7 @@ class RingAttention(torch.autograd.Function):
k, k,
v, v,
sp_group, sp_group,
tp_group, tp_group : Optional[dist.ProcessGroup],
attention_mask_type, attention_mask_type,
cu_seqlens=None, cu_seqlens=None,
max_seqlen=None, max_seqlen=None,

View File

@ -858,12 +858,14 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group sp_group = shard_config.sequence_parallel_process_group
tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
query, query,
key, key,
value, value,
sp_group, sp_group,
tp_group,
**attention_mask, **attention_mask,
dropout_p=dropout_p, dropout_p=dropout_p,
scale=scale, scale=scale,