diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index f67aa84e4..dcb178520 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -593,10 +593,6 @@ class T5PipelineForwards: def get_t5_flash_attention_forward(): - try: - from xformers.ops import memory_efficient_attention as me_attention - except: - raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.") from transformers.models.t5.modeling_t5 import T5Attention def forward( @@ -632,11 +628,11 @@ def get_t5_flash_attention_forward(): def shape(states): """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim) + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) def unshape(states): """reshape""" - return states.view(batch_size, -1, self.inner_dim) + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" @@ -653,8 +649,8 @@ def get_t5_flash_attention_forward(): if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=1) - elif past_key_value.shape[1] != key_value_states.shape[1]: + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: # checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning # cross-attn @@ -701,10 +697,15 @@ def get_t5_flash_attention_forward(): else: position_bias_masked = position_bias - position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention( - query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 - ) + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=position_bias_masked, + dropout_p=self.dropout, + scale=1.0, + ) attn_output = unshape(attn_output) attn_output = self.o(attn_output)