From 388179f9664d52ceed60d26e31b90981cceb0e95 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 29 Jan 2024 17:38:46 +0800 Subject: [PATCH] [tests] fix t5 test. (#5322) * [ci] fix shardformer tests. (#5255) * fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen * fix t5 test --------- Co-authored-by: Wenhao Chen --- colossalai/shardformer/modeling/t5.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index f67aa84e4..dcb178520 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -593,10 +593,6 @@ class T5PipelineForwards: def get_t5_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from transformers.models.t5.modeling_t5 import T5Attention def forward( @@ -632,11 +628,11 @@ def get_t5_flash_attention_forward(): def shape(states): """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) def unshape(states): """reshape""" - return states.view(batch_size, -1, self.inner_dim) + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -653,8 +649,8 @@ def get_t5_flash_attention_forward(): if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=1) - elif past_key_value.shape[1] != key_value_states.shape[1]: + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning # cross-attn @@ -701,10 +697,15 @@ def get_t5_flash_attention_forward(): else: position_bias_masked = position_bias - position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention( - query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 - ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout, + scale=1.0, + ) attn_output = unshape(attn_output) attn_output = self.o(attn_output)