From f13954cd583336f6a12cdfa007f0340e0b3d73d4 Mon Sep 17 00:00:00 2001
From: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Date: Tue, 1 Aug 2023 10:35:17 +0800
Subject: [PATCH] [pipeline] refactor test pipeline and remove useless utils in
 pipeline (#4324)

* refactor tests

* refactor bloom model

* finish policy tests

* refactor tests

* fix test pure pipeline

* remove test pipeline and cutdown launch process

* refactor tests

* refactor bloom model

* finish policy tests

* refactor tests

* fix test pure pipeline

* remove test pipeline and cutdown launch process
---
 colossalai/pipeline/policy/__init__.py        |  22 -
 colossalai/pipeline/policy/base.py            | 111 ----
 colossalai/pipeline/policy/bert.py            | 523 ------------------
 colossalai/pipeline/policy/bloom.py           | 220 --------
 colossalai/pipeline/schedule/one_f_one_b.py   |   1 -
 colossalai/shardformer/policies/bert.py       |   2 +-
 .../test_bert_for_pretraining_model.py        |  64 ---
 .../test_policy/test_bert_lm_head_model.py    |  64 ---
 .../test_policy/test_bert_model.py            |  66 ---
 .../test_policy/test_bloom_model.py           |  63 ---
 .../test_model/test_shard_bert.py             |   3 +
 .../test_model/test_shard_bert_pipeline.py    | 104 ++--
 .../test_model/test_shard_bloom_pipeline.py   |  71 +--
 .../test_model/test_shard_llama_pipeline.py   |  70 +--
 14 files changed, 138 insertions(+), 1246 deletions(-)
 delete mode 100644 colossalai/pipeline/policy/__init__.py
 delete mode 100644 colossalai/pipeline/policy/base.py
 delete mode 100644 colossalai/pipeline/policy/bert.py
 delete mode 100644 colossalai/pipeline/policy/bloom.py
 delete mode 100644 tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py
 delete mode 100644 tests/test_pipeline/test_policy/test_bert_lm_head_model.py
 delete mode 100644 tests/test_pipeline/test_policy/test_bert_model.py
 delete mode 100644 tests/test_pipeline/test_policy/test_bloom_model.py

diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py
deleted file mode 100644
index fd9e6e045..000000000
--- a/colossalai/pipeline/policy/__init__.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from typing import Any, Dict, List, Optional, Tuple, Type
-
-from torch import Tensor
-from torch.nn import Module, Parameter
-
-from colossalai.pipeline.stage_manager import PipelineStageManager
-
-from .base import Policy
-from .bert import BertModel, BertModelPolicy
-
-POLICY_MAP: Dict[Type[Module], Type[Policy]] = {
-    BertModel: BertModelPolicy,
-}
-
-
-def pipeline_parallelize(
-        model: Module,
-        stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
-    if type(model) not in POLICY_MAP:
-        raise NotImplementedError(f"Policy for {type(model)} not implemented")
-    policy = POLICY_MAP[type(model)](stage_manager)
-    return policy.parallelize_model(model)
diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py
deleted file mode 100644
index f51d74fdb..000000000
--- a/colossalai/pipeline/policy/base.py
+++ /dev/null
@@ -1,111 +0,0 @@
-from typing import Any, Dict, List, Optional, Tuple
-
-import numpy as np
-from torch import Tensor
-from torch.nn import Module, Parameter
-
-from colossalai.lazy import LazyTensor
-from colossalai.pipeline.stage_manager import PipelineStageManager
-
-
-class Policy:
-
-    def __init__(self, stage_manager: PipelineStageManager) -> None:
-        self.stage_manager = stage_manager
-
-    def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]:
-        """Setup model for pipeline parallel
-
-        Args:
-            module (Module): Module to be setup
-
-        Returns:
-            Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers
-        """
-        hold_params = set()
-        hold_buffers = set()
-
-        def init_layer(layer: Module):
-            for p in layer.parameters():
-                if isinstance(p, LazyTensor):
-                    p.materialize()
-                p.data = p.cuda()
-                hold_params.add(p)
-            for b in layer.buffers():
-                if isinstance(b, LazyTensor):
-                    b.materialize()
-                b.data = b.cuda()
-                hold_buffers.add(b)
-
-        hold_layers = self.get_hold_layers(module)
-
-        for layer in hold_layers:
-            init_layer(layer)
-
-        hold_params_dict = {}
-        hold_buffers_dict = {}
-
-        # release other tensors
-        for n, p in module.named_parameters():
-            if p in hold_params:
-                hold_params_dict[n] = p
-            else:
-                if isinstance(p, LazyTensor):
-                    p.materialize()
-                p.data = p.cuda()
-                p.storage().resize_(0)
-        for n, b in module.named_buffers():
-            if b in hold_buffers:
-                hold_buffers_dict[n] = b
-            else:
-                if isinstance(b, LazyTensor):
-                    b.materialize()
-                b.data = b.cuda()
-                # FIXME(ver217): use meta tensor may be better
-                b.storage().resize_(0)
-        return hold_params_dict, hold_buffers_dict
-
-    def replace_forward(self, module: Module) -> None:
-        """Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict
-
-        Args:
-            module (Module): _description_
-        """
-        raise NotImplementedError
-
-    def get_hold_layers(self, module: Module) -> List[Module]:
-        """Get layers that should be hold in current stage. This method should be implemented by subclass.
-
-        Args:
-            module (Module): Module to be setup
-
-        Returns:
-            List[Module]: List of layers that should be hold in current stage
-        """
-        raise NotImplementedError
-
-    def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]:
-        """Get parameters that should be shared across stages. This method should be implemented by subclass.
-
-        Args:
-            module (Module): Module to be setup
-
-        Returns:
-            List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
-        """
-        raise NotImplementedError
-
-    def parallelize_model(self,
-                          module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
-        """Parallelize model for pipeline parallel
-
-        Args:
-            module (Module): Module to be setup
-
-        Returns:
-            Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters
-        """
-        hold_params, hold_buffers = self.setup_model(module)
-        self.replace_forward(module)
-        shared_params = self.get_shared_params(module)
-        return hold_params, hold_buffers, shared_params
diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py
deleted file mode 100644
index abce504e9..000000000
--- a/colossalai/pipeline/policy/bert.py
+++ /dev/null
@@ -1,523 +0,0 @@
-from functools import partial
-from types import MethodType
-from typing import Dict, List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-from torch import Tensor
-from torch.nn import CrossEntropyLoss, Module
-from transformers.modeling_outputs import (
-    BaseModelOutputWithPast,
-    BaseModelOutputWithPastAndCrossAttentions,
-    BaseModelOutputWithPoolingAndCrossAttentions,
-    CausalLMOutputWithCrossAttentions,
-)
-from transformers.models.bert.modeling_bert import (
-    BertForPreTraining,
-    BertForPreTrainingOutput,
-    BertLMHeadModel,
-    BertModel,
-)
-from transformers.utils import ModelOutput, logging
-
-from colossalai.pipeline.stage_manager import PipelineStageManager
-
-from .base import Policy
-
-logger = logging.get_logger(__name__)
-
-
-def bert_model_forward(
-        self: BertModel,
-        input_ids: Optional[torch.Tensor] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        token_type_ids: Optional[torch.Tensor] = None,
-        position_ids: Optional[torch.Tensor] = None,
-        head_mask: Optional[torch.Tensor] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-        encoder_hidden_states: Optional[torch.Tensor] = None,
-        encoder_attention_mask: Optional[torch.Tensor] = None,
-        past_key_values: Optional[List[torch.FloatTensor]] = None,
-    # labels: Optional[torch.LongTensor] = None,
-        use_cache: Optional[bool] = None,
-        output_attentions: Optional[bool] = None,
-        output_hidden_states: Optional[bool] = None,
-        return_dict: Optional[bool] = None,
-        stage_manager: Optional[PipelineStageManager] = None,
-        hidden_states: Optional[torch.FloatTensor] = None,    # this is from the previous stage
-):
-    # TODO: add explaination of the output here.
-    r"""
-        encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
-            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
-            the model is configured as a decoder.
-        encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
-            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
-
-            - 1 for tokens that are **not masked**,
-            - 0 for tokens that are **masked**.
-        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
-            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
-
-            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
-            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
-            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
-        use_cache (`bool`, *optional*):
-            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
-            `past_key_values`).
-        """
-    # debugging
-    # preprocess:
-    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-    output_hidden_states = (output_hidden_states
-                            if output_hidden_states is not None else self.config.output_hidden_states)
-    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-    if self.config.is_decoder:
-        use_cache = use_cache if use_cache is not None else self.config.use_cache
-    else:
-        use_cache = False
-
-    if stage_manager.is_first_stage():
-        if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
-        elif input_ids is not None:
-            input_shape = input_ids.size()
-        elif inputs_embeds is not None:
-            input_shape = inputs_embeds.size()[:-1]
-        else:
-            raise ValueError("You have to specify either input_ids or inputs_embeds")
-        batch_size, seq_length = input_shape
-        device = input_ids.device if input_ids is not None else inputs_embeds.device
-    else:
-        input_shape = hidden_states.size()[:-1]
-        batch_size, seq_length = input_shape
-        device = hidden_states.device
-
-    # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
-    if output_attentions:
-        logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
-        output_attentions = False
-    if output_hidden_states:
-        logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
-        output_hidden_states = False
-    if use_cache:
-        logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
-        use_cache = False
-
-    # past_key_values_length
-    past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
-
-    if attention_mask is None:
-        attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
-
-    if token_type_ids is None:
-        if hasattr(self.embeddings, "token_type_ids"):
-            buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
-            buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
-            token_type_ids = buffered_token_type_ids_expanded
-        else:
-            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
-
-    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
-    # ourselves in which case we just need to make it broadcastable to all heads.
-    extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
-    attention_mask = extended_attention_mask
-    # If a 2D or 3D attention mask is provided for the cross-attention
-    # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
-    if self.config.is_decoder and encoder_hidden_states is not None:
-        encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
-        encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
-        if encoder_attention_mask is None:
-            encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
-        encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
-    else:
-        encoder_extended_attention_mask = None
-
-    # Prepare head mask if needed
-    # 1.0 in head_mask indicate we keep the head
-    # attention_probs has shape bsz x n_heads x N x N
-    # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
-    # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
-    head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
-    hidden_states = hidden_states if hidden_states is not None else None
-
-    if stage_manager.is_first_stage():
-        hidden_states = self.embeddings(
-            input_ids=input_ids,
-            position_ids=position_ids,
-            token_type_ids=token_type_ids,
-            inputs_embeds=inputs_embeds,
-            past_key_values_length=past_key_values_length,
-        )
-
-    # inherit from bert_layer,this should be changed when we add the feature to record hidden_states
-    all_hidden_states = () if output_hidden_states else None
-    all_self_attentions = () if output_attentions else None
-    all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
-
-    if self.encoder.gradient_checkpointing and self.encoder.training:
-        if use_cache:
-            logger.warning_once(
-                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
-            use_cache = False
-    next_decoder_cache = () if use_cache else None
-
-    # calculate the num_layers
-    num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages
-    start_layer = stage_manager.stage * num_layers_per_stage
-    end_layer = (stage_manager.stage + 1) * num_layers_per_stage
-
-    # layer_outputs
-    layer_outputs = hidden_states if hidden_states is not None else None
-    for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer):
-        if stage_manager.is_first_stage() and idx == 0:
-            encoder_attention_mask = encoder_extended_attention_mask
-
-        if output_hidden_states:
-            all_hidden_states = all_hidden_states + (hidden_states,)
-
-        layer_head_mask = head_mask[idx] if head_mask is not None else None
-        past_key_value = past_key_values[idx] if past_key_values is not None else None
-
-        if self.encoder.gradient_checkpointing and self.encoder.training:
-
-            def create_custom_forward(module):
-
-                def custom_forward(*inputs):
-                    return module(*inputs, past_key_value, output_attentions)
-
-                return custom_forward
-
-            layer_outputs = torch.utils.checkpoint.checkpoint(
-                create_custom_forward(encoder_layer),
-                hidden_states,
-                attention_mask,
-                layer_head_mask,
-                encoder_hidden_states,
-                encoder_attention_mask,
-            )
-        else:
-            layer_outputs = encoder_layer(
-                hidden_states,
-                attention_mask,
-                layer_head_mask,
-                encoder_hidden_states,
-                encoder_attention_mask,
-                past_key_value,
-                output_attentions,
-            )
-        hidden_states = layer_outputs[0]
-        if use_cache:
-            next_decoder_cache += (layer_outputs[-1],)
-        if output_attentions:
-            all_self_attentions = all_self_attentions + (layer_outputs[1],)
-            if self.config.add_cross_attention:
-                all_cross_attentions = all_cross_attentions + \
-                    (layer_outputs[2],)
-
-    if output_hidden_states:
-        all_hidden_states = all_hidden_states + (hidden_states,)
-
-    # end of a stage loop
-    sequence_output = layer_outputs[0] if layer_outputs is not None else None
-
-    if stage_manager.is_last_stage():
-        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
-        if not return_dict:
-            return (sequence_output, pooled_output) + layer_outputs[1:]
-        # return dict is not supported at this moment
-        else:
-            return BaseModelOutputWithPastAndCrossAttentions(
-                last_hidden_state=hidden_states,
-                past_key_values=next_decoder_cache,
-                hidden_states=all_hidden_states,
-                attentions=all_self_attentions,
-                cross_attentions=all_cross_attentions,
-            )
-
-    # output of non-first and non-last stages: must be a dict
-    else:
-        # intermediate stage always return dict
-        return {
-            'hidden_states': hidden_states,
-        }
-
-
-# The layer partition policy for bertmodel
-class BertModelPolicy(Policy):
-
-    def __init__(
-        self,
-        stage_manager: PipelineStageManager,
-        num_layers: int,
-    ):
-        super().__init__(stage_manager=stage_manager)
-        self.stage_manager = stage_manager
-        self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages)
-
-    def get_hold_layers(self, module: BertModel) -> List[Module]:
-        """
-        get pipeline layers for current stage
-        """
-        hold_layers = []
-        if self.stage_manager.is_first_stage():
-            hold_layers.append(module.embeddings)
-        start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
-        hold_layers.extend(module.encoder.layer[start_idx:end_idx])
-        if self.stage_manager.is_last_stage():
-            hold_layers.append(module.pooler)
-
-        return hold_layers
-
-    def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]:
-        '''no shared params in bertmodel'''
-        return []
-
-    def replace_forward(self, module: Module) -> None:
-        module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module)
-
-
-def bert_for_pretraining_forward(
-    self: BertForPreTraining,
-    input_ids: Optional[torch.Tensor] = None,
-    attention_mask: Optional[torch.Tensor] = None,
-    token_type_ids: Optional[torch.Tensor] = None,
-    position_ids: Optional[torch.Tensor] = None,
-    head_mask: Optional[torch.Tensor] = None,
-    inputs_embeds: Optional[torch.Tensor] = None,
-    labels: Optional[torch.Tensor] = None,
-    next_sentence_label: Optional[torch.Tensor] = None,
-    output_attentions: Optional[bool] = None,
-    output_hidden_states: Optional[bool] = None,
-    return_dict: Optional[bool] = None,
-    hidden_states: Optional[torch.FloatTensor] = None,
-    stage_manager: Optional[PipelineStageManager] = None,
-):
-    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-    # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
-    if output_attentions:
-        logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
-        output_attentions = False
-    if output_hidden_states:
-        logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
-        output_hidden_states = False
-    if return_dict:
-        logger.warning_once('return_dict is not supported for pipeline models at the moment')
-        return_dict = False
-
-    outputs = bert_model_forward(self.bert,
-                                 input_ids,
-                                 attention_mask=attention_mask,
-                                 token_type_ids=token_type_ids,
-                                 position_ids=position_ids,
-                                 head_mask=head_mask,
-                                 inputs_embeds=inputs_embeds,
-                                 output_attentions=output_attentions,
-                                 output_hidden_states=output_hidden_states,
-                                 return_dict=return_dict,
-                                 stage_manager=stage_manager,
-                                 hidden_states=hidden_states if hidden_states is not None else None)
-    past_key_values = None
-    all_hidden_states = None
-    all_self_attentions = None
-    all_cross_attentions = None
-    if stage_manager.is_last_stage():
-        sequence_output, pooled_output = outputs[:2]
-        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
-        # the last stage for pretraining model
-        total_loss = None
-        if labels is not None and next_sentence_label is not None:
-            loss_fct = CrossEntropyLoss()
-            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
-            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
-            total_loss = masked_lm_loss + next_sentence_loss
-
-        if not return_dict:
-            output = (prediction_scores, seq_relationship_score) + outputs[2:]
-            return ((total_loss,) + output) if total_loss is not None else output
-
-        return BertForPreTrainingOutput(
-            loss=total_loss,
-            prediction_logits=prediction_scores,
-            seq_relationship_logits=seq_relationship_score,
-            hidden_states=outputs.hidden_states,
-            attentions=outputs.attentions,
-        )
-    else:
-        hidden_states = outputs.get('hidden_states')
-
-        # intermediate stage always return dict
-        return {
-            'hidden_states': hidden_states,
-        }
-
-
-class BertForPreTrainingPolicy(Policy):
-
-    def __init__(self, stage_manager: PipelineStageManager, num_layers: int):
-        super().__init__(stage_manager=stage_manager)
-        self.stage_manager = stage_manager
-        self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages)
-
-    def get_hold_layers(self, module: BertForPreTraining) -> List[Module]:
-        """
-        get pipeline layers for current stage
-        """
-        hold_layers = []
-        if self.stage_manager.is_first_stage():
-            hold_layers.append(module.bert.embeddings)
-
-        start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
-        hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
-
-        if self.stage_manager.is_last_stage():
-            hold_layers.append(module.bert.pooler)
-            hold_layers.append(module.cls)
-
-        return hold_layers
-
-    def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]:
-        '''no shared params in bertmodel'''
-        return []
-
-    def replace_forward(self, module: Module) -> None:
-        module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager),
-                                    module.forward)
-
-
-def bert_lmhead_forward(self: BertLMHeadModel,
-                        input_ids: Optional[torch.Tensor] = None,
-                        attention_mask: Optional[torch.Tensor] = None,
-                        token_type_ids: Optional[torch.Tensor] = None,
-                        position_ids: Optional[torch.Tensor] = None,
-                        head_mask: Optional[torch.Tensor] = None,
-                        inputs_embeds: Optional[torch.Tensor] = None,
-                        encoder_hidden_states: Optional[torch.Tensor] = None,
-                        encoder_attention_mask: Optional[torch.Tensor] = None,
-                        labels: Optional[torch.Tensor] = None,
-                        past_key_values: Optional[List[torch.Tensor]] = None,
-                        use_cache: Optional[bool] = None,
-                        output_attentions: Optional[bool] = None,
-                        output_hidden_states: Optional[bool] = None,
-                        return_dict: Optional[bool] = None,
-                        hidden_states: Optional[torch.FloatTensor] = None,
-                        stage_manager: Optional[PipelineStageManager] = None):
-    r"""
-        encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
-            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
-            the model is configured as a decoder.
-        encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
-            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
-
-            - 1 for tokens that are **not masked**,
-            - 0 for tokens that are **masked**.
-        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
-            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
-            ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
-        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
-            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
-
-            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
-            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
-            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
-        use_cache (`bool`, *optional*):
-            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
-            `past_key_values`).
-        """
-    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-    if labels is not None:
-        use_cache = False
-    if output_attentions:
-        logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
-        output_attentions = False
-    if output_hidden_states:
-        logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
-        output_hidden_states = False
-    if return_dict:
-        logger.warning_once('return_dict is not supported for pipeline models at the moment')
-        return_dict = False
-
-    outputs = bert_model_forward(self.bert,
-                                 input_ids,
-                                 attention_mask=attention_mask,
-                                 token_type_ids=token_type_ids,
-                                 position_ids=position_ids,
-                                 head_mask=head_mask,
-                                 inputs_embeds=inputs_embeds,
-                                 encoder_hidden_states=encoder_hidden_states,
-                                 encoder_attention_mask=encoder_attention_mask,
-                                 past_key_values=past_key_values,
-                                 use_cache=use_cache,
-                                 output_attentions=output_attentions,
-                                 output_hidden_states=output_hidden_states,
-                                 return_dict=return_dict,
-                                 stage_manager=stage_manager,
-                                 hidden_states=hidden_states if hidden_states is not None else None)
-    past_key_values = None
-    all_hidden_states = None
-    all_self_attentions = None
-    all_cross_attentions = None
-
-    if stage_manager.is_last_stage():
-        sequence_output = outputs[0]
-        prediction_scores = self.cls(sequence_output)
-
-        lm_loss = None
-        if labels is not None:
-            # we are doing next-token prediction; shift prediction scores and input ids by one
-            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
-            labels = labels[:, 1:].contiguous()
-            loss_fct = CrossEntropyLoss()
-            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
-
-        if not return_dict:
-            output = (prediction_scores,) + outputs[2:]
-            return ((lm_loss,) + output) if lm_loss is not None else output
-
-        return CausalLMOutputWithCrossAttentions(
-            loss=lm_loss,
-            logits=prediction_scores,
-            past_key_values=outputs.past_key_values,
-            hidden_states=outputs.hidden_states,
-            attentions=outputs.attentions,
-            cross_attentions=outputs.cross_attentions,
-        )
-    else:
-        hidden_states = outputs.get('hidden_states')
-        # intermediate stage always return dict
-        return {'hidden_states': hidden_states}
-
-
-class BertLMHeadModelPolicy(Policy):
-
-    def __init__(self, stage_manager: PipelineStageManager, num_layers: int):
-        super().__init__(stage_manager=stage_manager)
-        self.stage_manager = stage_manager
-        self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages)
-
-    def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]:
-        """
-        get pipeline layers for current stage
-        """
-        hold_layers = []
-        if self.stage_manager.is_first_stage():
-            hold_layers.append(module.bert.embeddings)
-        start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
-        hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
-        if self.stage_manager.is_last_stage():
-            hold_layers.append(module.bert.pooler)
-            hold_layers.append(module.cls)
-
-        return hold_layers
-
-    def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]:
-        '''no shared params in bertmodel'''
-        return []
-
-    def replace_forward(self, module: Module) -> None:
-        module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module)
diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py
deleted file mode 100644
index 71d2913fc..000000000
--- a/colossalai/pipeline/policy/bloom.py
+++ /dev/null
@@ -1,220 +0,0 @@
-import warnings
-from functools import partial
-from types import MethodType
-from typing import Dict, List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-from torch import Tensor
-from torch.nn import CrossEntropyLoss, Module
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
-from transformers.models.bloom.modeling_bloom import BloomModel
-from transformers.utils import logging
-
-from colossalai.pipeline.stage_manager import PipelineStageManager
-
-from .base import Policy
-
-logger = logging.get_logger(__name__)
-
-
-def bloom_model_forward(
-    self: BloomModel,
-    input_ids: Optional[torch.LongTensor] = None,
-    past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
-    attention_mask: Optional[torch.Tensor] = None,
-    head_mask: Optional[torch.LongTensor] = None,
-    inputs_embeds: Optional[torch.LongTensor] = None,
-    use_cache: Optional[bool] = None,
-    output_attentions: Optional[bool] = None,
-    output_hidden_states: Optional[bool] = None,
-    return_dict: Optional[bool] = None,
-    stage_manager: Optional[PipelineStageManager] = None,
-    hidden_states: Optional[torch.FloatTensor] = None,
-    **deprecated_arguments,
-) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
-    if deprecated_arguments.pop("position_ids", False) is not False:
-        # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
-        warnings.warn(
-            "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
-            " passing `position_ids`.",
-            FutureWarning,
-        )
-    if len(deprecated_arguments) > 0:
-        raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
-
-    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-    output_hidden_states = (output_hidden_states
-                            if output_hidden_states is not None else self.config.output_hidden_states)
-    use_cache = use_cache if use_cache is not None else self.config.use_cache
-    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-    # add warnings here
-    if output_attentions:
-        logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
-        output_attentions = False
-    if output_hidden_states:
-        logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
-        output_hidden_states = False
-    if use_cache:
-        logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
-        use_cache = False
-    # Prepare head mask if needed
-    # 1.0 in head_mask indicate we keep the head
-    # attention_probs has shape batch_size x num_heads x N x N
-
-    # head_mask has shape n_layer x batch x num_heads x N x N
-    head_mask = self.get_head_mask(head_mask, self.config.n_layer)
-
-    # case: First stage of training
-    if stage_manager.is_first_stage():
-        # check input_ids and inputs_embeds
-        if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
-        elif input_ids is not None:
-            batch_size, seq_length = input_ids.shape
-        elif inputs_embeds is not None:
-            batch_size, seq_length, _ = inputs_embeds.shape
-        else:
-            raise ValueError("You have to specify either input_ids or inputs_embeds")
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
-        # initialize in the first stage and then pass to the next stage
-    else:
-        input_shape = hidden_states.shape[:-1]
-        batch_size, seq_length = input_shape
-
-    # extra recording tensor should be generated in the first stage
-
-    presents = () if use_cache else None
-    all_self_attentions = () if output_attentions else None
-    all_hidden_states = () if output_hidden_states else None
-
-    if self.gradient_checkpointing and self.training:
-        if use_cache:
-            logger.warning_once(
-                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
-            use_cache = False
-
-    if past_key_values is None:
-        past_key_values = tuple([None] * len(self.h))
-    # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
-    seq_length_with_past = seq_length
-    past_key_values_length = 0
-    if past_key_values[0] is not None:
-        past_key_values_length = past_key_values[0][0].shape[2]    # source_len
-
-        seq_length_with_past = seq_length_with_past + past_key_values_length
-    if attention_mask is None:
-        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
-    else:
-        attention_mask = attention_mask.to(hidden_states.device)
-
-    alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
-
-    # causal_mask is constructed every stage and its input is passed through different stages
-    causal_mask = self._prepare_attn_mask(
-        attention_mask,
-        input_shape=(batch_size, seq_length),
-        past_key_values_length=past_key_values_length,
-    )
-
-    # calculate the num_layers
-    num_layers_per_stage = len(self.h) // stage_manager.num_stages
-    start_layer = stage_manager.stage * num_layers_per_stage
-    end_layer = (stage_manager.stage + 1) * num_layers_per_stage
-
-    for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])):
-        if output_hidden_states:
-            all_hidden_states = all_hidden_states + (hidden_states,)
-
-        if self.gradient_checkpointing and self.training:
-
-            def create_custom_forward(module):
-
-                def custom_forward(*inputs):
-                    # None for past_key_value
-                    return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
-
-                return custom_forward
-
-            outputs = torch.utils.checkpoint.checkpoint(
-                create_custom_forward(block),
-                hidden_states,
-                alibi,
-                causal_mask,
-                layer_past,
-                head_mask[i],
-            )
-        else:
-            outputs = block(
-                hidden_states,
-                layer_past=layer_past,
-                attention_mask=causal_mask,
-                head_mask=head_mask[i],
-                use_cache=use_cache,
-                output_attentions=output_attentions,
-                alibi=alibi,
-            )
-
-        hidden_states = outputs[0]
-
-        if use_cache is True:
-            presents = presents + (outputs[1],)
-        if output_attentions:
-            all_self_attentions = all_self_attentions + \
-                (outputs[2 if use_cache else 1],)
-
-    if stage_manager.is_last_stage():
-        # Add last hidden state
-        hidden_states = self.ln_f(hidden_states)
-
-    # TODO: deal with all_hidden_states, all_self_attentions, presents
-    if output_hidden_states:
-        all_hidden_states = all_hidden_states + (hidden_states,)
-
-    if not return_dict:
-        return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
-
-    # attention_mask is not returned ; presents = past_key_values
-    return BaseModelOutputWithPastAndCrossAttentions(
-        last_hidden_state=hidden_states,
-        past_key_values=presents,
-        hidden_states=all_hidden_states,
-        attentions=all_self_attentions,
-    )
-
-
-class BloomModelPolicy(Policy):
-
-    def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
-        super().__init__(stage_manager=stage_manager)
-        self.stage_manager = stage_manager
-        self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
-
-    def get_hold_layers(self, module: BloomModel) -> List[Module]:
-        """
-        get pipeline layers for current stage
-        """
-        hold_layers = []
-        if self.stage_manager.is_first_stage():
-            hold_layers.append(module.word_embeddings)
-            hold_layers.append(module.word_embeddings_layernorm)
-
-        start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
-        hold_layers.extend(module.h[start_idx:end_idx])
-
-        if self.stage_manager.is_last_stage():
-            hold_layers.append(module.ln_f)
-
-        return hold_layers
-
-    def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]:
-        '''no shared params in bloommodel'''
-        pass
-
-    def replace_forward(self, module: Module) -> None:
-        module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model)
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index 6ed3055d6..d907d53ed 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -76,7 +76,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
         # for the first stage, input_obj is None
         # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
         output_obj = model_forward(model, micro_batch, input_obj)
-
         if self.stage_manager.is_last_stage():
             loss = criterion(output_obj, micro_batch) / self.num_microbatches
             if accum_loss is not None:
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index f6a4c706e..6f86de232 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -315,7 +315,7 @@ class BertForMaskedLMPolicy(BertPolicy):
     def module_policy(self):
         policy = super().module_policy()
         policy = self.add_lm_head_policy(policy)
-        mpolicy = self.add_lm_prediction_policy(policy)
+        policy = self.add_lm_prediction_policy(policy)
         from transformers.models.bert.modeling_bert import BertForMaskedLM
         if self.pipeline_stage_manager:
             self.set_pipeline_forward(model_cls=BertForMaskedLM,
diff --git a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py b/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py
deleted file mode 100644
index bc3a9bf1b..000000000
--- a/tests/test_pipeline/test_policy/test_bert_for_pretraining_model.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import pytest
-import torch
-import torch.distributed as dist
-from transformers.models.bert import BertConfig
-from transformers.models.bert.modeling_bert import BertForPreTraining
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.base_policy import Policy
-from colossalai.shardformer.policies.bert import BertForPreTrainingPolicy
-from colossalai.shardformer.shard import ShardConfig
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-
-
-def check_bert_for_pretraining_policy():
-    configuration = BertConfig()
-    model = BertForPreTraining(configuration)
-    DP_DIM, PP_DIM = 0, 1
-    DP_SIZE, PP_SIZE = 2, 2
-    RANK_TO_COORDINATE = {
-        0: (0, 0),
-        1: (0, 1),
-        2: (1, 0),
-        3: (1, 1),
-    }
-    PP_RANKS_IN_GROUP = {
-        0: [0, 1],
-        1: [0, 1],
-        2: [2, 3],
-        3: [2, 3],
-    }
-    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
-    # print(pg_mesh)
-
-    stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-    rank = dist.get_rank()
-
-    model_policy = BertForPreTrainingPolicy()
-    model_policy.set_model(model)
-
-    model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
-    model_policy.set_shard_config(model_config)
-    layers = model_policy.get_held_layers()
-    if stage_manager.is_first_stage():
-        assert len(layers) == 6 + 1
-    else:
-        assert len(layers) == 6 + 2
-
-
-def run_dist_policy(rank, world_size, port):
-    colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
-    check_bert_for_pretraining_policy()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_bert_for_pretraining_policy():
-    spawn(run_dist_policy, 4)
-
-
-if __name__ == "__main__":
-    """test the bert for pretraining model forward and bert for pretraining model policy"""
-    test_bert_for_pretraining_policy()
diff --git a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py b/tests/test_pipeline/test_policy/test_bert_lm_head_model.py
deleted file mode 100644
index 1aeb00123..000000000
--- a/tests/test_pipeline/test_policy/test_bert_lm_head_model.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import pytest
-import torch
-import torch.distributed as dist
-from transformers.models.bert import BertConfig
-from transformers.models.bert.modeling_bert import BertLMHeadModel
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.base_policy import Policy
-from colossalai.shardformer.policies.bert import BertLMHeadModelPolicy
-from colossalai.shardformer.shard import ShardConfig
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-
-
-def check_bert_lmhead_policy():
-    configuration = BertConfig()
-    model = BertLMHeadModel(configuration)
-    DP_DIM, PP_DIM = 0, 1
-    DP_SIZE, PP_SIZE = 2, 2
-    RANK_TO_COORDINATE = {
-        0: (0, 0),
-        1: (0, 1),
-        2: (1, 0),
-        3: (1, 1),
-    }
-    PP_RANKS_IN_GROUP = {
-        0: [0, 1],
-        1: [0, 1],
-        2: [2, 3],
-        3: [2, 3],
-    }
-    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
-    # print(pg_mesh)
-
-    stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-    rank = dist.get_rank()
-
-    model_policy = BertLMHeadModelPolicy()
-    model_policy.set_model(model)
-    model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
-    model_policy.set_shard_config(model_config)
-    layers = model_policy.get_held_layers()
-
-    if stage_manager.is_first_stage():
-        assert len(layers) == 6 + 1
-    else:
-        assert len(layers) == 6 + 2
-
-
-def run_dist_policy(rank, world_size, port):
-    colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
-    check_bert_lmhead_policy()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_bert_lmhead_policy():
-    spawn(run_dist_policy, 4)
-
-
-if __name__ == "__main__":
-    """test the bert for lm head model policy"""
-    test_bert_lmhead_policy()
diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py
deleted file mode 100644
index b366df017..000000000
--- a/tests/test_pipeline/test_policy/test_bert_model.py
+++ /dev/null
@@ -1,66 +0,0 @@
-'''
-In the test policy we only test policy: held layers and others, as the tests for forward logic are done in test_shardformer/test_model
-'''
-
-import pytest
-import torch.distributed as dist
-from transformers.models.bert.modeling_bert import BertModel
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.base_policy import Policy
-from colossalai.shardformer.policies.bert import BertModelPolicy
-from colossalai.shardformer.shard import ShardConfig
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-
-
-def check_bert_model_policy():
-    model = BertModel.from_pretrained('bert-base-uncased')
-    DP_DIM, PP_DIM = 0, 1
-    DP_SIZE, PP_SIZE = 2, 2
-    RANK_TO_COORDINATE = {
-        0: (0, 0),
-        1: (0, 1),
-        2: (1, 0),
-        3: (1, 1),
-    }
-    PP_RANKS_IN_GROUP = {
-        0: [0, 1],
-        1: [0, 1],
-        2: [2, 3],
-        3: [2, 3],
-    }
-    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
-    # print(pg_mesh)
-
-    stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-    rank = dist.get_rank()
-
-    model_policy = BertModelPolicy()
-    model_policy.set_model(model)
-    model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
-    model_policy.set_shard_config(model_config)
-
-    layers = model_policy.get_held_layers()
-
-    if stage_manager.is_first_stage():
-        assert len(layers) == 6 + 1
-    else:
-        assert len(layers) == 6 + 1
-
-
-def run_dist_policy(rank, world_size, port):
-    colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
-    check_bert_model_policy()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_bert_model_policy():
-    spawn(run_dist_policy, 4)
-
-
-if __name__ == "__main__":
-    """test the bert model policy"""
-    test_bert_model_policy()
diff --git a/tests/test_pipeline/test_policy/test_bloom_model.py b/tests/test_pipeline/test_policy/test_bloom_model.py
deleted file mode 100644
index e6a222f4e..000000000
--- a/tests/test_pipeline/test_policy/test_bloom_model.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import pytest
-import torch
-import torch.distributed as dist
-from transformers.models.bloom import BloomConfig, BloomModel
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.base_policy import Policy
-from colossalai.shardformer.policies.bloom import BloomModelPolicy
-from colossalai.shardformer.shard import ShardConfig
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-
-
-def check_bloom_model_policy():
-    # create a BloomModel
-    configuration = BloomConfig()
-    model = BloomModel(configuration)
-    DP_DIM, PP_DIM = 0, 1
-    DP_SIZE, PP_SIZE = 2, 2
-    RANK_TO_COORDINATE = {
-        0: (0, 0),
-        1: (0, 1),
-        2: (1, 0),
-        3: (1, 1),
-    }
-    PP_RANKS_IN_GROUP = {
-        0: [0, 1],
-        1: [0, 1],
-        2: [2, 3],
-        3: [2, 3],
-    }
-    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
-    # print(pg_mesh)
-
-    stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-    rank = dist.get_rank()
-
-    model_policy = BloomModelPolicy()
-    model_policy.set_model(model)
-    model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
-    model_policy.set_shard_config(model_config)
-    layers = model_policy.get_held_layers()
-    if stage_manager.is_first_stage():
-        assert len(layers) == 1 + 2
-    else:
-        assert len(layers) == 1 + 1
-
-
-def run_dist_policy(rank, world_size, port):
-    colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
-    check_bloom_model_policy()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_bloom_model_policy():
-    spawn(run_dist_policy, 4)
-
-
-if __name__ == "__main__":
-    """test the bloom model policy"""
-    test_bloom_model_policy()
diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py
index ea0f12264..6d0d3c798 100644
--- a/tests/test_shardformer/test_model/test_shard_bert.py
+++ b/tests/test_shardformer/test_model/test_shard_bert.py
@@ -2,7 +2,10 @@ import pytest
 import torch
 
 import colossalai
+from colossalai.cluster import ProcessGroupMesh
 from colossalai.logging import disable_existing_loggers
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.policies.auto_policy import get_autopolicy
 from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
 from colossalai.testing import (
     assert_hf_output_close,
diff --git a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py
index 4feaf982a..3170b58a1 100644
--- a/tests/test_shardformer/test_model/test_shard_bert_pipeline.py
+++ b/tests/test_shardformer/test_model/test_shard_bert_pipeline.py
@@ -5,6 +5,8 @@ import colossalai
 from colossalai.cluster import ProcessGroupMesh
 from colossalai.logging import disable_existing_loggers
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.policies.auto_policy import get_autopolicy
+from colossalai.shardformer.shard import ShardConfig
 from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
 from colossalai.testing import (
     assert_hf_output_close,
@@ -17,9 +19,55 @@ from tests.kit.model_zoo import model_zoo
 from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
 
 
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
-    # check forward
-    pass
+def check_bert_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
+    stage_manager = stage_manager
+    policy = get_autopolicy(model)
+    policy.set_model(model)
+    model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
+    policy.set_shard_config(model_config)
+    layers = policy.get_held_layers()
+    if stage_manager.is_first_stage():
+        assert len(layers) == 1 + 1
+    else:
+        if name == "transformers_bert":
+            assert len(layers) == 1 + 1
+        elif name in [
+                "transformers_bert_for_sequence_classification", "transformers_bert_for_token_classification",
+                "transformers_bert_for_mcq"
+        ]:
+            assert len(layers) == 1 + 3
+        else:
+            assert len(layers) == 1 + 2
+
+
+def check_bert_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
+    if name == 'transformers_bert_for_mcq':
+        x = torch.randint(0, 1000, (2, 3, 3)).cuda()
+        attention_mask = torch.ones_like(x).cuda()
+        if stage_manager.stage == 0:
+            output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
+            assert output['hidden_states'].shape == (6, 3, 128)
+        else:
+            hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
+            output = sharded_model(input_ids=x,
+                                   hidden_states=hidden_states,
+                                   attention_mask=attention_mask,
+                                   stage_manager=stage_manager)
+            assert output[0].shape == (2, 3)
+    else:
+        x = torch.randint(0, 1000, (2, 3)).cuda()
+        # one batch, 2 single sentences, each sentence has 3 tokens
+        hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
+        if stage_manager.stage == 0:
+            attention_mask = torch.ones_like(x).cuda()
+            output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
+            assert output['hidden_states'].shape == (2, 3, 128)
+        else:
+            attention_mask = torch.ones((2, 3)).cuda()
+            output = sharded_model(hidden_states=hidden_states,
+                                   attention_mask=attention_mask,
+                                   stage_manager=stage_manager)
+            assert output[0].shape[0] == 2
 
 
 @parameterize('enable_fused_normalization', [False])
@@ -27,55 +75,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
 @parameterize('use_lazy_init', [False])
 #TODO: merge this into test_shard_bert
 def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
-    DP_DIM, PP_DIM = 0, 1
-    DP_SIZE, PP_SIZE = 2, 2
-    RANK_TO_COORDINATE = {
-        0: (0, 0),
-        1: (0, 1),
-        2: (1, 0),
-        3: (1, 1),
-    }
-    PP_RANKS_IN_GROUP = {
-        0: [0, 1],
-        1: [0, 1],
-        2: [2, 3],
-        3: [2, 3],
-    }
-    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
+    PP_DIM = 0
+    PP_SIZE = 2
+    pg_mesh = ProcessGroupMesh(PP_SIZE)
     stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
 
     sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
         org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
                                                         enable_tensor_parallelism, use_lazy_init)
-
-        if name == 'transformers_bert_for_mcq':
-            x = torch.randint(0, 1000, (2, 3, 3)).cuda()
-            attention_mask = torch.ones_like(x).cuda()
-            if stage_manager.stage == 0:
-                output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
-                assert output['hidden_states'].shape == (6, 3, 128)
-            else:
-                hidden_states = torch.randint(0, 1000, (6, 3, 128)).to(torch.float32).cuda()
-                output = sharded_model(input_ids=x,
-                                       hidden_states=hidden_states,
-                                       attention_mask=attention_mask,
-                                       stage_manager=stage_manager)
-                assert output[0].shape == (2, 3)
-        else:
-            x = torch.randint(0, 1000, (2, 3)).cuda()
-            # one batch, 2 single sentences, each sentence has 3 tokens
-            hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
-            if stage_manager.stage == 0:
-                attention_mask = torch.ones_like(x).cuda()
-                output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
-                assert output['hidden_states'].shape == (2, 3, 128)
-            else:
-                attention_mask = torch.ones((2, 3)).cuda()
-                output = sharded_model(hidden_states=hidden_states,
-                                       attention_mask=attention_mask,
-                                       stage_manager=stage_manager)
-                assert output[0].shape[0] == 2
+        check_bert_model_policy(name, org_model, stage_manager)
+        check_bert_model_pipeline_forward(name, sharded_model, stage_manager)
 
     torch.cuda.empty_cache()
 
@@ -90,7 +100,7 @@ def check_bert(rank, world_size, port):
 @rerun_if_address_is_in_use()
 @clear_cache_before_run()
 def test_bert():
-    spawn(check_bert, 4)
+    spawn(check_bert, 2)
 
 
 if __name__ == "__main__":
diff --git a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py
index 3a36479fc..6695e8a68 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom_pipeline.py
@@ -5,7 +5,9 @@ import colossalai
 from colossalai.cluster import ProcessGroupMesh
 from colossalai.logging import disable_existing_loggers
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.policies.auto_policy import get_autopolicy
 from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.shard import ShardConfig
 from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
 from colossalai.testing import (
     assert_hf_output_close,
@@ -18,9 +20,37 @@ from tests.kit.model_zoo import model_zoo
 from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
 
 
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
-    # check forward
-    pass
+def check_bloom_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
+    policy = get_autopolicy(model)
+    policy.set_model(model)
+    model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
+    policy.set_shard_config(model_config)
+    layers = policy.get_held_layers()
+    if stage_manager.is_first_stage():
+        assert len(layers) == 0 + 2
+    else:
+        if name == 'transformers_bloom':
+            assert len(layers) == 1 + 1
+        elif name == 'transformers_bloom_for_token_classification':
+            assert len(layers) == 1 + 3
+        else:
+            assert len(layers) == 1 + 2
+
+
+def check_bloom_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
+    if stage_manager.stage == 0:
+        x = torch.randint(0, 1000, (1, 3)).cuda()
+        attention_mask = torch.ones_like(x).cuda()
+        output = sharded_model(input_ids=x, attention_mask=attention_mask)
+        assert output['hidden_states'].shape == (1, 3, 64)
+    else:
+        attention_mask = torch.ones((1, 3)).cuda()
+        hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
+        output = sharded_model(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+        )
+        assert output[0].shape[0] == 1
 
 
 @parameterize('enable_fused_normalization', [False])
@@ -28,40 +58,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
 @parameterize('use_lazy_init', [False])
 #TODO: merge this into test_shard_bloom
 def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
-    DP_DIM, PP_DIM = 0, 1
-    DP_SIZE, PP_SIZE = 2, 2
-    RANK_TO_COORDINATE = {
-        0: (0, 0),
-        1: (0, 1),
-        2: (1, 0),
-        3: (1, 1),
-    }
-    PP_RANKS_IN_GROUP = {
-        0: [0, 1],
-        1: [0, 1],
-        2: [2, 3],
-        3: [2, 3],
-    }
-    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
+    PP_DIM = 0
+    PP_SIZE = 2
+    pg_mesh = ProcessGroupMesh(PP_SIZE)
     stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
 
     sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
-    x = torch.randint(0, 1000, (1, 3)).cuda()
-    hidden_states = torch.randint(0, 1000, (1, 3, 64)).to(torch.float32).cuda()
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
         org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
                                                         enable_tensor_parallelism, use_lazy_init)
-        if stage_manager.stage == 0:
-            attention_mask = torch.ones_like(x).cuda()
-            output = sharded_model(input_ids=x, attention_mask=attention_mask)
-            assert output['hidden_states'].shape == (1, 3, 64)
-        else:
-            attention_mask = torch.ones((1, 3)).cuda()
-            output = sharded_model(
-                hidden_states=hidden_states,
-                attention_mask=attention_mask,
-            )
-            assert output[0].shape[0] == 1
+        check_bloom_model_policy(name, org_model, stage_manager)
+        check_bloom_model_pipeline_forward(name, sharded_model, stage_manager)
 
     torch.cuda.empty_cache()
 
@@ -76,7 +83,7 @@ def check_bloom(rank, world_size, port):
 @rerun_if_address_is_in_use()
 @clear_cache_before_run()
 def test_bloom():
-    spawn(check_bloom, 4)
+    spawn(check_bloom, 2)
 
 
 if __name__ == "__main__":
diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py
index 8fd9ed099..6f1f0cb34 100644
--- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py
+++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py
@@ -5,7 +5,9 @@ import colossalai
 from colossalai.cluster import ProcessGroupMesh
 from colossalai.logging import disable_existing_loggers
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.policies.auto_policy import get_autopolicy
 from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.shardformer.shard import ShardConfig
 from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
 from colossalai.testing import (
     assert_hf_output_close,
@@ -18,9 +20,35 @@ from tests.kit.model_zoo import model_zoo
 from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
 
 
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
-    # check forward
-    pass
+def check_llama_model_policy(name, model: torch.nn.Module, stage_manager: PipelineStageManager):
+    policy = get_autopolicy(model)
+    policy.set_model(model)
+    model_config = ShardConfig(pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False)
+    policy.set_shard_config(model_config)
+    layers = policy.get_held_layers()
+    if stage_manager.is_first_stage():
+        assert len(layers) == 2 + 1
+    else:
+        if name == "transformers_llama":
+            assert len(layers) == 2 + 1
+        else:
+            assert len(layers) == 2 + 2
+
+
+def check_llama_model_pipeline_forward(name, sharded_model, stage_manager: PipelineStageManager):
+    x = torch.randint(0, 1000, (2, 3)).cuda()
+    if stage_manager.stage == 0:
+        attention_mask = torch.ones_like(x).cuda()
+        output = sharded_model(input_ids=x, attention_mask=attention_mask)
+        assert output['hidden_states'].shape == (2, 3, 128)
+    else:
+        hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
+        attention_mask = torch.ones((2, 3)).cuda()
+        output = sharded_model(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+        )
+        assert output[0] is not None
 
 
 @parameterize('enable_fused_normalization', [False])
@@ -28,40 +56,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
 @parameterize('use_lazy_init', [False])
 #TODO: merge this into test_shard_llama
 def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
-    DP_DIM, PP_DIM = 0, 1
-    DP_SIZE, PP_SIZE = 2, 2
-    RANK_TO_COORDINATE = {
-        0: (0, 0),
-        1: (0, 1),
-        2: (1, 0),
-        3: (1, 1),
-    }
-    PP_RANKS_IN_GROUP = {
-        0: [0, 1],
-        1: [0, 1],
-        2: [2, 3],
-        3: [2, 3],
-    }
-    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
+    PP_DIM = 0
+    PP_SIZE = 2
+    pg_mesh = ProcessGroupMesh(PP_SIZE)
     stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
 
     sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
-    x = torch.randint(0, 1000, (2, 3)).cuda()
-    hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
+
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
         org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
                                                         enable_tensor_parallelism, use_lazy_init)
-        if stage_manager.stage == 0:
-            attention_mask = torch.ones_like(x).cuda()
-            output = sharded_model(input_ids=x, attention_mask=attention_mask)
-            assert output['hidden_states'].shape == (2, 3, 128)
-        else:
-            attention_mask = torch.ones((2, 3)).cuda()
-            output = sharded_model(
-                hidden_states=hidden_states,
-                attention_mask=attention_mask,
-            )
-            assert output[0] is not None
+        check_llama_model_policy(name, org_model, stage_manager)
+        check_llama_model_pipeline_forward(name, sharded_model, stage_manager)
 
     torch.cuda.empty_cache()
 
@@ -76,7 +82,7 @@ def check_llama(rank, world_size, port):
 @rerun_if_address_is_in_use()
 @clear_cache_before_run()
 def test_llama():
-    spawn(check_llama, 4)
+    spawn(check_llama, 2)
 
 
 if __name__ == "__main__":