|
|
|
@ -593,10 +593,6 @@ class T5PipelineForwards:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_t5_flash_attention_forward():
|
|
|
|
|
try:
|
|
|
|
|
from xformers.ops import memory_efficient_attention as me_attention
|
|
|
|
|
except:
|
|
|
|
|
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
|
|
|
|
from transformers.models.t5.modeling_t5 import T5Attention
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
@ -632,11 +628,11 @@ def get_t5_flash_attention_forward():
|
|
|
|
|
|
|
|
|
|
def shape(states):
|
|
|
|
|
"""projection"""
|
|
|
|
|
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
|
|
|
|
|
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
def unshape(states):
|
|
|
|
|
"""reshape"""
|
|
|
|
|
return states.view(batch_size, -1, self.inner_dim)
|
|
|
|
|
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
|
|
|
|
|
|
|
|
|
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
|
|
|
|
"""projects hidden states correctly to key/query states"""
|
|
|
|
@ -653,8 +649,8 @@ def get_t5_flash_attention_forward():
|
|
|
|
|
if key_value_states is None:
|
|
|
|
|
# self-attn
|
|
|
|
|
# (batch_size, n_heads, key_length, dim_per_head)
|
|
|
|
|
hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
|
|
|
|
|
elif past_key_value.shape[1] != key_value_states.shape[1]:
|
|
|
|
|
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
|
|
|
|
elif past_key_value.shape[2] != key_value_states.shape[1]:
|
|
|
|
|
# checking that the `sequence_length` of the `past_key_value` is the same as
|
|
|
|
|
# the provided `key_value_states` to support prefix tuning
|
|
|
|
|
# cross-attn
|
|
|
|
@ -701,9 +697,14 @@ def get_t5_flash_attention_forward():
|
|
|
|
|
else:
|
|
|
|
|
position_bias_masked = position_bias
|
|
|
|
|
|
|
|
|
|
position_bias_masked = position_bias_masked.contiguous()
|
|
|
|
|
attn_output = me_attention(
|
|
|
|
|
query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0
|
|
|
|
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
|
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
|
|
|
query_states,
|
|
|
|
|
key_states,
|
|
|
|
|
value_states,
|
|
|
|
|
attn_mask=position_bias_masked,
|
|
|
|
|
dropout_p=self.dropout,
|
|
|
|
|
scale=1.0,
|
|
|
|
|
)
|
|
|
|
|
attn_output = unshape(attn_output)
|
|
|
|
|
attn_output = self.o(attn_output)
|
|
|
|
|