mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1192 lines
54 KiB
1192 lines
54 KiB
import logging
|
|
import random
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import CrossEntropyLoss
|
|
from transformers.modeling_attn_mask_utils import (
|
|
_prepare_4d_causal_attention_mask,
|
|
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
)
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutput,
|
|
BaseModelOutputWithPastAndCrossAttentions,
|
|
Seq2SeqLMOutput,
|
|
Seq2SeqModelOutput,
|
|
SequenceClassifierOutput,
|
|
)
|
|
from transformers.models.whisper.modeling_whisper import (
|
|
_HIDDEN_STATES_START_POSITION,
|
|
WhisperDecoder,
|
|
WhisperEncoder,
|
|
WhisperForAudioClassification,
|
|
WhisperForConditionalGeneration,
|
|
WhisperModel,
|
|
shift_tokens_right,
|
|
)
|
|
from transformers.utils import logging
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
from colossalai.shardformer.layer import ColoAttention
|
|
from colossalai.shardformer.shard import ShardConfig
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def _get_attention_mask(
|
|
self: WhisperDecoder,
|
|
shard_config: ShardConfig,
|
|
hidden_states: torch.Tensor,
|
|
past_key_values_length: int,
|
|
attention_mask: Optional[torch.FloatTensor],
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
):
|
|
batch_size, seq_length = hidden_states.shape[:2]
|
|
mask_seq_length = past_key_values_length + seq_length
|
|
if shard_config.enable_flash_attention:
|
|
attention_mask = ColoAttention.prepare_attn_kwargs(
|
|
(batch_size, 1, seq_length, mask_seq_length),
|
|
hidden_states.dtype,
|
|
hidden_states.device,
|
|
attention_mask,
|
|
is_causal=True,
|
|
)
|
|
else:
|
|
input_shape = (batch_size, seq_length)
|
|
if self._use_flash_attention_2:
|
|
# 2d mask is passed through the layers
|
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
|
elif self._use_sdpa and head_mask is None and not output_attentions:
|
|
# output_attentions=True & head_mask can not be supported when using SDPA.
|
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
attention_mask, input_shape, hidden_states, past_key_values_length
|
|
)
|
|
else:
|
|
# 4d mask is passed through the layers
|
|
attention_mask = _prepare_4d_causal_attention_mask(
|
|
attention_mask, input_shape, hidden_states, past_key_values_length
|
|
)
|
|
return attention_mask
|
|
|
|
|
|
def get_whisper_flash_attention_forward():
|
|
from transformers.models.whisper.modeling_whisper import WhisperAttention
|
|
|
|
def forward(
|
|
self: WhisperAttention,
|
|
hidden_states: torch.Tensor,
|
|
key_value_states: Optional[torch.Tensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
attention_mask: Optional[dict] = None,
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
output_attentions: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
"""Input shape: Batch x Time x Channel"""
|
|
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
|
|
# for encoder, attention_mask is None
|
|
if attention_mask is None:
|
|
attention_mask = {}
|
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
# for the decoder
|
|
is_cross_attention = key_value_states is not None
|
|
|
|
bsz, tgt_len, _ = hidden_states.size()
|
|
|
|
# get query proj
|
|
query_states = self.q_proj(hidden_states)
|
|
# get key, value proj
|
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
|
# the provided `key_value_states` to support prefix tuning
|
|
if (
|
|
is_cross_attention
|
|
and past_key_value is not None
|
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
|
):
|
|
# reuse k,v, cross_attentions
|
|
key_states = past_key_value[0]
|
|
value_states = past_key_value[1]
|
|
elif is_cross_attention:
|
|
# cross_attentions
|
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
elif past_key_value is not None:
|
|
# reuse k, v, self_attention
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
else:
|
|
# self_attention
|
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
|
|
if self.is_decoder:
|
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
# key/value_states (first "if" case)
|
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
past_key_value = (key_states, value_states)
|
|
|
|
query_states = self._shape(query_states, tgt_len, bsz)
|
|
|
|
dropout_p = self.dropout if self.training else 0.0
|
|
attn_output = ColoAttention.attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
**attention_mask,
|
|
dropout_p=dropout_p,
|
|
scale=self.scaling,
|
|
)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
|
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
|
# partitioned across GPUs when using tensor-parallelism.
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output, None, past_key_value
|
|
|
|
return forward
|
|
|
|
|
|
def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
|
|
def forward(
|
|
self: WhisperDecoder,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
head_mask=None,
|
|
cross_attn_head_mask=None,
|
|
past_key_values=None,
|
|
inputs_embeds=None,
|
|
position_ids=None,
|
|
use_cache=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
):
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
# retrieve input_ids and inputs_embeds
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
|
|
# past_key_values_length
|
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
attention_mask = _get_attention_mask(self, shard_config, inputs_embeds, past_key_values_length, attention_mask)
|
|
|
|
# embed positions
|
|
if input_ids is not None:
|
|
positions = self.embed_positions(
|
|
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
|
)
|
|
else:
|
|
positions = self.embed_positions(
|
|
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
|
)
|
|
|
|
hidden_states = inputs_embeds + positions
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
|
|
)
|
|
use_cache = False
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
|
if attn_mask is not None:
|
|
assert attn_mask.size()[0] == (len(self.layers)), (
|
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
|
f" {head_mask.size()[0]}."
|
|
)
|
|
for idx, decoder_layer in enumerate(self.layers):
|
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
if self.training:
|
|
dropout_probability = torch.rand([])
|
|
if dropout_probability < self.layerdrop:
|
|
continue
|
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
# None for past_key_value
|
|
return module(*inputs, output_attentions, use_cache)
|
|
|
|
return custom_forward
|
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(decoder_layer),
|
|
hidden_states,
|
|
attention_mask,
|
|
encoder_hidden_states,
|
|
None, # encoder attention mask
|
|
head_mask[idx] if head_mask is not None else None,
|
|
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
|
None, # past_key_value
|
|
)
|
|
else:
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
cross_attn_layer_head_mask=(
|
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
|
),
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
if encoder_hidden_states is not None:
|
|
all_cross_attentions += (layer_outputs[2],)
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [
|
|
hidden_states,
|
|
next_cache,
|
|
all_hidden_states,
|
|
all_self_attns,
|
|
all_cross_attentions,
|
|
]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
return forward
|
|
|
|
|
|
def get_jit_fused_whisper_encoder_layer_forward():
|
|
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
|
|
|
|
def forward(
|
|
self: WhisperEncoderLayer,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
layer_head_mask: torch.Tensor,
|
|
output_attentions: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
|
`(encoder_attention_heads,)`.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
hidden_states, attn_weights, _ = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
if hidden_states.dtype == torch.float16 and (
|
|
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
|
):
|
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|
|
|
|
return forward
|
|
|
|
|
|
def get_jit_fused_whisper_decoder_layer_forward():
|
|
from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer
|
|
|
|
def forward(
|
|
self: WhisperDecoderLayer,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
layer_head_mask: Optional[torch.Tensor] = None,
|
|
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = True,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
encoder_hidden_states (`torch.FloatTensor`):
|
|
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
|
`(encoder_attention_heads,)`.
|
|
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
|
size `(decoder_attention_heads,)`.
|
|
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
"""
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
|
|
# Self Attention
|
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
|
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
past_key_value=self_attn_past_key_value,
|
|
attention_mask=attention_mask,
|
|
layer_head_mask=layer_head_mask,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
# Cross-Attention Block
|
|
cross_attn_present_key_value = None
|
|
cross_attn_weights = None
|
|
if encoder_hidden_states is not None:
|
|
residual = hidden_states
|
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
|
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
|
hidden_states=hidden_states,
|
|
key_value_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
layer_head_mask=cross_attn_layer_head_mask,
|
|
past_key_value=cross_attn_past_key_value,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
# add cross-attn to positions 3,4 of present_key_value tuple
|
|
present_key_value = present_key_value + cross_attn_present_key_value
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights, cross_attn_weights)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
return forward
|
|
|
|
|
|
class WhisperPipelineForwards:
|
|
"""
|
|
This class serves as a micro library for forward function substitution of Llama models
|
|
under pipeline setting.
|
|
"""
|
|
|
|
@staticmethod
|
|
def whisper_encoder_forward(
|
|
self: WhisperEncoder,
|
|
input_features,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_states=None,
|
|
all_attentions=None,
|
|
stage_index: Optional[List[int]] = None,
|
|
decoder_starting_stage: Optional[int] = None,
|
|
shard_config: Optional[ShardConfig] = None,
|
|
):
|
|
r"""
|
|
Args:
|
|
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
|
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
|
|
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
|
|
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
|
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
|
|
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
|
|
attention_mask (`torch.Tensor`)`, *optional*):
|
|
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
|
|
but it is not used. By default the silence in the input log mel spectrogram are ignored.
|
|
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
logging.get_logger(__name__)
|
|
|
|
stage = stage_manager.stage
|
|
at_first_stage = stage == 0
|
|
at_last_stage = stage == decoder_starting_stage - 1
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
# Process inputs if at the first stage of encoder.
|
|
if at_first_stage:
|
|
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
|
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
|
|
|
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
|
embed_pos = self.embed_positions.weight
|
|
|
|
hidden_states = inputs_embeds + embed_pos
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
|
|
encoder_states = () if output_hidden_states else None
|
|
all_attentions = () if output_attentions else None
|
|
|
|
# check if head_mask has a correct number of layers specified if desired
|
|
if head_mask is not None:
|
|
assert head_mask.size()[0] == (
|
|
len(self.layers)
|
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
|
|
|
else:
|
|
if hidden_states is None:
|
|
raise ValueError(
|
|
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder."
|
|
)
|
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
|
|
|
for idx in range(start_idx, end_idx):
|
|
encoder_layer = self.layers[idx]
|
|
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
dropout_probability = random.uniform(0, 1)
|
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
|
layer_outputs = (None, None)
|
|
else:
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
encoder_layer.__call__,
|
|
hidden_states,
|
|
None,
|
|
(head_mask[idx] if head_mask is not None else None),
|
|
output_attentions,
|
|
)
|
|
else:
|
|
layer_outputs = encoder_layer(
|
|
hidden_states,
|
|
None,
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
output_attentions=output_attentions,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if output_attentions:
|
|
all_attentions = all_attentions + (layer_outputs[1],)
|
|
|
|
if at_last_stage:
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
if output_hidden_states:
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
return BaseModelOutput(
|
|
last_hidden_state=hidden_states,
|
|
hidden_states=encoder_states,
|
|
attentions=all_attentions,
|
|
)
|
|
|
|
else:
|
|
return {"hidden_states": hidden_states, "head_mask": head_mask}
|
|
|
|
@staticmethod
|
|
def whisper_decoder_forward(
|
|
self: WhisperDecoder,
|
|
input_ids=None,
|
|
attention_mask=None,
|
|
encoder_hidden_states=None,
|
|
head_mask=None,
|
|
cross_attn_head_mask=None,
|
|
past_key_values=None,
|
|
inputs_embeds=None,
|
|
position_ids=None,
|
|
use_cache=None,
|
|
output_attentions=None,
|
|
output_hidden_states=None,
|
|
return_dict=None,
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
stage_index: Optional[List[int]] = None,
|
|
decoder_starting_stage: Optional[int] = None,
|
|
shard_config: Optional[ShardConfig] = None,
|
|
):
|
|
r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
|
provide it.
|
|
|
|
Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
|
of the decoder.
|
|
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
|
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
|
on hidden heads. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
|
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
|
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
|
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
|
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
|
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
|
embedding lookup matrix.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
for more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
logger = logging.get_logger(__name__)
|
|
stage = stage_manager.stage
|
|
at_first_stage = stage == decoder_starting_stage
|
|
at_last_stage = stage == stage_manager.num_stages - 1
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
next_decoder_cache = () if use_cache else None
|
|
|
|
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
|
if attn_mask is not None:
|
|
assert attn_mask.size()[0] == (len(self.layers)), (
|
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
|
f" {head_mask.size()[0]}."
|
|
)
|
|
|
|
# past_key_values_length
|
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
|
|
if at_first_stage:
|
|
# retrieve input_ids and inputs_embeds
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
attention_mask = _get_attention_mask(
|
|
self, shard_config, inputs_embeds, past_key_values_length, attention_mask
|
|
)
|
|
|
|
# embed positions
|
|
if input_ids is not None:
|
|
positions = self.embed_positions(
|
|
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
|
)
|
|
else:
|
|
positions = self.embed_positions(
|
|
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
|
)
|
|
|
|
hidden_states = inputs_embeds + positions
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
|
|
)
|
|
use_cache = False
|
|
|
|
else:
|
|
if hidden_states is None:
|
|
raise ValueError(
|
|
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder."
|
|
)
|
|
input_shape = hidden_states.size()[:-1]
|
|
attention_mask = _get_attention_mask(
|
|
self,
|
|
shard_config,
|
|
hidden_states,
|
|
past_key_values_length,
|
|
attention_mask,
|
|
)
|
|
|
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
|
|
|
for idx in range(start_idx, end_idx):
|
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
|
decoder_layer = self.layers[idx]
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
dropout_probability = random.uniform(0, 1)
|
|
if self.training and (dropout_probability < self.layerdrop):
|
|
continue
|
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
decoder_layer.__call__,
|
|
hidden_states,
|
|
attention_mask,
|
|
encoder_hidden_states,
|
|
None, # encoder attention mask
|
|
head_mask[idx] if head_mask is not None else None,
|
|
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
|
None, # past_key_value
|
|
output_attentions,
|
|
use_cache,
|
|
)
|
|
else:
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
|
cross_attn_layer_head_mask=(
|
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
|
),
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
if encoder_hidden_states is not None:
|
|
all_cross_attentions += (layer_outputs[2],)
|
|
|
|
if at_last_stage:
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
if not return_dict:
|
|
return tuple(
|
|
v
|
|
for v in [
|
|
hidden_states,
|
|
next_cache,
|
|
all_hidden_states,
|
|
all_self_attns,
|
|
all_cross_attentions,
|
|
]
|
|
if v is not None
|
|
)
|
|
return BaseModelOutputWithPastAndCrossAttentions(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
cross_attentions=all_cross_attentions,
|
|
)
|
|
|
|
else:
|
|
return {
|
|
"head_mask": head_mask,
|
|
"cross_attn_head_mask": cross_attn_head_mask,
|
|
"hidden_states": hidden_states,
|
|
}
|
|
|
|
@staticmethod
|
|
def whisper_model_forward(
|
|
self: WhisperModel,
|
|
input_features: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
stage_index: Optional[List[int]] = None,
|
|
decoder_starting_stage: Optional[int] = None,
|
|
shard_config: Optional[ShardConfig] = None,
|
|
):
|
|
r"""
|
|
Returns:
|
|
|
|
Example:
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import AutoFeatureExtractor, WhisperModel
|
|
>>> from datasets import load_dataset
|
|
|
|
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
|
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
|
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
|
|
>>> input_features = inputs.input_features
|
|
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
|
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
|
>>> list(last_hidden_state.shape)
|
|
[1, 2, 512]
|
|
```"""
|
|
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
|
if past_key_values:
|
|
logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.")
|
|
past_key_values = None
|
|
if output_attentions:
|
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
|
output_attentions = False
|
|
if output_hidden_states:
|
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
|
output_hidden_states = False
|
|
if use_cache:
|
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
|
use_cache = False
|
|
|
|
logging.get_logger(__name__)
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
in_decoder = stage_manager.stage >= decoder_starting_stage
|
|
if not in_decoder:
|
|
if encoder_outputs is None:
|
|
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
|
|
|
|
encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
|
|
self.encoder,
|
|
input_features,
|
|
head_mask=head_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
stage_manager=stage_manager,
|
|
hidden_states=hidden_states,
|
|
stage_index=stage_index,
|
|
decoder_starting_stage=decoder_starting_stage,
|
|
)
|
|
|
|
if stage_manager.stage == decoder_starting_stage - 1:
|
|
# last stage of encoder
|
|
return {"encoder_hidden_states": encoder_outputs[0]}
|
|
else:
|
|
return encoder_outputs
|
|
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
|
encoder_outputs = BaseModelOutput(
|
|
last_hidden_state=encoder_outputs[0],
|
|
hidden_states=(encoder_outputs[1] if len(encoder_outputs) > 1 else None),
|
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
)
|
|
|
|
at_last_decoder_stage = stage_manager.is_last_stage()
|
|
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
|
|
if encoder_outputs is not None:
|
|
encoder_hidden_states = encoder_outputs[0]
|
|
elif encoder_hidden_states is None:
|
|
raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
|
|
|
|
if not at_first_decoder_stage and hidden_states is None:
|
|
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
|
|
|
|
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(
|
|
self.decoder,
|
|
input_ids=decoder_input_ids,
|
|
attention_mask=decoder_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
head_mask=decoder_head_mask,
|
|
cross_attn_head_mask=cross_attn_head_mask,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=decoder_inputs_embeds,
|
|
position_ids=decoder_position_ids,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
stage_manager=stage_manager,
|
|
hidden_states=hidden_states,
|
|
stage_index=stage_index,
|
|
decoder_starting_stage=decoder_starting_stage,
|
|
shard_config=shard_config,
|
|
)
|
|
|
|
# Directly return outputs of overloaded Whisper forward if not at last stage.
|
|
if not at_last_decoder_stage:
|
|
# encoder_hidden_states should be passed to the next stage
|
|
decoder_outputs["encoder_hidden_states"] = encoder_hidden_states
|
|
return decoder_outputs
|
|
|
|
if not return_dict:
|
|
return decoder_outputs + encoder_outputs
|
|
|
|
return Seq2SeqModelOutput(
|
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
|
past_key_values=decoder_outputs.past_key_values,
|
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
decoder_attentions=decoder_outputs.attentions,
|
|
cross_attentions=decoder_outputs.cross_attentions,
|
|
encoder_last_hidden_state=encoder_hidden_states,
|
|
)
|
|
|
|
@staticmethod
|
|
def whisper_for_conditional_generation_forward(
|
|
self: WhisperForConditionalGeneration,
|
|
input_features: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
stage_index: Optional[List[int]] = None,
|
|
decoder_starting_stage: Optional[int] = None,
|
|
shard_config: Optional[ShardConfig] = None,
|
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
|
|
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
|
|
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> import torch
|
|
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
|
|
>>> from datasets import load_dataset
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
|
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
|
|
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
|
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
|
>>> input_features = inputs.input_features
|
|
|
|
>>> generated_ids = model.generate(inputs=input_features)
|
|
|
|
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
>>> transcription
|
|
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
|
```"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if labels is not None:
|
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
|
decoder_input_ids = shift_tokens_right(
|
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
|
)
|
|
in_decoder = stage_manager.stage >= decoder_starting_stage
|
|
at_last_decoder_stage = stage_manager.is_last_stage()
|
|
outputs = WhisperPipelineForwards.whisper_model_forward(
|
|
self.model,
|
|
input_features,
|
|
attention_mask=attention_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
encoder_outputs=encoder_outputs,
|
|
decoder_attention_mask=decoder_attention_mask,
|
|
head_mask=head_mask,
|
|
decoder_head_mask=decoder_head_mask,
|
|
cross_attn_head_mask=cross_attn_head_mask,
|
|
past_key_values=past_key_values,
|
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
decoder_position_ids=decoder_position_ids,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
stage_manager=stage_manager,
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
stage_index=stage_index,
|
|
decoder_starting_stage=decoder_starting_stage,
|
|
shard_config=shard_config,
|
|
)
|
|
if not in_decoder:
|
|
return outputs
|
|
|
|
if not at_last_decoder_stage:
|
|
# encoder_hidden_states should be passed to the next stage
|
|
outputs["encoder_hidden_states"] = encoder_hidden_states
|
|
return outputs
|
|
|
|
lm_logits = self.proj_out(outputs[0])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
# move labels to correct device to enable PP
|
|
labels = labels.to(lm_logits.device)
|
|
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
|
|
|
|
if not return_dict:
|
|
output = (lm_logits,) + outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return Seq2SeqLMOutput(
|
|
loss=loss,
|
|
logits=lm_logits,
|
|
past_key_values=outputs.past_key_values,
|
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
|
decoder_attentions=outputs.decoder_attentions,
|
|
cross_attentions=outputs.cross_attentions,
|
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
|
encoder_attentions=outputs.encoder_attentions,
|
|
)
|
|
|
|
@staticmethod
|
|
def whisper_for_audio_classification_forward(
|
|
self: WhisperForAudioClassification,
|
|
input_features: Optional[torch.LongTensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
stage_manager: Optional[PipelineStageManager] = None,
|
|
hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_states=None,
|
|
all_attentions=None,
|
|
stage_index: Optional[List[int]] = None,
|
|
decoder_starting_stage: Optional[int] = None,
|
|
shard_config: Optional[ShardConfig] = None,
|
|
):
|
|
r"""
|
|
This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
|
|
Please refer to original code of transformers for more details.
|
|
"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
|
|
if self.config.use_weighted_layer_sum:
|
|
output_hidden_states = True
|
|
elif output_hidden_states is None:
|
|
output_hidden_states = self.config.output_hidden_states
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
# audio_classification only holds encoder
|
|
encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
|
|
self.encoder,
|
|
input_features,
|
|
head_mask=head_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
stage_manager=stage_manager,
|
|
hidden_states=hidden_states,
|
|
stage_index=stage_index,
|
|
decoder_starting_stage=decoder_starting_stage,
|
|
)
|
|
|
|
if not stage_manager.is_last_stage():
|
|
return encoder_outputs
|
|
|
|
if self.config.use_weighted_layer_sum:
|
|
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
|
|
hidden_states = torch.stack(hidden_states, dim=1)
|
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
|
else:
|
|
hidden_states = encoder_outputs[0]
|
|
|
|
hidden_states = self.projector(hidden_states)
|
|
pooled_output = hidden_states.mean(dim=1)
|
|
|
|
logits = self.classifier(pooled_output)
|
|
|
|
loss = None
|
|
|
|
if labels is not None:
|
|
loss_fct = CrossEntropyLoss()
|
|
# move labels to correct device to enable PP
|
|
labels = labels.to(logits.device)
|
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits,) + encoder_outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
hidden_states=encoder_outputs.hidden_states,
|
|
attentions=encoder_outputs.attentions,
|
|
)
|