From 8739aa7fa01a9c04743dea813e2cc210e30dd77f Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Fri, 18 Aug 2023 21:29:25 +0800 Subject: [PATCH] [shardformer] Pipeline/whisper (#4456) * add some base tests and policies * finish whisper base model * add conditional generation * finish basic tests * whisper * finish whisper * finish whisper * del useless whisper test * fix * add argmin to replace * finish revision --- colossalai/shardformer/modeling/whisper.py | 715 +++++++++++++++++- colossalai/shardformer/policies/blip2.py | 9 - colossalai/shardformer/policies/t5.py | 9 +- colossalai/shardformer/policies/whisper.py | 243 +++++- .../test_t5_pipeline_utils.py | 39 + .../test_whisper_pipeline_utils.py | 44 ++ .../test_model/test_shard_llama.py | 2 + .../test_model/test_shard_whisper.py | 148 +++- 8 files changed, 1156 insertions(+), 53 deletions(-) create mode 100644 tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py create mode 100644 tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 0a16c6f78..62f8f7b47 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -1,7 +1,26 @@ -from typing import Optional, Tuple +import logging +import random +from typing import Dict, List, Optional, Set, Tuple, Union import torch from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, +) +from transformers.models.whisper.modeling_whisper import ( + WhisperEncoder, + WhisperForAudioClassification, + WhisperForConditionalGeneration, + WhisperModel, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager def get_whisper_flash_attention_forward(): @@ -247,3 +266,697 @@ def get_jit_fused_whisper_decoder_layer_forward(): return outputs return forward + + +class WhisperPipelineForwards: + ''' + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + ''' + + @staticmethod + def whisper_encoder_forward( + self: WhisperEncoder, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + + stage = stage_manager.stage + at_first_stage = (stage == 0) + at_last_stage = (stage == decoder_starting_stage - 1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Process inputs if at the first stage of encoder. + if at_first_stage: + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + else: + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + encoder_layer = self.layers[idx] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + None, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + else: + return {'hidden_states': hidden_states, 'head_mask': head_mask} + + @staticmethod + def whisper_decoder_forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more + control over how to convert `input_ids` indices into associated vectors than the model's internal + embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + logger = logging.get_logger(__name__) + stage = stage_manager.stage + at_first_stage = (stage == decoder_starting_stage) + at_last_stage = (stage == stage_manager.num_stages - 1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}.") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if at_first_stage: + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if input_ids is not None: + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + else: + positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, + past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + else: + + if hidden_states is None: + raise ValueError( + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + input_shape = hidden_states.size()[:-1] + + attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states, + past_key_values_length) + + start_idx, end_idx = stage_index[0], stage_index[1] + + for idx in range(start_idx, end_idx): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + None, # encoder attention mask + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, # past_key_value + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] + if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if at_last_stage: + hidden_states = self.layer_norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + else: + return { + 'head_mask': head_mask, + 'cross_attn_head_mask': cross_attn_head_mask, + 'hidden_states': hidden_states, + } + + @staticmethod + def whisper_model_forward( + self: WhisperModel, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + Returns: + + Example: + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, WhisperModel + >>> from datasets import load_dataset + + >>> model = WhisperModel.from_pretrained("openai/whisper-base") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id + >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state + >>> list(last_hidden_state.shape) + [1, 2, 512] + ```""" + # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + if past_key_values: + logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + past_key_values = None + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + in_decoder = stage_manager.stage >= decoder_starting_stage + if not in_decoder: + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + if stage_manager.stage == decoder_starting_stage - 1: + # last stage of encoder + return {'encoder_hidden_states': encoder_outputs[0]} + else: + return encoder_outputs + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + at_last_decoder_stage = stage_manager.is_last_stage() + at_first_decoder_stage = stage_manager.stage == decoder_starting_stage + if encoder_outputs is not None: + encoder_hidden_states = encoder_outputs[0] + elif encoder_hidden_states is None: + raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.") + + if not at_first_decoder_stage and hidden_states is None: + raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + + # Directly return outputs of overloaded Whisper forward if not at last stage. + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + return decoder_outputs + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) + + @staticmethod + def whisper_for_conditional_generation_forward( + self: WhisperForConditionalGeneration, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, + self.config.decoder_start_token_id) + in_decoder = stage_manager.stage >= decoder_starting_stage + at_last_decoder_stage = stage_manager.is_last_stage() + outputs = WhisperPipelineForwards.whisper_model_forward(self.model, + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + if not in_decoder: + return outputs + + if not at_last_decoder_stage: + # encoder_hidden_states should be passed to the next stage + outputs['encoder_hidden_states'] = encoder_hidden_states + return outputs + + lm_logits = self.proj_out(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + @staticmethod + def whisper_for_audio_classification_forward( + self: WhisperForAudioClassification, + input_features: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + encoder_states=None, + all_attentions=None, + stage_index: Optional[List[int]] = None, + decoder_starting_stage: Optional[int] = None, + ): + r""" + This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward. + Please refer to original code of transformers for more details. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # audio_classification only holds encoder + encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward( + self.encoder, + input_features, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) + + if not stage_manager.is_last_stage(): + return encoder_outputs + + if self.config.use_weighted_layer_sum: + hidden_states = torch.stack(encoder_outputs, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = encoder_outputs[0] + + hidden_states = self.projector(hidden_states) + pooled_output = hidden_states.mean(dim=1) + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + loss_fct = CrossEntropyLoss() + # move labels to correct device to enable PP + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + encoder_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 50356302e..3610e2c41 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -304,15 +304,6 @@ class BlipPolicy(Policy): return policy def postprocess(self): - binding_map = { - 'language_model.model.decoder.embed_tokens': 'language_model.lm_head', - } - - for k, v in binding_map.items(): - src_mod = getattr_(self.model, k) - dst_mod = getattr_(self.model, v) - dst_mod.weight = src_mod.weight - return self.model diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 2ef52c214..651883d35 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,6 +1,7 @@ from functools import partial from typing import Callable, Dict, List, Optional, Tuple +import numpy as np from torch import Tensor, nn from colossalai.shardformer.layer import ( @@ -228,13 +229,7 @@ class T5BasePolicy(Policy): def objective(num_encoder_stages): return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) - num_encoder_stages = 0 - optimal_diff = 2**31 - 1 - for i in range(1, num_stages): - attempt = objective(i) - if attempt < optimal_diff: - num_encoder_stages = i - optimal_diff = attempt + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 2ac7a49fd..a33f929f1 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -1,10 +1,16 @@ +from functools import partial +from typing import Callable, Dict, List, Tuple + +import numpy as np import torch.nn as nn +from torch import Tensor import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.whisper import ( + WhisperPipelineForwards, get_jit_fused_whisper_decoder_layer_forward, get_jit_fused_whisper_encoder_layer_forward, get_whisper_flash_attention_forward, @@ -12,7 +18,8 @@ from ..modeling.whisper import ( from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification' + 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', + 'WhisperForAudioClassificationPolicy' ] @@ -223,6 +230,146 @@ class WhisperPolicy(Policy): def postprocess(self): return self.model + @staticmethod + def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int, + num_stages: int) -> Tuple[List[int], int]: + """ + Distribute whisper layers into stages when pipeline parallel is used. + Return the layer distribution as a list and the starting stage of decoder. + If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. + """ + + # number of encoder layers must be a positive integer + if num_encoder_layers <= 0: + raise ValueError("The number of encoder layers for whisper must be a positive integer.") + + # number of layers should be large enough to fill in every stage + if num_encoder_layers + num_decoder_layers < num_stages: + raise ValueError("The total number of layers can't be smaller than number of stages.") + + # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist + if num_decoder_layers == 0: + return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages + + # the number of stages distributed between encoder and decoder is optmized in this way: + # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) + # s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1 + def objective(num_encoder_stages): + return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages)) + + num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 + num_decoder_stages = num_stages - num_encoder_stages + + encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages) + return encoder_distribution + decoder_distribution, num_encoder_stages + + @staticmethod + def get_whisper_stage_index(layers_per_stage: List[int], stage: int, + decoder_starting_stage: int) -> Tuple[bool, int, int]: + """ + Input the distribution of layers among stages, the current stage and the first stage of decoder. + Return the starting/ending idx of layers in encoder/decoder + """ + if stage < decoder_starting_stage: + return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + else: + return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + + def get_held_layers(self) -> List[nn.Module]: + + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == 'WhisperModel': + model = self.model + elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + # whisper for audio classification holds encoder only + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + held_layers = [] + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + if stage_manager.stage < decoder_starting_stage: + # current stage is in whisper's encoder + if stage_manager.is_first_stage(): + held_layers.append(encoder.embed_positions) + held_layers.append(encoder.conv1) + held_layers.append(encoder.conv2) + if stage_manager.stage == decoder_starting_stage - 1: + held_layers.append(encoder.layer_norm) + held_layers.extend(encoder.layers[start_idx:end_idx]) + else: + # current stage is in whisper's decoder + # TODO:(Jianghai) We divide encoder and decoder layers into different parts here, + # the case encoder and decoder put in same stage should be add in the future. + if stage_manager.stage == decoder_starting_stage: + held_layers.append(decoder.embed_tokens) + held_layers.append(decoder.embed_positions) + if stage_manager.is_last_stage(): + held_layers.append(decoder.layer_norm) + held_layers.extend(decoder.layers[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if not self.pipeline_stage_manager: + raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") + stage_manager = self.pipeline_stage_manager + + if self.model.__class__.__name__ == 'WhisperModel': + model = self.model + elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + model = self.model.model + else: + model = None + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, + decoder_starting_stage) + + method_replacement = { + 'forward': + partial(new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage) + } + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) + # WhisperModel class WhisperModelPolicy(WhisperPolicy): @@ -230,6 +377,24 @@ class WhisperModelPolicy(WhisperPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + from transformers import WhisperModel + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperModel, + new_forward=WhisperPipelineForwards.whisper_model_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + return super().get_held_layers() + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + "no shared params in whisper model" + return [] + # WhisperForConditionalGeneration class WhisperForConditionalGenerationPolicy(WhisperPolicy): @@ -238,20 +403,82 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): super().__init__() def module_policy(self): - module_policy = super().module_policy() - module_policy = self.add_lm_head_policy(module_policy) - return module_policy + from transformers import WhisperForConditionalGeneration + policy = super().module_policy() + policy = self.add_lm_head_policy(policy) + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration, + new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, + policy=policy) + return policy def postprocess(self): - binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - setattr_(self.model, v, param) return self.model + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.proj_out) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + module = self.model + model = module.model + + if model: + encoder = self.model.get_encoder() + decoder = self.model.get_decoder() + else: + encoder = self.model.encoder + decoder = None + + num_encoder_layers = len(encoder.layers) + if decoder: + num_decoder_layers = len(decoder.layers) + else: + num_decoder_layers = 0 + + stage_manager = self.pipeline_stage_manager + if stage_manager is not None and stage_manager.num_stages > 1: + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers, + stage_manager.num_stages) + shared_params = [] + shared_embedding = {} + if id(module.proj_out) == id(model.decoder.embed_tokens): + shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens + shared_embedding[stage_manager.num_stages - 1] = module.proj_out + if len(shared_embedding) > 0: + shared_params.append(shared_embedding) + return shared_params + return [] + # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): def __init__(self) -> None: super().__init__() + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers import WhisperForAudioClassification + policy = super().module_policy() + + if self.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=WhisperForAudioClassification, + new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + if self.pipeline_stage_manager.is_last_stage(): + held_layers.append(self.model.projector) + held_layers.append(self.model.classifier) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + return [] diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py new file mode 100644 index 000000000..0cbb852b9 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -0,0 +1,39 @@ +from colossalai.shardformer.policies.t5 import T5BasePolicy + + +def test_t5_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_t5_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py new file mode 100644 index 000000000..395519e97 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -0,0 +1,44 @@ +from colossalai.shardformer.policies.whisper import WhisperPolicy + + +def test_whisper_pipeline_distribution(): + num_test_cases = 8 + test_dict = { + 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], + 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], + 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], + 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + } + + for i in range(num_test_cases): + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i], + test_dict['num_decoder_layers'][i], + test_dict['num_stages'][i]) + assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + + +def test_whisper_pipeline_layers(): + num_test_cases = 4 + test_dict = { + 'num_encoder_layers': [2, 3, 2, 4], + 'num_decoder_layers': [2, 0, 2, 8], + 'num_stages': [2, 2, 4, 4], + 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]]] + } + + for i in range(num_test_cases): + layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + + for stage in range(test_dict['num_stages'][i]): + start_idx, end_idx = test_dict['layers_per_stage'][i][stage] + predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage, + decoder_starting_stage) + assert start_idx == predicted_start + assert end_idx == predicted_end + + +if __name__ == '__main__': + test_whisper_pipeline_distribution() + test_whisper_pipeline_layers() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a433567b3..ec5578a76 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -6,6 +6,7 @@ from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -143,6 +144,7 @@ def run_llama_test(test_config): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index 9b38ae07b..90e007e34 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -3,6 +3,8 @@ import torch import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( assert_hf_output_close, clear_cache_before_run, @@ -11,55 +13,145 @@ from colossalai.testing import ( spawn, ) from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward +from tests.test_shardformer.test_model._utils import ( + build_model_from_hybrid_plugin, + check_grad, + check_loss, + check_output_hidden_state, + check_weight, + run_forward_backward_with_hybrid_plugin, +) -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): +def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5) + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ + build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - # do backward - org_loss.backward() - shard_loss.backward() + org_loss, org_output, sharded_loss, sharded_output = \ + run_forward_backward_with_hybrid_plugin( + org_model, + sharded_model, + sharded_optimizer, + data_gen_fn, + output_transform_fn, + criterion, + booster) - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + stage_manager = booster.plugin.stage_manager + tp_group = booster.plugin.tp_group + + # check last hidden state & loss + if stage_manager is None or stage_manager.is_last_stage(): + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if org_model.__class__.__name__ == 'WhisperModel': + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) + + check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # unwarp the model if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': whisper = org_model.model - sharded_whisper = sharded_model.model + sharded_whisper = sharded_model.unwrap().model else: whisper = org_model - sharded_whisper = sharded_model + sharded_whisper = sharded_model.unwrap() # check grad if org_model.__class__.__name__ == 'WhisperForAudioClassification': col_layer_for_check = ['encoder.layers[0].self_attn.q_proj'] row_layer_for_check = ['encoder.layers[0].self_attn.out_proj'] else: - col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj'] - row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj'] - check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) - check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) + col_layer_for_check = [ + 'encoder.layers[0].self_attn.q_proj', + # 'decoder.layers[0].self_attn.q_proj' + ] + row_layer_for_check = [ + 'encoder.layers[0].self_attn.out_proj', + #'decoder.layers[0].self_attn.out_proj' + ] + + # check weights and gradients + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + + if stage_manager is None or stage_manager.is_first_stage(): + check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0) + check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) + + # check weights after optimizer.step() + org_optimizer.step() + sharded_optimizer.step() + if test_config['precision'] == 'fp32': + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-3, 5e-3 + if stage_manager is None or stage_manager.is_first_stage(): + check_weight(whisper, + sharded_whisper, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + check_weight(whisper, + sharded_whisper, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False) + + torch.cuda.empty_cache() -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) -def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): +# TODO(jianghai) fix fp16 +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 2, + 'enable_all_optimization': True, + 'use_lazy_init': True, + 'precision': 'fp32', + 'initial_scale': 1, +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', + 'initial_scale': 1, +}, { + 'tp_size': 4, + 'pp_size': 1, + 'enable_all_optimization': True, + 'use_lazy_init': False, + 'precision': 'fp32', +}, { + 'tp_size': 1, + 'pp_size': 4, + 'num_microbatches': 4, + 'use_lazy_init': False, + 'precision': 'fp32', +}]) +def run_whisper_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, - enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification': + continue + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() @@ -73,7 +165,7 @@ def check_whisper(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_whisper(): - spawn(check_whisper, 2) + spawn(check_whisper, 4) if __name__ == "__main__":