|
|
@ -771,11 +771,12 @@ def get_gpt2_flash_attention_forward(): |
|
|
|
attn_mask_type = AttnMaskType.causal |
|
|
|
attn_mask_type = AttnMaskType.causal |
|
|
|
flash_attention_mask = None |
|
|
|
flash_attention_mask = None |
|
|
|
if attention_mask != None: |
|
|
|
if attention_mask != None: |
|
|
|
if attn_mask_type == AttnMaskType.causal: |
|
|
|
|
|
|
|
attn_mask_type == AttnMaskType.paddedcausal |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
attn_mask_type = AttnMaskType.padding |
|
|
|
|
|
|
|
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() |
|
|
|
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() |
|
|
|
|
|
|
|
if not torch.all(flash_attention_mask): |
|
|
|
|
|
|
|
if attn_mask_type == AttnMaskType.causal: |
|
|
|
|
|
|
|
attn_mask_type == AttnMaskType.paddedcausal |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
attn_mask_type = AttnMaskType.padding |
|
|
|
|
|
|
|
|
|
|
|
scale = value.size(-1) ** -0.5 |
|
|
|
scale = value.size(-1) ** -0.5 |
|
|
|
if self.scale_attn_by_inverse_layer_idx: |
|
|
|
if self.scale_attn_by_inverse_layer_idx: |
|
|
|