Browse Source

[tests] fix t5 test. (#5322)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* fix t5 test

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
pull/5330/head
flybird11111 10 months ago committed by GitHub
parent
commit
388179f966
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 25
      colossalai/shardformer/modeling/t5.py

25
colossalai/shardformer/modeling/t5.py

@ -593,10 +593,6 @@ class T5PipelineForwards:
def get_t5_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.t5.modeling_t5 import T5Attention
def forward(
@ -632,11 +628,11 @@ def get_t5_flash_attention_forward():
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.view(batch_size, -1, self.inner_dim)
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
@ -653,8 +649,8 @@ def get_t5_flash_attention_forward():
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
elif past_key_value.shape[1] != key_value_states.shape[1]:
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
@ -701,10 +697,15 @@ def get_t5_flash_attention_forward():
else:
position_bias_masked = position_bias
position_bias_masked = position_bias_masked.contiguous()
attn_output = me_attention(
query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0
)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=position_bias_masked,
dropout_p=self.dropout,
scale=1.0,
)
attn_output = unshape(attn_output)
attn_output = self.o(attn_output)

Loading…
Cancel
Save