fix the attn

pull/6071/head
wangbluo 2 months ago
parent cfd9eda628
commit 65c8297710

@ -500,7 +500,7 @@ class RingAttention(torch.autograd.Function):
k,
v,
sp_group,
tp_group,
tp_group : Optional[dist.ProcessGroup],
attention_mask_type,
cu_seqlens=None,
max_seqlen=None,

@ -858,12 +858,14 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
tp_group = shard_config.tensor_parallel_process_group
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query,
key,
value,
sp_group,
tp_group,
**attention_mask,
dropout_p=dropout_p,
scale=scale,

Loading…
Cancel
Save