mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] Pipeline/whisper (#4456)
* add some base tests and policies * finish whisper base model * add conditional generation * finish basic tests * whisper * finish whisper * finish whisper * del useless whisper test * fix * add argmin to replace * finish revisionpull/4484/head
parent
a27e0bb494
commit
8739aa7fa0
|
@ -1,7 +1,26 @@
|
||||||
from typing import Optional, Tuple
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutput,
|
||||||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
Seq2SeqLMOutput,
|
||||||
|
Seq2SeqModelOutput,
|
||||||
|
SequenceClassifierOutput,
|
||||||
|
)
|
||||||
|
from transformers.models.whisper.modeling_whisper import (
|
||||||
|
WhisperEncoder,
|
||||||
|
WhisperForAudioClassification,
|
||||||
|
WhisperForConditionalGeneration,
|
||||||
|
WhisperModel,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
|
|
||||||
def get_whisper_flash_attention_forward():
|
def get_whisper_flash_attention_forward():
|
||||||
|
@ -247,3 +266,697 @@ def get_jit_fused_whisper_decoder_layer_forward():
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperPipelineForwards:
|
||||||
|
'''
|
||||||
|
This class serves as a micro library for forward function substitution of Llama models
|
||||||
|
under pipeline setting.
|
||||||
|
'''
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def whisper_encoder_forward(
|
||||||
|
self: WhisperEncoder,
|
||||||
|
input_features,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_states=None,
|
||||||
|
all_attentions=None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
decoder_starting_stage: Optional[int] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
|
||||||
|
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
|
||||||
|
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
|
||||||
|
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
|
||||||
|
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
|
||||||
|
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
|
||||||
|
attention_mask (`torch.Tensor`)`, *optional*):
|
||||||
|
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
|
||||||
|
but it is not used. By default the silence in the input log mel spectrogram are ignored.
|
||||||
|
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
stage = stage_manager.stage
|
||||||
|
at_first_stage = (stage == 0)
|
||||||
|
at_last_stage = (stage == decoder_starting_stage - 1)
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (output_hidden_states
|
||||||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# Process inputs if at the first stage of encoder.
|
||||||
|
if at_first_stage:
|
||||||
|
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
||||||
|
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
||||||
|
|
||||||
|
inputs_embeds = inputs_embeds.permute(0, 2, 1)
|
||||||
|
embed_pos = self.embed_positions.weight
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds + embed_pos
|
||||||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
encoder_states = () if output_hidden_states else None
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
|
if head_mask is not None:
|
||||||
|
assert head_mask.size()[0] == (
|
||||||
|
len(self.layers)
|
||||||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||||||
|
|
||||||
|
else:
|
||||||
|
if hidden_states is None:
|
||||||
|
raise ValueError(
|
||||||
|
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
|
||||||
|
|
||||||
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
|
||||||
|
for idx in range(start_idx, end_idx):
|
||||||
|
encoder_layer = self.layers[idx]
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
|
dropout_probability = random.uniform(0, 1)
|
||||||
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
|
layer_outputs = (None, None)
|
||||||
|
else:
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(encoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
(head_mask[idx] if head_mask is not None else None),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
if at_last_stage:
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
if output_hidden_states:
|
||||||
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||||
|
return BaseModelOutput(last_hidden_state=hidden_states,
|
||||||
|
hidden_states=encoder_states,
|
||||||
|
attentions=all_attentions)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {'hidden_states': hidden_states, 'head_mask': head_mask}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def whisper_decoder_forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
use_cache=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
decoder_starting_stage: Optional[int] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||||
|
provide it.
|
||||||
|
|
||||||
|
Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||||
|
of the decoder.
|
||||||
|
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||||||
|
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
|
||||||
|
on hidden heads. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||||
|
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||||||
|
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||||
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||||
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
||||||
|
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
||||||
|
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
||||||
|
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
||||||
|
embedding lookup matrix.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||||
|
for more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
stage = stage_manager.stage
|
||||||
|
at_first_stage = (stage == decoder_starting_stage)
|
||||||
|
at_last_stage = (stage == stage_manager.num_stages - 1)
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (output_hidden_states
|
||||||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||||
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||||
|
if attn_mask is not None:
|
||||||
|
assert attn_mask.size()[0] == (len(self.layers)), (
|
||||||
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||||
|
f" {head_mask.size()[0]}.")
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
if at_first_stage:
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
if input_ids is not None:
|
||||||
|
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
||||||
|
else:
|
||||||
|
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
||||||
|
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
|
||||||
|
past_key_values_length)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds + positions
|
||||||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
if hidden_states is None:
|
||||||
|
raise ValueError(
|
||||||
|
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
|
||||||
|
input_shape = hidden_states.size()[:-1]
|
||||||
|
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states,
|
||||||
|
past_key_values_length)
|
||||||
|
|
||||||
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
|
||||||
|
for idx in range(start_idx, end_idx):
|
||||||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
|
decoder_layer = self.layers[idx]
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
dropout_probability = random.uniform(0, 1)
|
||||||
|
if self.training and (dropout_probability < self.layerdrop):
|
||||||
|
continue
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, output_attentions, use_cache)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
None, # encoder attention mask
|
||||||
|
head_mask[idx] if head_mask is not None else None,
|
||||||
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||||||
|
None, # past_key_value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
|
cross_attn_layer_head_mask=(cross_attn_head_mask[idx]
|
||||||
|
if cross_attn_head_mask is not None else None),
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
|
if at_last_stage:
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
||||||
|
if v is not None)
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
'head_mask': head_mask,
|
||||||
|
'cross_attn_head_mask': cross_attn_head_mask,
|
||||||
|
'hidden_states': hidden_states,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def whisper_model_forward(
|
||||||
|
self: WhisperModel,
|
||||||
|
input_features: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
decoder_starting_stage: Optional[int] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoFeatureExtractor, WhisperModel
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
|
||||||
|
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
|
||||||
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
|
||||||
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||||
|
>>> input_features = inputs.input_features
|
||||||
|
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
||||||
|
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||||
|
>>> list(last_hidden_state.shape)
|
||||||
|
[1, 2, 512]
|
||||||
|
```"""
|
||||||
|
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if past_key_values:
|
||||||
|
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
|
||||||
|
past_key_values = None
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
|
output_hidden_states = False
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (output_hidden_states
|
||||||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
in_decoder = stage_manager.stage >= decoder_starting_stage
|
||||||
|
if not in_decoder:
|
||||||
|
if encoder_outputs is None:
|
||||||
|
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
|
||||||
|
self.encoder,
|
||||||
|
input_features,
|
||||||
|
head_mask=head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
decoder_starting_stage=decoder_starting_stage)
|
||||||
|
|
||||||
|
if stage_manager.stage == decoder_starting_stage - 1:
|
||||||
|
# last stage of encoder
|
||||||
|
return {'encoder_hidden_states': encoder_outputs[0]}
|
||||||
|
else:
|
||||||
|
return encoder_outputs
|
||||||
|
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
||||||
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
||||||
|
encoder_outputs = BaseModelOutput(
|
||||||
|
last_hidden_state=encoder_outputs[0],
|
||||||
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||||
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
at_last_decoder_stage = stage_manager.is_last_stage()
|
||||||
|
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
|
||||||
|
if encoder_outputs is not None:
|
||||||
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
|
elif encoder_hidden_states is None:
|
||||||
|
raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
|
||||||
|
|
||||||
|
if not at_first_decoder_stage and hidden_states is None:
|
||||||
|
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||||
|
decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder,
|
||||||
|
input_ids=decoder_input_ids,
|
||||||
|
attention_mask=decoder_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
decoder_starting_stage=decoder_starting_stage)
|
||||||
|
|
||||||
|
# Directly return outputs of overloaded Whisper forward if not at last stage.
|
||||||
|
if not at_last_decoder_stage:
|
||||||
|
# encoder_hidden_states should be passed to the next stage
|
||||||
|
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
|
||||||
|
return decoder_outputs
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
|
return Seq2SeqModelOutput(
|
||||||
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||||
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
|
encoder_last_hidden_state=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def whisper_for_conditional_generation_forward(
|
||||||
|
self: WhisperForConditionalGeneration,
|
||||||
|
input_features: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
decoder_starting_stage: Optional[int] = None,
|
||||||
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
|
||||||
|
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
|
||||||
|
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> import torch
|
||||||
|
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
|
||||||
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
|
||||||
|
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||||
|
>>> input_features = inputs.input_features
|
||||||
|
|
||||||
|
>>> generated_ids = model.generate(inputs=input_features)
|
||||||
|
|
||||||
|
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
>>> transcription
|
||||||
|
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||||
|
```"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||||||
|
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id,
|
||||||
|
self.config.decoder_start_token_id)
|
||||||
|
in_decoder = stage_manager.stage >= decoder_starting_stage
|
||||||
|
at_last_decoder_stage = stage_manager.is_last_stage()
|
||||||
|
outputs = WhisperPipelineForwards.whisper_model_forward(self.model,
|
||||||
|
input_features,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
decoder_head_mask=decoder_head_mask,
|
||||||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
decoder_starting_stage=decoder_starting_stage)
|
||||||
|
if not in_decoder:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
if not at_last_decoder_stage:
|
||||||
|
# encoder_hidden_states should be passed to the next stage
|
||||||
|
outputs['encoder_hidden_states'] = encoder_hidden_states
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
lm_logits = self.proj_out(outputs[0])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
# move labels to correct device to enable PP
|
||||||
|
labels = labels.to(lm_logits.device)
|
||||||
|
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (lm_logits,) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return Seq2SeqLMOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=lm_logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||||
|
decoder_attentions=outputs.decoder_attentions,
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||||
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||||
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def whisper_for_audio_classification_forward(
|
||||||
|
self: WhisperForAudioClassification,
|
||||||
|
input_features: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_states=None,
|
||||||
|
all_attentions=None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
decoder_starting_stage: Optional[int] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
|
||||||
|
Please refer to original code of transformers for more details.
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (output_hidden_states
|
||||||
|
if output_hidden_states is not None else self.config.output_hidden_states)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# audio_classification only holds encoder
|
||||||
|
encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
|
||||||
|
self.encoder,
|
||||||
|
input_features,
|
||||||
|
head_mask=head_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
decoder_starting_stage=decoder_starting_stage,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not stage_manager.is_last_stage():
|
||||||
|
return encoder_outputs
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = torch.stack(encoder_outputs, dim=1)
|
||||||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
pooled_output = hidden_states.mean(dim=1)
|
||||||
|
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
# move labels to correct device to enable PP
|
||||||
|
labels = labels.to(logits.device)
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + encoder_outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
|
@ -304,15 +304,6 @@ class BlipPolicy(Policy):
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {
|
|
||||||
'language_model.model.decoder.embed_tokens': 'language_model.lm_head',
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v in binding_map.items():
|
|
||||||
src_mod = getattr_(self.model, k)
|
|
||||||
dst_mod = getattr_(self.model, v)
|
|
||||||
dst_mod.weight = src_mod.weight
|
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from colossalai.shardformer.layer import (
|
from colossalai.shardformer.layer import (
|
||||||
|
@ -228,13 +229,7 @@ class T5BasePolicy(Policy):
|
||||||
def objective(num_encoder_stages):
|
def objective(num_encoder_stages):
|
||||||
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
|
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
|
||||||
|
|
||||||
num_encoder_stages = 0
|
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
||||||
optimal_diff = 2**31 - 1
|
|
||||||
for i in range(1, num_stages):
|
|
||||||
attempt = objective(i)
|
|
||||||
if attempt < optimal_diff:
|
|
||||||
num_encoder_stages = i
|
|
||||||
optimal_diff = attempt
|
|
||||||
num_decoder_stages = num_stages - num_encoder_stages
|
num_decoder_stages = num_stages - num_encoder_stages
|
||||||
|
|
||||||
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
|
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||||
from ..modeling.whisper import (
|
from ..modeling.whisper import (
|
||||||
|
WhisperPipelineForwards,
|
||||||
get_jit_fused_whisper_decoder_layer_forward,
|
get_jit_fused_whisper_decoder_layer_forward,
|
||||||
get_jit_fused_whisper_encoder_layer_forward,
|
get_jit_fused_whisper_encoder_layer_forward,
|
||||||
get_whisper_flash_attention_forward,
|
get_whisper_flash_attention_forward,
|
||||||
|
@ -12,7 +18,8 @@ from ..modeling.whisper import (
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification'
|
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
|
||||||
|
'WhisperForAudioClassificationPolicy'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -223,6 +230,146 @@ class WhisperPolicy(Policy):
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
|
||||||
|
num_stages: int) -> Tuple[List[int], int]:
|
||||||
|
"""
|
||||||
|
Distribute whisper layers into stages when pipeline parallel is used.
|
||||||
|
Return the layer distribution as a list and the starting stage of decoder.
|
||||||
|
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# number of encoder layers must be a positive integer
|
||||||
|
if num_encoder_layers <= 0:
|
||||||
|
raise ValueError("The number of encoder layers for whisper must be a positive integer.")
|
||||||
|
|
||||||
|
# number of layers should be large enough to fill in every stage
|
||||||
|
if num_encoder_layers + num_decoder_layers < num_stages:
|
||||||
|
raise ValueError("The total number of layers can't be smaller than number of stages.")
|
||||||
|
|
||||||
|
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
||||||
|
if num_decoder_layers == 0:
|
||||||
|
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||||
|
|
||||||
|
# the number of stages distributed between encoder and decoder is optmized in this way:
|
||||||
|
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
||||||
|
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
|
||||||
|
def objective(num_encoder_stages):
|
||||||
|
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
|
||||||
|
|
||||||
|
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
||||||
|
num_decoder_stages = num_stages - num_encoder_stages
|
||||||
|
|
||||||
|
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||||
|
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||||
|
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
|
||||||
|
decoder_starting_stage: int) -> Tuple[bool, int, int]:
|
||||||
|
"""
|
||||||
|
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||||
|
Return the starting/ending idx of layers in encoder/decoder
|
||||||
|
"""
|
||||||
|
if stage < decoder_starting_stage:
|
||||||
|
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||||
|
else:
|
||||||
|
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
|
||||||
|
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == 'WhisperModel':
|
||||||
|
model = self.model
|
||||||
|
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
|
||||||
|
model = self.model.model
|
||||||
|
else:
|
||||||
|
model = None
|
||||||
|
|
||||||
|
if model:
|
||||||
|
encoder = self.model.get_encoder()
|
||||||
|
decoder = self.model.get_decoder()
|
||||||
|
else:
|
||||||
|
# whisper for audio classification holds encoder only
|
||||||
|
encoder = self.model.encoder
|
||||||
|
decoder = None
|
||||||
|
|
||||||
|
num_encoder_layers = len(encoder.layers)
|
||||||
|
if decoder:
|
||||||
|
num_decoder_layers = len(decoder.layers)
|
||||||
|
else:
|
||||||
|
num_decoder_layers = 0
|
||||||
|
|
||||||
|
held_layers = []
|
||||||
|
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
|
||||||
|
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
|
||||||
|
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
|
||||||
|
decoder_starting_stage)
|
||||||
|
|
||||||
|
if stage_manager.stage < decoder_starting_stage:
|
||||||
|
# current stage is in whisper's encoder
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(encoder.embed_positions)
|
||||||
|
held_layers.append(encoder.conv1)
|
||||||
|
held_layers.append(encoder.conv2)
|
||||||
|
if stage_manager.stage == decoder_starting_stage - 1:
|
||||||
|
held_layers.append(encoder.layer_norm)
|
||||||
|
held_layers.extend(encoder.layers[start_idx:end_idx])
|
||||||
|
else:
|
||||||
|
# current stage is in whisper's decoder
|
||||||
|
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
|
||||||
|
# the case encoder and decoder put in same stage should be add in the future.
|
||||||
|
if stage_manager.stage == decoder_starting_stage:
|
||||||
|
held_layers.append(decoder.embed_tokens)
|
||||||
|
held_layers.append(decoder.embed_positions)
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(decoder.layer_norm)
|
||||||
|
held_layers.extend(decoder.layers[start_idx:end_idx])
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||||
|
to customized forward method, and add this changing to policy."""
|
||||||
|
if not self.pipeline_stage_manager:
|
||||||
|
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == 'WhisperModel':
|
||||||
|
model = self.model
|
||||||
|
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
|
||||||
|
model = self.model.model
|
||||||
|
else:
|
||||||
|
model = None
|
||||||
|
|
||||||
|
if model:
|
||||||
|
encoder = self.model.get_encoder()
|
||||||
|
decoder = self.model.get_decoder()
|
||||||
|
else:
|
||||||
|
encoder = self.model.encoder
|
||||||
|
decoder = None
|
||||||
|
|
||||||
|
num_encoder_layers = len(encoder.layers)
|
||||||
|
if decoder:
|
||||||
|
num_decoder_layers = len(decoder.layers)
|
||||||
|
else:
|
||||||
|
num_decoder_layers = 0
|
||||||
|
|
||||||
|
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
|
||||||
|
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
|
||||||
|
stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
|
||||||
|
decoder_starting_stage)
|
||||||
|
|
||||||
|
method_replacement = {
|
||||||
|
'forward':
|
||||||
|
partial(new_forward,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
stage_index=stage_index,
|
||||||
|
decoder_starting_stage=decoder_starting_stage)
|
||||||
|
}
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
||||||
|
|
||||||
# WhisperModel
|
# WhisperModel
|
||||||
class WhisperModelPolicy(WhisperPolicy):
|
class WhisperModelPolicy(WhisperPolicy):
|
||||||
|
@ -230,6 +377,24 @@ class WhisperModelPolicy(WhisperPolicy):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import WhisperModel
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager is not None:
|
||||||
|
self.set_pipeline_forward(model_cls=WhisperModel,
|
||||||
|
new_forward=WhisperPipelineForwards.whisper_model_forward,
|
||||||
|
policy=policy)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
return super().get_held_layers()
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"no shared params in whisper model"
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
# WhisperForConditionalGeneration
|
# WhisperForConditionalGeneration
|
||||||
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||||
|
@ -238,20 +403,82 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
module_policy = super().module_policy()
|
from transformers import WhisperForConditionalGeneration
|
||||||
module_policy = self.add_lm_head_policy(module_policy)
|
policy = super().module_policy()
|
||||||
return module_policy
|
policy = self.add_lm_head_policy(policy)
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager is not None:
|
||||||
|
self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
|
||||||
|
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
|
||||||
|
policy=policy)
|
||||||
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"}
|
|
||||||
for k, v in binding_map.items():
|
|
||||||
param = getattr_(self.model, k)
|
|
||||||
setattr_(self.model, v, param)
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if self.pipeline_stage_manager.is_last_stage():
|
||||||
|
held_layers.append(self.model.proj_out)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
module = self.model
|
||||||
|
model = module.model
|
||||||
|
|
||||||
|
if model:
|
||||||
|
encoder = self.model.get_encoder()
|
||||||
|
decoder = self.model.get_decoder()
|
||||||
|
else:
|
||||||
|
encoder = self.model.encoder
|
||||||
|
decoder = None
|
||||||
|
|
||||||
|
num_encoder_layers = len(encoder.layers)
|
||||||
|
if decoder:
|
||||||
|
num_decoder_layers = len(decoder.layers)
|
||||||
|
else:
|
||||||
|
num_decoder_layers = 0
|
||||||
|
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||||
|
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
|
||||||
|
stage_manager.num_stages)
|
||||||
|
shared_params = []
|
||||||
|
shared_embedding = {}
|
||||||
|
if id(module.proj_out) == id(model.decoder.embed_tokens):
|
||||||
|
shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens
|
||||||
|
shared_embedding[stage_manager.num_stages - 1] = module.proj_out
|
||||||
|
if len(shared_embedding) > 0:
|
||||||
|
shared_params.append(shared_embedding)
|
||||||
|
return shared_params
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
# WhisperForAudioClassification
|
# WhisperForAudioClassification
|
||||||
class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
from transformers import WhisperForAudioClassification
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager is not None:
|
||||||
|
self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
|
||||||
|
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
|
||||||
|
policy=policy)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if self.pipeline_stage_manager.is_last_stage():
|
||||||
|
held_layers.append(self.model.projector)
|
||||||
|
held_layers.append(self.model.classifier)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
return []
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
from colossalai.shardformer.policies.t5 import T5BasePolicy
|
||||||
|
|
||||||
|
|
||||||
|
def test_t5_pipeline_distribution():
|
||||||
|
num_test_cases = 8
|
||||||
|
test_dict = {
|
||||||
|
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
|
||||||
|
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
|
||||||
|
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
|
||||||
|
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(num_test_cases):
|
||||||
|
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
|
||||||
|
test_dict['num_decoder_layers'][i],
|
||||||
|
test_dict['num_stages'][i])
|
||||||
|
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
|
||||||
|
|
||||||
|
|
||||||
|
def test_t5_pipeline_layers():
|
||||||
|
num_test_cases = 4
|
||||||
|
test_dict = {
|
||||||
|
'num_encoder_layers': [2, 3, 2, 4],
|
||||||
|
'num_decoder_layers': [2, 0, 2, 8],
|
||||||
|
'num_stages': [2, 2, 4, 4],
|
||||||
|
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
|
||||||
|
[[0, 4], [0, 3], [3, 6], [6, 8]]]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(num_test_cases):
|
||||||
|
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||||
|
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
|
||||||
|
|
||||||
|
for stage in range(test_dict['num_stages'][i]):
|
||||||
|
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
|
||||||
|
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
|
||||||
|
decoder_starting_stage)
|
||||||
|
assert start_idx == predicted_start
|
||||||
|
assert end_idx == predicted_end
|
|
@ -0,0 +1,44 @@
|
||||||
|
from colossalai.shardformer.policies.whisper import WhisperPolicy
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisper_pipeline_distribution():
|
||||||
|
num_test_cases = 8
|
||||||
|
test_dict = {
|
||||||
|
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
|
||||||
|
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
|
||||||
|
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
|
||||||
|
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(num_test_cases):
|
||||||
|
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i],
|
||||||
|
test_dict['num_decoder_layers'][i],
|
||||||
|
test_dict['num_stages'][i])
|
||||||
|
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisper_pipeline_layers():
|
||||||
|
num_test_cases = 4
|
||||||
|
test_dict = {
|
||||||
|
'num_encoder_layers': [2, 3, 2, 4],
|
||||||
|
'num_decoder_layers': [2, 0, 2, 8],
|
||||||
|
'num_stages': [2, 2, 4, 4],
|
||||||
|
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
|
||||||
|
[[0, 4], [0, 3], [3, 6], [6, 8]]]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in range(num_test_cases):
|
||||||
|
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
|
||||||
|
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
|
||||||
|
|
||||||
|
for stage in range(test_dict['num_stages'][i]):
|
||||||
|
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
|
||||||
|
predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage,
|
||||||
|
decoder_starting_stage)
|
||||||
|
assert start_idx == predicted_start
|
||||||
|
assert end_idx == predicted_end
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_whisper_pipeline_distribution()
|
||||||
|
test_whisper_pipeline_layers()
|
|
@ -6,6 +6,7 @@ from torch import distributed as dist
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
@ -143,6 +144,7 @@ def run_llama_test(test_config):
|
||||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
clear_layout_converter()
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,8 @@ import torch
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.utils import Randomizer
|
||||||
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||||
from colossalai.testing import (
|
from colossalai.testing import (
|
||||||
assert_hf_output_close,
|
assert_hf_output_close,
|
||||||
clear_cache_before_run,
|
clear_cache_before_run,
|
||||||
|
@ -11,55 +13,145 @@ from colossalai.testing import (
|
||||||
spawn,
|
spawn,
|
||||||
)
|
)
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
|
from tests.test_shardformer.test_model._utils import (
|
||||||
|
build_model_from_hybrid_plugin,
|
||||||
|
check_grad,
|
||||||
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
|
check_weight,
|
||||||
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
||||||
# check forward
|
# check forward
|
||||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
|
||||||
output_transform_fn, loss_fn)
|
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
|
||||||
assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5)
|
|
||||||
|
|
||||||
# do backward
|
org_loss, org_output, sharded_loss, sharded_output = \
|
||||||
org_loss.backward()
|
run_forward_backward_with_hybrid_plugin(
|
||||||
shard_loss.backward()
|
org_model,
|
||||||
|
sharded_model,
|
||||||
|
sharded_optimizer,
|
||||||
|
data_gen_fn,
|
||||||
|
output_transform_fn,
|
||||||
|
criterion,
|
||||||
|
booster)
|
||||||
|
|
||||||
assert torch.allclose(org_loss, shard_loss,
|
stage_manager = booster.plugin.stage_manager
|
||||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
tp_group = booster.plugin.tp_group
|
||||||
|
|
||||||
|
# check last hidden state & loss
|
||||||
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
|
if test_config['precision'] == 'fp32':
|
||||||
|
atol, rtol = 1e-3, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
|
if org_model.__class__.__name__ == 'WhisperModel':
|
||||||
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
# unwarp the model
|
# unwarp the model
|
||||||
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
|
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
|
||||||
whisper = org_model.model
|
whisper = org_model.model
|
||||||
sharded_whisper = sharded_model.model
|
sharded_whisper = sharded_model.unwrap().model
|
||||||
else:
|
else:
|
||||||
whisper = org_model
|
whisper = org_model
|
||||||
sharded_whisper = sharded_model
|
sharded_whisper = sharded_model.unwrap()
|
||||||
|
|
||||||
# check grad
|
# check grad
|
||||||
if org_model.__class__.__name__ == 'WhisperForAudioClassification':
|
if org_model.__class__.__name__ == 'WhisperForAudioClassification':
|
||||||
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj']
|
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj']
|
||||||
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj']
|
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj']
|
||||||
else:
|
else:
|
||||||
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj']
|
col_layer_for_check = [
|
||||||
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj']
|
'encoder.layers[0].self_attn.q_proj',
|
||||||
check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
|
# 'decoder.layers[0].self_attn.q_proj'
|
||||||
check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
|
]
|
||||||
|
row_layer_for_check = [
|
||||||
|
'encoder.layers[0].self_attn.out_proj',
|
||||||
|
#'decoder.layers[0].self_attn.out_proj'
|
||||||
|
]
|
||||||
|
|
||||||
|
# check weights and gradients
|
||||||
|
if test_config['precision'] == 'fp32':
|
||||||
|
atol, rtol = 1e-3, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
|
||||||
|
check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
||||||
|
|
||||||
|
# check weights after optimizer.step()
|
||||||
|
org_optimizer.step()
|
||||||
|
sharded_optimizer.step()
|
||||||
|
if test_config['precision'] == 'fp32':
|
||||||
|
atol, rtol = 1e-3, 1e-3
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
|
check_weight(whisper,
|
||||||
|
sharded_whisper,
|
||||||
|
row_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=0,
|
||||||
|
verbose=False)
|
||||||
|
check_weight(whisper,
|
||||||
|
sharded_whisper,
|
||||||
|
col_layer_for_check,
|
||||||
|
tp_group,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
|
dim=0,
|
||||||
|
verbose=False)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@parameterize('enable_fused_normalization', [True, False])
|
# TODO(jianghai) fix fp16
|
||||||
@parameterize('enable_tensor_parallelism', [True, False])
|
@parameterize('test_config', [{
|
||||||
@parameterize('enable_flash_attention', [True, False])
|
'tp_size': 2,
|
||||||
@parameterize('enable_jit_fused', [True, False])
|
'pp_size': 2,
|
||||||
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
|
'num_microbatches': 2,
|
||||||
|
'enable_all_optimization': True,
|
||||||
|
'use_lazy_init': True,
|
||||||
|
'precision': 'fp32',
|
||||||
|
'initial_scale': 1,
|
||||||
|
}, {
|
||||||
|
'tp_size': 1,
|
||||||
|
'pp_size': 2,
|
||||||
|
'num_microbatches': 4,
|
||||||
|
'use_lazy_init': False,
|
||||||
|
'precision': 'fp32',
|
||||||
|
'initial_scale': 1,
|
||||||
|
}, {
|
||||||
|
'tp_size': 4,
|
||||||
|
'pp_size': 1,
|
||||||
|
'enable_all_optimization': True,
|
||||||
|
'use_lazy_init': False,
|
||||||
|
'precision': 'fp32',
|
||||||
|
}, {
|
||||||
|
'tp_size': 1,
|
||||||
|
'pp_size': 4,
|
||||||
|
'num_microbatches': 4,
|
||||||
|
'use_lazy_init': False,
|
||||||
|
'precision': 'fp32',
|
||||||
|
}])
|
||||||
|
def run_whisper_test(test_config):
|
||||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
org_model, sharded_model = build_model(model_fn,
|
|
||||||
enable_fused_normalization=enable_fused_normalization,
|
|
||||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
|
||||||
enable_flash_attention=enable_flash_attention,
|
|
||||||
enable_jit_fused=enable_jit_fused)
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
|
||||||
|
|
||||||
|
if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification':
|
||||||
|
continue
|
||||||
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||||
|
|
||||||
|
clear_layout_converter()
|
||||||
|
Randomizer.reset_index()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,7 +165,7 @@ def check_whisper(rank, world_size, port):
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def test_whisper():
|
def test_whisper():
|
||||||
spawn(check_whisper, 2)
|
spawn(check_whisper, 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue