from typing import Optional, Tuple import torch from torch import nn def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() def forward( self: WhisperAttention, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning if (is_cross_attention and past_key_value is not None and past_key_value[0].shape[1] == key_value_states.shape[1]): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) elif past_key_value is not None: # reuse k, v, self_attention key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) key_states = torch.cat([past_key_value[0], key_states], dim=1) value_states = torch.cat([past_key_value[1], value_states], dim=1) else: # self_attention key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) # get query proj query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim) src_len = key_states.size(1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}") attn_type = None flash_attention_mask = None if self.is_decoder: if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) attn_type = AttnMaskType.paddedcausal attention = ColoAttention(embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling) attn_output = attention(query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value return forward def get_jit_fused_whisper_encoder_layer_forward(): from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer def forward( self: WhisperEncoderLayer, hidden_states: torch.Tensor, attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: bool = False, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs return forward def get_jit_fused_whisper_decoder_layer_forward(): from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer def forward( self: WhisperDecoderLayer, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) # Cross-Attention Block cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value,) return outputs return forward