[shardformer] Pipeline/whisper (#4456)

* add some base tests and policies

* finish whisper base model

* add conditional generation

* finish basic tests

* whisper

* finish whisper

* finish whisper

* del useless  whisper test

* fix

* add argmin to replace

* finish revision
pull/4484/head
Jianghai 2023-08-18 21:29:25 +08:00 committed by GitHub
parent a27e0bb494
commit 8739aa7fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1156 additions and 53 deletions

View File

@ -1,7 +1,26 @@
from typing import Optional, Tuple
import logging
import random
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
SequenceClassifierOutput,
)
from transformers.models.whisper.modeling_whisper import (
WhisperEncoder,
WhisperForAudioClassification,
WhisperForConditionalGeneration,
WhisperModel,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
def get_whisper_flash_attention_forward():
@ -247,3 +266,697 @@ def get_jit_fused_whisper_decoder_layer_forward():
return outputs
return forward
class WhisperPipelineForwards:
'''
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
'''
@staticmethod
def whisper_encoder_forward(
self: WhisperEncoder,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_states=None,
all_attentions=None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
):
r"""
Args:
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.Tensor`)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
but it is not used. By default the silence in the input log mel spectrogram are ignored.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
logger = logging.get_logger(__name__)
stage = stage_manager.stage
at_first_stage = (stage == 0)
at_last_stage = (stage == decoder_starting_stage - 1)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Process inputs if at the first stage of encoder.
if at_first_stage:
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
else:
if hidden_states is None:
raise ValueError(
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
start_idx, end_idx = stage_index[0], stage_index[1]
for idx in range(start_idx, end_idx):
encoder_layer = self.layers[idx]
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
)
else:
layer_outputs = encoder_layer(
hidden_states,
None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if at_last_stage:
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions)
else:
return {'hidden_states': hidden_states, 'head_mask': head_mask}
@staticmethod
def whisper_decoder_forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
logger = logging.get_logger(__name__)
stage = stage_manager.stage
at_first_stage = (stage == decoder_starting_stage)
at_last_stage = (stage == stage_manager.num_stages - 1)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (len(self.layers)), (
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}.")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if at_first_stage:
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
)
use_cache = False
else:
if hidden_states is None:
raise ValueError(
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
input_shape = hidden_states.size()[:-1]
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states,
past_key_values_length)
start_idx, end_idx = stage_index[0], stage_index[1]
for idx in range(start_idx, end_idx):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
decoder_layer = self.layers[idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
None, # encoder attention mask
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, # past_key_value
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx]
if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if at_last_stage:
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
else:
return {
'head_mask': head_mask,
'cross_attn_head_mask': cross_attn_head_mask,
'hidden_states': hidden_states,
}
@staticmethod
def whisper_model_forward(
self: WhisperModel,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
):
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, WhisperModel
>>> from datasets import load_dataset
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
in_decoder = stage_manager.stage >= decoder_starting_stage
if not in_decoder:
if encoder_outputs is None:
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
self.encoder,
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
if stage_manager.stage == decoder_starting_stage - 1:
# last stage of encoder
return {'encoder_hidden_states': encoder_outputs[0]}
else:
return encoder_outputs
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
at_last_decoder_stage = stage_manager.is_last_stage()
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
if encoder_outputs is not None:
encoder_hidden_states = encoder_outputs[0]
elif encoder_hidden_states is None:
raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
if not at_first_decoder_stage and hidden_states is None:
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder,
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
# Directly return outputs of overloaded Whisper forward if not at last stage.
if not at_last_decoder_stage:
# encoder_hidden_states should be passed to the next stage
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
return decoder_outputs
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
)
@staticmethod
def whisper_for_conditional_generation_forward(
self: WhisperForConditionalGeneration,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id,
self.config.decoder_start_token_id)
in_decoder = stage_manager.stage >= decoder_starting_stage
at_last_decoder_stage = stage_manager.is_last_stage()
outputs = WhisperPipelineForwards.whisper_model_forward(self.model,
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
if not in_decoder:
return outputs
if not at_last_decoder_stage:
# encoder_hidden_states should be passed to the next stage
outputs['encoder_hidden_states'] = encoder_hidden_states
return outputs
lm_logits = self.proj_out(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@staticmethod
def whisper_for_audio_classification_forward(
self: WhisperForAudioClassification,
input_features: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_states=None,
all_attentions=None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
):
r"""
This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
Please refer to original code of transformers for more details.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# audio_classification only holds encoder
encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
self.encoder,
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
)
if not stage_manager.is_last_stage():
return encoder_outputs
if self.config.use_weighted_layer_sum:
hidden_states = torch.stack(encoder_outputs, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = encoder_outputs[0]
hidden_states = self.projector(hidden_states)
pooled_output = hidden_states.mean(dim=1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + encoder_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)

View File

@ -304,15 +304,6 @@ class BlipPolicy(Policy):
return policy
def postprocess(self):
binding_map = {
'language_model.model.decoder.embed_tokens': 'language_model.lm_head',
}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model

View File

@ -1,6 +1,7 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
from torch import Tensor, nn
from colossalai.shardformer.layer import (
@ -228,13 +229,7 @@ class T5BasePolicy(Policy):
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = 0
optimal_diff = 2**31 - 1
for i in range(1, num_stages):
attempt = objective(i)
if attempt < optimal_diff:
num_encoder_stages = i
optimal_diff = attempt
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)

View File

@ -1,10 +1,16 @@
from functools import partial
from typing import Callable, Dict, List, Tuple
import numpy as np
import torch.nn as nn
from torch import Tensor
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import (
WhisperPipelineForwards,
get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward,
get_whisper_flash_attention_forward,
@ -12,7 +18,8 @@ from ..modeling.whisper import (
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification'
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
'WhisperForAudioClassificationPolicy'
]
@ -223,6 +230,146 @@ class WhisperPolicy(Policy):
def postprocess(self):
return self.model
@staticmethod
def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
"""
Distribute whisper layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
raise ValueError("The number of encoder layers for whisper must be a positive integer.")
# number of layers should be large enough to fill in every stage
if num_encoder_layers + num_decoder_layers < num_stages:
raise ValueError("The total number of layers can't be smaller than number of stages.")
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optmized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
decoder_starting_stage: int) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]:
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
model = self.model.model
else:
model = None
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
# whisper for audio classification holds encoder only
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
held_layers = []
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
if stage_manager.is_first_stage():
held_layers.append(encoder.embed_positions)
held_layers.append(encoder.conv1)
held_layers.append(encoder.conv2)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.layer_norm)
held_layers.extend(encoder.layers[start_idx:end_idx])
else:
# current stage is in whisper's decoder
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
# the case encoder and decoder put in same stage should be add in the future.
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.embed_positions)
if stage_manager.is_last_stage():
held_layers.append(decoder.layer_norm)
held_layers.extend(decoder.layers[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
model = self.model.model
else:
model = None
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
# WhisperModel
class WhisperModelPolicy(WhisperPolicy):
@ -230,6 +377,24 @@ class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperModel
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperModel,
new_forward=WhisperPipelineForwards.whisper_model_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"no shared params in whisper model"
return []
# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
@ -238,20 +403,82 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
from transformers import WhisperForConditionalGeneration
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
policy=policy)
return policy
def postprocess(self):
binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.proj_out)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
model = module.model
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
stage_manager.num_stages)
shared_params = []
shared_embedding = {}
if id(module.proj_out) == id(model.decoder.embed_tokens):
shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens
shared_embedding[stage_manager.num_stages - 1] = module.proj_out
if len(shared_embedding) > 0:
shared_params.append(shared_embedding)
return shared_params
return []
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def preprocess(self):
return self.model
def module_policy(self):
from transformers import WhisperForAudioClassification
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.projector)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []

View File

@ -0,0 +1,39 @@
from colossalai.shardformer.policies.t5 import T5BasePolicy
def test_t5_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_t5_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end

View File

@ -0,0 +1,44 @@
from colossalai.shardformer.policies.whisper import WhisperPolicy
def test_whisper_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_whisper_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end
if __name__ == '__main__':
test_whisper_pipeline_distribution()
test_whisper_pipeline_layers()

View File

@ -6,6 +6,7 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@ -143,6 +144,7 @@ def run_llama_test(test_config):
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()

View File

@ -3,6 +3,8 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
@ -11,55 +13,145 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
# do backward
org_loss.backward()
shard_loss.backward()
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'WhisperModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwarp the model
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
whisper = org_model.model
sharded_whisper = sharded_model.model
sharded_whisper = sharded_model.unwrap().model
else:
whisper = org_model
sharded_whisper = sharded_model
sharded_whisper = sharded_model.unwrap()
# check grad
if org_model.__class__.__name__ == 'WhisperForAudioClassification':
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj']
else:
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj']
check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
col_layer_for_check = [
'encoder.layers[0].self_attn.q_proj',
# 'decoder.layers[0].self_attn.q_proj'
]
row_layer_for_check = [
'encoder.layers[0].self_attn.out_proj',
#'decoder.layers[0].self_attn.out_proj'
]
# check weights and gradients
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(whisper,
sharded_whisper,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
check_weight(whisper,
sharded_whisper,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
# TODOjianghai) fix fp16
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
}])
def run_whisper_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification':
continue
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -73,7 +165,7 @@ def check_whisper(rank, world_size, port):
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_whisper():
spawn(check_whisper, 2)
spawn(check_whisper, 4)
if __name__ == "__main__":