mirror of https://github.com/hpcaitech/ColossalAI
[tests] fix t5 test. (#5322)
* [ci] fix shardformer tests. (#5255) * fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * fix t5 test --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>pull/5330/head
parent
a6709afe66
commit
388179f966
|
@ -593,10 +593,6 @@ class T5PipelineForwards:
|
|||
|
||||
|
||||
def get_t5_flash_attention_forward():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.t5.modeling_t5 import T5Attention
|
||||
|
||||
def forward(
|
||||
|
@ -632,11 +628,11 @@ def get_t5_flash_attention_forward():
|
|||
|
||||
def shape(states):
|
||||
"""projection"""
|
||||
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
|
||||
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||
|
||||
def unshape(states):
|
||||
"""reshape"""
|
||||
return states.view(batch_size, -1, self.inner_dim)
|
||||
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
||||
|
||||
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
||||
"""projects hidden states correctly to key/query states"""
|
||||
|
@ -653,8 +649,8 @@ def get_t5_flash_attention_forward():
|
|||
if key_value_states is None:
|
||||
# self-attn
|
||||
# (batch_size, n_heads, key_length, dim_per_head)
|
||||
hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
|
||||
elif past_key_value.shape[1] != key_value_states.shape[1]:
|
||||
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
||||
elif past_key_value.shape[2] != key_value_states.shape[1]:
|
||||
# checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
# cross-attn
|
||||
|
@ -701,10 +697,15 @@ def get_t5_flash_attention_forward():
|
|||
else:
|
||||
position_bias_masked = position_bias
|
||||
|
||||
position_bias_masked = position_bias_masked.contiguous()
|
||||
attn_output = me_attention(
|
||||
query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0
|
||||
)
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=position_bias_masked,
|
||||
dropout_p=self.dropout,
|
||||
scale=1.0,
|
||||
)
|
||||
attn_output = unshape(attn_output)
|
||||
attn_output = self.o(attn_output)
|
||||
|
||||
|
|
Loading…
Reference in New Issue