pull/6071/head
wangbluo 1 month ago
parent fd92789af2
commit bc7eeade33

@ -857,7 +857,6 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
dropout_p = self.attn_dropout.p if self.training else 0.0 dropout_p = self.attn_dropout.p if self.training else 0.0
sp_mode = shard_config.sequence_parallelism_mode sp_mode = shard_config.sequence_parallelism_mode
shard_config.sequence_parallel_process_group
if sp_mode == "ring_attn": if sp_mode == "ring_attn":
attn_output = RingAttention.attention( attn_output = RingAttention.attention(
query, query,

Loading…
Cancel
Save