From f13954cd583336f6a12cdfa007f0340e0b3d73d4 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 1 Aug 2023 10:35:17 +0800 Subject: [PATCH] [pipeline] refactor test pipeline and remove useless utils in pipeline (#4324) * refactor tests * refactor bloom model * finish policy tests * refactor tests * fix test pure pipeline * remove test pipeline and cutdown launch process * refactor tests * refactor bloom model * finish policy tests * refactor tests * fix test pure pipeline * remove test pipeline and cutdown launch process --- colossalai/pipeline/policy/__init__.py | 22 - colossalai/pipeline/policy/base.py | 111 ---- colossalai/pipeline/policy/bert.py | 523 ------------------ colossalai/pipeline/policy/bloom.py | 220 -------- colossalai/pipeline/schedule/one_f_one_b.py | 1 - colossalai/shardformer/policies/bert.py | 2 +- .../test_bert_for_pretraining_model.py | 64 --- .../test_policy/test_bert_lm_head_model.py | 64 --- .../test_policy/test_bert_model.py | 66 --- .../test_policy/test_bloom_model.py | 63 --- .../test_model/test_shard_bert.py | 3 + .../test_model/test_shard_bert_pipeline.py | 104 ++-- .../test_model/test_shard_bloom_pipeline.py | 71 +-- .../test_model/test_shard_llama_pipeline.py | 70 +-- 14 files changed, 138 insertions(+), 1246 deletions(-) delete mode 100644 colossalai/pipeline/policy/__init__.py delete mode 100644 colossalai/pipeline/policy/base.py delete mode 100644 colossalai/pipeline/policy/bert.py delete mode 100644 colossalai/pipeline/policy/bloom.py delete mode 100644 tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py delete mode 100644 tests/test_pipeline/test_policy/test_bert_lm_head_model.py delete mode 100644 tests/test_pipeline/test_policy/test_bert_model.py delete mode 100644 tests/test_pipeline/test_policy/test_bloom_model.py diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py deleted file mode 100644 index fd9e6e045..000000000 --- a/colossalai/pipeline/policy/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Type - -from torch import Tensor -from torch.nn import Module, Parameter - -from colossalai.pipeline.stage_manager import PipelineStageManager - -from .base import Policy -from .bert import BertModel, BertModelPolicy - -POLICY_MAP: Dict[Type[Module], Type[Policy]] = { - BertModel: BertModelPolicy, -} - - -def pipeline_parallelize( - model: Module, - stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: - if type(model) not in POLICY_MAP: - raise NotImplementedError(f"Policy for {type(model)} not implemented") - policy = POLICY_MAP[type(model)](stage_manager) - return policy.parallelize_model(model) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py deleted file mode 100644 index f51d74fdb..000000000 --- a/colossalai/pipeline/policy/base.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -from torch import Tensor -from torch.nn import Module, Parameter - -from colossalai.lazy import LazyTensor -from colossalai.pipeline.stage_manager import PipelineStageManager - - -class Policy: - - def __init__(self, stage_manager: PipelineStageManager) -> None: - self.stage_manager = stage_manager - - def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]: - """Setup model for pipeline parallel - - Args: - module (Module): Module to be setup - - Returns: - Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers - """ - hold_params = set() - hold_buffers = set() - - def init_layer(layer: Module): - for p in layer.parameters(): - if isinstance(p, LazyTensor): - p.materialize() - p.data = p.cuda() - hold_params.add(p) - for b in layer.buffers(): - if isinstance(b, LazyTensor): - b.materialize() - b.data = b.cuda() - hold_buffers.add(b) - - hold_layers = self.get_hold_layers(module) - - for layer in hold_layers: - init_layer(layer) - - hold_params_dict = {} - hold_buffers_dict = {} - - # release other tensors - for n, p in module.named_parameters(): - if p in hold_params: - hold_params_dict[n] = p - else: - if isinstance(p, LazyTensor): - p.materialize() - p.data = p.cuda() - p.storage().resize_(0) - for n, b in module.named_buffers(): - if b in hold_buffers: - hold_buffers_dict[n] = b - else: - if isinstance(b, LazyTensor): - b.materialize() - b.data = b.cuda() - # FIXME(ver217): use meta tensor may be better - b.storage().resize_(0) - return hold_params_dict, hold_buffers_dict - - def replace_forward(self, module: Module) -> None: - """Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict - - Args: - module (Module): _description_ - """ - raise NotImplementedError - - def get_hold_layers(self, module: Module) -> List[Module]: - """Get layers that should be hold in current stage. This method should be implemented by subclass. - - Args: - module (Module): Module to be setup - - Returns: - List[Module]: List of layers that should be hold in current stage - """ - raise NotImplementedError - - def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: - """Get parameters that should be shared across stages. This method should be implemented by subclass. - - Args: - module (Module): Module to be setup - - Returns: - List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] - """ - raise NotImplementedError - - def parallelize_model(self, - module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: - """Parallelize model for pipeline parallel - - Args: - module (Module): Module to be setup - - Returns: - Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters - """ - hold_params, hold_buffers = self.setup_model(module) - self.replace_forward(module) - shared_params = self.get_shared_params(module) - return hold_params, hold_buffers, shared_params diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py deleted file mode 100644 index abce504e9..000000000 --- a/colossalai/pipeline/policy/bert.py +++ /dev/null @@ -1,523 +0,0 @@ -from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, -) -from transformers.models.bert.modeling_bert import ( - BertForPreTraining, - BertForPreTrainingOutput, - BertLMHeadModel, - BertModel, -) -from transformers.utils import ModelOutput, logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - -from .base import Policy - -logger = logging.get_logger(__name__) - - -def bert_model_forward( - self: BertModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - # labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage -): - # TODO: add explaination of the output here. - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - # debugging - # preprocess: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.is_decoder: - use_cache = use_cache if use_cache is not None else self.config.use_cache - else: - use_cache = False - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - else: - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape - device = hidden_states.device - - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - - if token_type_ids is None: - if hasattr(self.embeddings, "token_type_ids"): - buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) - token_type_ids = buffered_token_type_ids_expanded - else: - token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - hidden_states = hidden_states if hidden_states is not None else None - - if stage_manager.is_first_stage(): - hidden_states = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - - # inherit from bert_layer,this should be changed when we add the feature to record hidden_states - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - - if self.encoder.gradient_checkpointing and self.encoder.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - next_decoder_cache = () if use_cache else None - - # calculate the num_layers - num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - # layer_outputs - layer_outputs = hidden_states if hidden_states is not None else None - for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): - if stage_manager.is_first_stage() and idx == 0: - encoder_attention_mask = encoder_extended_attention_mask - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[idx] if head_mask is not None else None - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.encoder.gradient_checkpointing and self.encoder.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + \ - (layer_outputs[2],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # end of a stage loop - sequence_output = layer_outputs[0] if layer_outputs is not None else None - - if stage_manager.is_last_stage(): - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - if not return_dict: - return (sequence_output, pooled_output) + layer_outputs[1:] - # return dict is not supported at this moment - else: - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - # output of non-first and non-last stages: must be a dict - else: - # intermediate stage always return dict - return { - 'hidden_states': hidden_states, - } - - -# The layer partition policy for bertmodel -class BertModelPolicy(Policy): - - def __init__( - self, - stage_manager: PipelineStageManager, - num_layers: int, - ): - super().__init__(stage_manager=stage_manager) - self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) - - def get_hold_layers(self, module: BertModel) -> List[Module]: - """ - get pipeline layers for current stage - """ - hold_layers = [] - if self.stage_manager.is_first_stage(): - hold_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) - hold_layers.extend(module.encoder.layer[start_idx:end_idx]) - if self.stage_manager.is_last_stage(): - hold_layers.append(module.pooler) - - return hold_layers - - def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: - '''no shared params in bertmodel''' - return [] - - def replace_forward(self, module: Module) -> None: - module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module) - - -def bert_for_pretraining_forward( - self: BertForPreTraining, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - next_sentence_label: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None, -): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - if stage_manager.is_last_stage(): - sequence_output, pooled_output = outputs[:2] - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - # the last stage for pretraining model - total_loss = None - if labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) - total_loss = masked_lm_loss + next_sentence_loss - - if not return_dict: - output = (prediction_scores, seq_relationship_score) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - - return BertForPreTrainingOutput( - loss=total_loss, - prediction_logits=prediction_scores, - seq_relationship_logits=seq_relationship_score, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - - # intermediate stage always return dict - return { - 'hidden_states': hidden_states, - } - - -class BertForPreTrainingPolicy(Policy): - - def __init__(self, stage_manager: PipelineStageManager, num_layers: int): - super().__init__(stage_manager=stage_manager) - self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) - - def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: - """ - get pipeline layers for current stage - """ - hold_layers = [] - if self.stage_manager.is_first_stage(): - hold_layers.append(module.bert.embeddings) - - start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) - hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) - - if self.stage_manager.is_last_stage(): - hold_layers.append(module.bert.pooler) - hold_layers.append(module.cls) - - return hold_layers - - def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: - '''no shared params in bertmodel''' - return [] - - def replace_forward(self, module: Module) -> None: - module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), - module.forward) - - -def bert_lmhead_forward(self: BertLMHeadModel, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.Tensor]] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_manager: Optional[PipelineStageManager] = None): - r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - use_cache = False - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if return_dict: - logger.warning_once('return_dict is not supported for pipeline models at the moment') - return_dict = False - - outputs = bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states if hidden_states is not None else None) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None - - if stage_manager.is_last_stage(): - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - else: - hidden_states = outputs.get('hidden_states') - # intermediate stage always return dict - return {'hidden_states': hidden_states} - - -class BertLMHeadModelPolicy(Policy): - - def __init__(self, stage_manager: PipelineStageManager, num_layers: int): - super().__init__(stage_manager=stage_manager) - self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages) - - def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]: - """ - get pipeline layers for current stage - """ - hold_layers = [] - if self.stage_manager.is_first_stage(): - hold_layers.append(module.bert.embeddings) - start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) - hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx]) - if self.stage_manager.is_last_stage(): - hold_layers.append(module.bert.pooler) - hold_layers.append(module.cls) - - return hold_layers - - def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]: - '''no shared params in bertmodel''' - return [] - - def replace_forward(self, module: Module) -> None: - module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module) diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py deleted file mode 100644 index 71d2913fc..000000000 --- a/colossalai/pipeline/policy/bloom.py +++ /dev/null @@ -1,220 +0,0 @@ -import warnings -from functools import partial -from types import MethodType -from typing import Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from torch import Tensor -from torch.nn import CrossEntropyLoss, Module -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.bloom.modeling_bloom import BloomModel -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - -from .base import Policy - -logger = logging.get_logger(__name__) - - -def bloom_model_forward( - self: BloomModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - **deprecated_arguments, -) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # add warnings here - if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') - output_attentions = False - if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') - output_hidden_states = False - if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') - use_cache = False - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - # case: First stage of training - if stage_manager.is_first_stage(): - # check input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - # initialize in the first stage and then pass to the next stage - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - - # extra recording tensor should be generated in the first stage - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len - - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) - - alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - - # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - # calculate the num_layers - num_layers_per_stage = len(self.h) // stage_manager.num_stages - start_layer = stage_manager.stage * num_layers_per_stage - end_layer = (stage_manager.stage + 1) * num_layers_per_stage - - for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - alibi, - causal_mask, - layer_past, - head_mask[i], - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - ) - - hidden_states = outputs[0] - - if use_cache is True: - presents = presents + (outputs[1],) - if output_attentions: - all_self_attentions = all_self_attentions + \ - (outputs[2 if use_cache else 1],) - - if stage_manager.is_last_stage(): - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - - # TODO: deal with all_hidden_states, all_self_attentions, presents - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - # attention_mask is not returned ; presents = past_key_values - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class BloomModelPolicy(Policy): - - def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): - super().__init__(stage_manager=stage_manager) - self.stage_manager = stage_manager - self.layers_per_stage = self.distribute_layers(num_layers, num_stages) - - def get_hold_layers(self, module: BloomModel) -> List[Module]: - """ - get pipeline layers for current stage - """ - hold_layers = [] - if self.stage_manager.is_first_stage(): - hold_layers.append(module.word_embeddings) - hold_layers.append(module.word_embeddings_layernorm) - - start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage) - hold_layers.extend(module.h[start_idx:end_idx]) - - if self.stage_manager.is_last_stage(): - hold_layers.append(module.ln_f) - - return hold_layers - - def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]: - '''no shared params in bloommodel''' - pass - - def replace_forward(self, module: Module) -> None: - module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 6ed3055d6..d907d53ed 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -76,7 +76,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict output_obj = model_forward(model, micro_batch, input_obj) - if self.stage_manager.is_last_stage(): loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index f6a4c706e..6f86de232 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -315,7 +315,7 @@ class BertForMaskedLMPolicy(BertPolicy): def module_policy(self): policy = super().module_policy() policy = self.add_lm_head_policy(policy) - mpolicy = self.add_lm_prediction_policy(policy) + policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM if self.pipeline_stage_manager: self.set_pipeline_forward(model_cls=BertForMaskedLM, diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py deleted file mode 100644 index bc3a9bf1b..000000000 --- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bert import BertConfig -from transformers.models.bert.modeling_bert import BertForPreTraining - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_for_pretraining_policy(): - configuration = BertConfig() - model = BertForPreTraining(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertForPreTrainingPolicy() - model_policy.set_model(model) - - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - model_policy.set_shard_config(model_config) - layers = model_policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 6 + 1 - else: - assert len(layers) == 6 + 2 - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_for_pretraining_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_for_pretraining_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert for pretraining model forward and bert for pretraining model policy""" - test_bert_for_pretraining_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py deleted file mode 100644 index 1aeb00123..000000000 --- a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bert import BertConfig -from transformers.models.bert.modeling_bert import BertLMHeadModel - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_lmhead_policy(): - configuration = BertConfig() - model = BertLMHeadModel(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertLMHeadModelPolicy() - model_policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - model_policy.set_shard_config(model_config) - layers = model_policy.get_held_layers() - - if stage_manager.is_first_stage(): - assert len(layers) == 6 + 1 - else: - assert len(layers) == 6 + 2 - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_lmhead_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_lmhead_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert for lm head model policy""" - test_bert_lmhead_policy() diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py deleted file mode 100644 index b366df017..000000000 --- a/tests/test_pipeline/test_policy/test_bert_model.py +++ /dev/null @@ -1,66 +0,0 @@ -''' -In the test policy we only test policy: held layers and others, as the tests for forward logic are done in test_shardformer/test_model -''' - -import pytest -import torch.distributed as dist -from transformers.models.bert.modeling_bert import BertModel - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bert import BertModelPolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bert_model_policy(): - model = BertModel.from_pretrained('bert-base-uncased') - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BertModelPolicy() - model_policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - model_policy.set_shard_config(model_config) - - layers = model_policy.get_held_layers() - - if stage_manager.is_first_stage(): - assert len(layers) == 6 + 1 - else: - assert len(layers) == 6 + 1 - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bert_model_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bert_model_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bert model policy""" - test_bert_model_policy() diff --git a/tests/test_pipeline/test_policy/test_bloom_model.py b/tests/test_pipeline/test_policy/test_bloom_model.py deleted file mode 100644 index e6a222f4e..000000000 --- a/tests/test_pipeline/test_policy/test_bloom_model.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from transformers.models.bloom import BloomConfig, BloomModel - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.policies.bloom import BloomModelPolicy -from colossalai.shardformer.shard import ShardConfig -from colossalai.testing import rerun_if_address_is_in_use, spawn - - -def check_bloom_model_policy(): - # create a BloomModel - configuration = BloomConfig() - model = BloomModel(configuration) - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) - # print(pg_mesh) - - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - rank = dist.get_rank() - - model_policy = BloomModelPolicy() - model_policy.set_model(model) - model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) - model_policy.set_shard_config(model_config) - layers = model_policy.get_held_layers() - if stage_manager.is_first_stage(): - assert len(layers) == 1 + 2 - else: - assert len(layers) == 1 + 1 - - -def run_dist_policy(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') - check_bloom_model_policy() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_bloom_model_policy(): - spawn(run_dist_policy, 4) - - -if __name__ == "__main__": - """test the bloom model policy""" - test_bloom_model_policy() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index ea0f12264..6d0d3c798 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -2,7 +2,10 @@ import pytest import torch import colossalai +from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py index 4feaf982a..3170b58a1 100644 --- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py @@ -5,6 +5,8 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.auto_policy import get_autopolicy +from colossalai.shardformer.shard import ShardConfig from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -17,9 +19,55 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - pass +def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): + stage_manager = stage_manager + policy = get_autopolicy(model) + policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + policy.set_shard_config(model_config) + layers = policy.get_held_layers() + if stage_manager.is_first_stage(): + assert len(layers) == 1 + 1 + else: + if name == "transformers_bert": + assert len(layers) == 1 + 1 + elif name in [ + "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification", + "transformers_bert_for_mcq" + ]: + assert len(layers) == 1 + 3 + else: + assert len(layers) == 1 + 2 + + +def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): + if name == 'transformers_bert_for_mcq': + x = torch.randint(0, 1000, (2, 3, 3)).cuda() + attention_mask = torch.ones_like(x).cuda() + if stage_manager.stage == 0: + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + assert output['hidden_states'].shape == (6, 3, 128) + else: + hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() + output = sharded_model(input_ids=x, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + assert output[0].shape == (2, 3) + else: + x = torch.randint(0, 1000, (2, 3)).cuda() + # one batch, 2 single sentences, each sentence has 3 tokens + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model(hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + assert output[0].shape[0] == 2 @parameterize('enable_fused_normalization', [False]) @@ -27,55 +75,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_bert def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + PP_DIM = 0 + PP_SIZE = 2 + pg_mesh = ProcessGroupMesh(PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - - if name == 'transformers_bert_for_mcq': - x = torch.randint(0, 1000, (2, 3, 3)).cuda() - attention_mask = torch.ones_like(x).cuda() - if stage_manager.stage == 0: - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - assert output['hidden_states'].shape == (6, 3, 128) - else: - hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda() - output = sharded_model(input_ids=x, - hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - assert output[0].shape == (2, 3) - else: - x = torch.randint(0, 1000, (2, 3)).cuda() - # one batch, 2 single sentences, each sentence has 3 tokens - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model(hidden_states=hidden_states, - attention_mask=attention_mask, - stage_manager=stage_manager) - assert output[0].shape[0] == 2 + check_bert_model_policy(name, org_model, stage_manager) + check_bert_model_pipeline_forward(name, sharded_model, stage_manager) torch.cuda.empty_cache() @@ -90,7 +100,7 @@ def check_bert(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bert(): - spawn(check_bert, 4) + spawn(check_bert, 2) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py index 3a36479fc..6695e8a68 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py @@ -5,7 +5,9 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.shard import ShardConfig from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -18,9 +20,37 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - pass +def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): + policy = get_autopolicy(model) + policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + policy.set_shard_config(model_config) + layers = policy.get_held_layers() + if stage_manager.is_first_stage(): + assert len(layers) == 0 + 2 + else: + if name == 'transformers_bloom': + assert len(layers) == 1 + 1 + elif name == 'transformers_bloom_for_token_classification': + assert len(layers) == 1 + 3 + else: + assert len(layers) == 1 + 2 + + +def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): + if stage_manager.stage == 0: + x = torch.randint(0, 1000, (1, 3)).cuda() + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (1, 3, 64) + else: + attention_mask = torch.ones((1, 3)).cuda() + hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0].shape[0] == 1 @parameterize('enable_fused_normalization', [False]) @@ -28,40 +58,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_bloom def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + PP_DIM = 0 + PP_SIZE = 2 + pg_mesh = ProcessGroupMesh(PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') - x = torch.randint(0, 1000, (1, 3)).cuda() - hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (1, 3, 64) - else: - attention_mask = torch.ones((1, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0].shape[0] == 1 + check_bloom_model_policy(name, org_model, stage_manager) + check_bloom_model_pipeline_forward(name, sharded_model, stage_manager) torch.cuda.empty_cache() @@ -76,7 +83,7 @@ def check_bloom(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_bloom(): - spawn(check_bloom, 4) + spawn(check_bloom, 2) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 8fd9ed099..6f1f0cb34 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -5,7 +5,9 @@ import colossalai from colossalai.cluster import ProcessGroupMesh from colossalai.logging import disable_existing_loggers from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.shard import ShardConfig from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, @@ -18,9 +20,35 @@ from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward -def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): - # check forward - pass +def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager): + policy = get_autopolicy(model) + policy.set_model(model) + model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False) + policy.set_shard_config(model_config) + layers = policy.get_held_layers() + if stage_manager.is_first_stage(): + assert len(layers) == 2 + 1 + else: + if name == "transformers_llama": + assert len(layers) == 2 + 1 + else: + assert len(layers) == 2 + 2 + + +def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager): + x = torch.randint(0, 1000, (2, 3)).cuda() + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (2, 3, 128) + else: + hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0] is not None @parameterize('enable_fused_normalization', [False]) @@ -28,40 +56,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo @parameterize('use_lazy_init', [False]) #TODO: merge this into test_shard_llama def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - DP_DIM, PP_DIM = 0, 1 - DP_SIZE, PP_SIZE = 2, 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + PP_DIM = 0 + PP_SIZE = 2 + pg_mesh = ProcessGroupMesh(PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - x = torch.randint(0, 1000, (2, 3)).cuda() - hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - assert output[0] is not None + check_llama_model_policy(name, org_model, stage_manager) + check_llama_model_pipeline_forward(name, sharded_model, stage_manager) torch.cuda.empty_cache() @@ -76,7 +82,7 @@ def check_llama(rank, world_size, port): @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(): - spawn(check_llama, 4) + spawn(check_llama, 2) if __name__ == "__main__":