|
|
|
@ -593,10 +593,6 @@ class T5PipelineForwards:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_t5_flash_attention_forward(): |
|
|
|
|
try: |
|
|
|
|
from xformers.ops import memory_efficient_attention as me_attention |
|
|
|
|
except: |
|
|
|
|
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") |
|
|
|
|
from transformers.models.t5.modeling_t5 import T5Attention |
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
@ -632,11 +628,11 @@ def get_t5_flash_attention_forward():
|
|
|
|
|
|
|
|
|
|
def shape(states): |
|
|
|
|
"""projection""" |
|
|
|
|
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) |
|
|
|
|
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
def unshape(states): |
|
|
|
|
"""reshape""" |
|
|
|
|
return states.view(batch_size, -1, self.inner_dim) |
|
|
|
|
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
|
|
|
|
|
|
|
|
|
def project(hidden_states, proj_layer, key_value_states, past_key_value): |
|
|
|
|
"""projects hidden states correctly to key/query states""" |
|
|
|
@ -653,8 +649,8 @@ def get_t5_flash_attention_forward():
|
|
|
|
|
if key_value_states is None: |
|
|
|
|
# self-attn |
|
|
|
|
# (batch_size, n_heads, key_length, dim_per_head) |
|
|
|
|
hidden_states = torch.cat([past_key_value, hidden_states], dim=1) |
|
|
|
|
elif past_key_value.shape[1] != key_value_states.shape[1]: |
|
|
|
|
hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
|
|
|
|
elif past_key_value.shape[2] != key_value_states.shape[1]: |
|
|
|
|
# checking that the `sequence_length` of the `past_key_value` is the same as |
|
|
|
|
# the provided `key_value_states` to support prefix tuning |
|
|
|
|
# cross-attn |
|
|
|
@ -701,10 +697,15 @@ def get_t5_flash_attention_forward():
|
|
|
|
|
else: |
|
|
|
|
position_bias_masked = position_bias |
|
|
|
|
|
|
|
|
|
position_bias_masked = position_bias_masked.contiguous() |
|
|
|
|
attn_output = me_attention( |
|
|
|
|
query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 |
|
|
|
|
) |
|
|
|
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True): |
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
|
|
|
query_states, |
|
|
|
|
key_states, |
|
|
|
|
value_states, |
|
|
|
|
attn_mask=position_bias_masked, |
|
|
|
|
dropout_p=self.dropout, |
|
|
|
|
scale=1.0, |
|
|
|
|
) |
|
|
|
|
attn_output = unshape(attn_output) |
|
|
|
|
attn_output = self.o(attn_output) |
|
|
|
|
|
|
|
|
|