mirror of https://github.com/hpcaitech/ColossalAI
250 lines
12 KiB
Python
250 lines
12 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
def get_whisper_flash_attention_forward():
|
|
|
|
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
|
|
|
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
|
|
|
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
|
|
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
|
|
|
|
def forward(
|
|
self: WhisperAttention,
|
|
hidden_states: torch.Tensor,
|
|
key_value_states: Optional[torch.Tensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
|
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
# for the decoder
|
|
is_cross_attention = key_value_states is not None
|
|
|
|
bsz, tgt_len, _ = hidden_states.size()
|
|
|
|
# get key, value proj
|
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
|
# the provided `key_value_states` to support prefix tuning
|
|
if (is_cross_attention and past_key_value is not None
|
|
and past_key_value[0].shape[1] == key_value_states.shape[1]):
|
|
# reuse k,v, cross_attentions
|
|
key_states = past_key_value[0]
|
|
value_states = past_key_value[1]
|
|
elif is_cross_attention:
|
|
# cross_attentions
|
|
key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
|
value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
|
|
elif past_key_value is not None:
|
|
# reuse k, v, self_attention
|
|
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
|
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
|
else:
|
|
# self_attention
|
|
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
|
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
|
|
|
|
if self.is_decoder:
|
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
# key/value_states (first "if" case)
|
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
past_key_value = (key_states, value_states)
|
|
|
|
# get query proj
|
|
query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
|
|
|
|
src_len = key_states.size(1)
|
|
if layer_head_mask is not None:
|
|
if layer_head_mask.size() != (self.num_heads,):
|
|
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
|
f" {layer_head_mask.size()}")
|
|
|
|
attn_type = None
|
|
flash_attention_mask = None
|
|
|
|
if self.is_decoder:
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
)
|
|
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
|
|
attn_type = AttnMaskType.paddedcausal
|
|
|
|
attention = ColoAttention(embed_dim=self.embed_dim,
|
|
num_heads=self.num_heads,
|
|
dropout=self.dropout,
|
|
scale=self.scaling)
|
|
attn_output = attention(query_states,
|
|
key_states,
|
|
value_states,
|
|
attn_mask=flash_attention_mask,
|
|
attn_mask_type=attn_type)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, None, past_key_value
|
|
|
|
return forward
|
|
|
|
|
|
def get_jit_fused_whisper_encoder_layer_forward():
|
|
|
|
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
|
|
|
|
def forward(
|
|
self: WhisperEncoderLayer,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
layer_head_mask: torch.Tensor,
|
|
output_attentions: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
|
`(encoder_attention_heads,)`.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
hidden_states, attn_weights, _ = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any()
|
|
or torch.isnan(hidden_states).any()):
|
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
return forward
|
|
|
|
|
|
def get_jit_fused_whisper_decoder_layer_forward():
|
|
|
|
from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer
|
|
|
|
def forward(
|
|
self: WhisperDecoderLayer,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
|
`(encoder_attention_heads,)`.
|
|
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
|
size `(decoder_attention_heads,)`.
|
|
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
# Self Attention
|
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
|
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
past_key_value=self_attn_past_key_value,
|
|
attention_mask=attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
# Cross-Attention Block
|
|
cross_attn_present_key_value = None
|
|
cross_attn_weights = None
|
|
if encoder_hidden_states is not None:
|
|
residual = hidden_states
|
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
|
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
|
hidden_states=hidden_states,
|
|
key_value_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
layer_head_mask=cross_attn_layer_head_mask,
|
|
past_key_value=cross_attn_past_key_value,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
# add cross-attn to positions 3,4 of present_key_value tuple
|
|
present_key_value = present_key_value + cross_attn_present_key_value
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights, cross_attn_weights)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
return forward
|