Update flash_attention_patch.py

To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
https://github.com/huggingface/transformers/pull/25598
pull/4918/head
Zian(Andy) Zheng 2023-10-13 16:46:33 +08:00
parent 611a5a80ca
commit 7768afbad0
1 changed files with 1 additions and 0 deletions

View File

@ -65,6 +65,7 @@ def attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.