From ed4c4484880b733894e6088e681f7cca32afe0b4 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang <eddiezhang@pku.edu.cn>
Date: Tue, 8 Aug 2023 17:46:44 +0800
Subject: [PATCH] [pipeline] rewrite t5 tests & support multi-tensor
 transmitting in pipeline (#4388)

* fix remaining t5 bugs/rewrite t5 tests

* fix multi-tensor communication in pipeline

* rearrange test_config

* fix keyerror in sync_shared_params

* fix get_held_layers & Randomnizer, complete t5 tests

* erase printing

* fix get_held_layers through modifying _release_unheld_layers

* fix _get_recursive_held_layers bug
---
 .../booster/plugin/hybrid_parallel_plugin.py  |   6 +-
 colossalai/pipeline/p2p.py                    |   6 +-
 colossalai/pipeline/schedule/_utils.py        |   2 +-
 colossalai/pipeline/schedule/one_f_one_b.py   |  11 +-
 colossalai/shardformer/layer/utils.py         |   7 +
 colossalai/shardformer/modeling/t5.py         |  95 ++++++-------
 colossalai/shardformer/policies/t5.py         |  51 ++-----
 colossalai/shardformer/shard/sharder.py       |  16 ++-
 .../test_model/test_shard_gpt2.py             |  13 +-
 .../test_model/test_shard_t5.py               | 134 ++++++++++++------
 .../test_model/test_shard_t5_pipeline.py      | 101 -------------
 11 files changed, 196 insertions(+), 246 deletions(-)
 delete mode 100644 tests/test_shardformer/test_model/test_shard_t5_pipeline.py

diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index a22bdb719..42942aaeb 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -50,8 +50,10 @@ class HybridParallelModule(ModelWrapper):
 
     def sync_shared_params(self):
         for shared_param, group in zip(self.shared_params, self.shared_param_process_groups):
-            param = shared_param[self.stage_manager.stage]
-            dist.all_reduce(param.grad, group=group)
+            if self.stage_manager.stage in shared_param:
+                param = shared_param[self.stage_manager.stage]
+                dist.all_reduce(param.grad, group=group)
+            dist.barrier()
 
     def no_sync(self) -> Iterator[None]:
         # no sync grads across data parallel
diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py
index f741b8363..af7a00b5c 100644
--- a/colossalai/pipeline/p2p.py
+++ b/colossalai/pipeline/p2p.py
@@ -3,6 +3,7 @@
 
 import io
 import pickle
+import re
 from typing import Any, List, Optional, Union
 
 import torch
@@ -31,7 +32,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
     if b'cuda' in buf:
         buf_array = bytearray(buf)
         device_index = torch.cuda.current_device()
-        buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
+        # There might be more than one output tensors during forward
+        for cuda_str in re.finditer(b'cuda', buf_array):
+            pos = cuda_str.start()
+            buf_array[pos + 5] = 48 + device_index
         buf = bytes(buf_array)
 
     io_bytes = io.BytesIO(buf)
diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py
index 045c86e40..3ed923927 100644
--- a/colossalai/pipeline/schedule/_utils.py
+++ b/colossalai/pipeline/schedule/_utils.py
@@ -86,7 +86,7 @@ def retain_grad(x: Any) -> None:
     Args:
         x (Any): Object to be called.
     """
-    if isinstance(x, torch.Tensor):
+    if isinstance(x, torch.Tensor) and x.requires_grad:
         x.retain_grad()
 
 
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index d907d53ed..ade3cf456 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -107,8 +107,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
         if output_obj_grad is None:
             optimizer.backward(output_obj)
         else:
-            for k, grad in output_obj_grad.items():
-                optimizer.backward_by_grad(output_obj[k], grad)
+            if "backward_tensor_keys" not in output_obj:
+                for k, grad in output_obj_grad.items():
+                    optimizer.backward_by_grad(output_obj[k], grad)
+            else:
+                for k, grad in output_obj_grad.items():
+                    output_obj[k].grad = grad
+                for k in output_obj["backward_tensor_keys"]:
+                    tensor_to_backward = output_obj[k]
+                    optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
 
         # Collect the grad of the input_obj.
         input_obj_grad = None
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index f2ac6563c..09cb7bfe1 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -122,6 +122,13 @@ class Randomizer:
         """
         Randomizer._INDEX += 1
 
+    @staticmethod
+    def reset_index():
+        """
+        Reset the index to zero.
+        """
+        Randomizer._INDEX = 0
+
     @staticmethod
     def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
         """
diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py
index 0b3486e87..d622da452 100644
--- a/colossalai/shardformer/modeling/t5.py
+++ b/colossalai/shardformer/modeling/t5.py
@@ -238,7 +238,8 @@ class T5PipelineForwards:
             return {
                 'hidden_states': hidden_states,
                 'position_bias': position_bias,
-                'encoder_decoder_position_bias': encoder_decoder_position_bias
+                'encoder_decoder_position_bias': encoder_decoder_position_bias,
+                'backward_tensor_keys': ['hidden_states']
             }
 
     @staticmethod
@@ -261,8 +262,10 @@ class T5PipelineForwards:
         return_dict: Optional[bool] = None,
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
         position_bias: Optional[torch.Tensor] = None,
         encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+        backward_tensor_keys: Optional[List[str]] = None,
         stage_index: Optional[List[int]] = None,
         decoder_starting_stage: Optional[int] = None,
     ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
@@ -303,7 +306,6 @@ class T5PipelineForwards:
                 decoder_head_mask = head_mask
 
         in_decoder = stage_manager.stage >= decoder_starting_stage
-
         # Stage is in encoder, directly return the output of t5_stack_forward
         if not in_decoder:
             encoder_outputs = T5PipelineForwards.t5_stack_forward(
@@ -323,25 +325,18 @@ class T5PipelineForwards:
                 decoder_starting_stage=decoder_starting_stage)
             if stage_manager.stage == decoder_starting_stage - 1:
                 # last stage of encoder
-                return {'encoder_outputs': encoder_outputs}
+                return {'encoder_hidden_states': encoder_outputs[0]}
             else:
                 return encoder_outputs
 
         at_last_decoder_stage = stage_manager.is_last_stage()
         at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
 
-        if encoder_outputs is None:
-            raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
+        if encoder_outputs is not None:
+            encoder_hidden_states = encoder_outputs[0]
+        elif encoder_hidden_states is None:
+            raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
 
-        encoder_hidden_states = encoder_outputs[0]
-        if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
-            encoder_outputs = BaseModelOutput(
-                last_hidden_state=encoder_outputs[0],
-                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
-                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
-            )
-
-        # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
         if not at_first_decoder_stage and hidden_states is None:
             raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
 
@@ -360,6 +355,7 @@ class T5PipelineForwards:
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
             return_dict=return_dict,
+            stage_manager=stage_manager,
             hidden_states=hidden_states,
             position_bias=position_bias,
             encoder_decoder_position_bias=encoder_decoder_position_bias,
@@ -368,22 +364,19 @@ class T5PipelineForwards:
 
         # Directly return outputs of overloaded T5Stack forward if not at last stage.
         if not at_last_decoder_stage:
-            decoder_outputs['encoder_outputs'] = encoder_outputs    # encoder_outputs should be passed to the next stage
+            # encoder_hidden_states should be passed to the next stage
+            decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
             return decoder_outputs
 
         if not return_dict:
-            return decoder_outputs + encoder_outputs
-
-        return Seq2SeqModelOutput(
-            last_hidden_state=decoder_outputs.last_hidden_state,
-            past_key_values=decoder_outputs.past_key_values,
-            decoder_hidden_states=decoder_outputs.hidden_states,
-            decoder_attentions=decoder_outputs.attentions,
-            cross_attentions=decoder_outputs.cross_attentions,
-            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
-            encoder_hidden_states=encoder_outputs.hidden_states,
-            encoder_attentions=encoder_outputs.attentions,
-        )
+            return decoder_outputs + encoder_hidden_states
+        else:
+            return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state,
+                                      past_key_values=decoder_outputs.past_key_values,
+                                      decoder_hidden_states=decoder_outputs.hidden_states,
+                                      decoder_attentions=decoder_outputs.attentions,
+                                      cross_attentions=decoder_outputs.cross_attentions,
+                                      encoder_last_hidden_state=encoder_hidden_states)
 
     @staticmethod
     def t5_for_conditional_generation_forward(
@@ -406,8 +399,10 @@ class T5PipelineForwards:
         return_dict: Optional[bool] = None,
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
         position_bias: Optional[torch.Tensor] = None,
         encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+        backward_tensor_keys: Optional[List[str]] = None,
         stage_index: Optional[List[int]] = None,
         decoder_starting_stage: Optional[int] = None,
     ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
@@ -468,28 +463,25 @@ class T5PipelineForwards:
                 decoder_starting_stage=decoder_starting_stage)
             if stage_manager.stage == decoder_starting_stage - 1:
                 # last stage of encoder
-                return {'encoder_outputs': encoder_outputs}
+                return {'encoder_hidden_states': encoder_outputs[0]}
             else:
                 return encoder_outputs
 
         at_last_decoder_stage = stage_manager.is_last_stage()
         at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
 
-        if encoder_outputs is None:
-            raise ValueError("Non-empty encoder_outputs should be passed in at decoder stages.")
+        if encoder_outputs is not None:
+            encoder_hidden_states = encoder_outputs[0]
+        elif encoder_hidden_states is None:
+            raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
 
-        encoder_hidden_states = encoder_outputs[0]
-        if return_dict and not isinstance(encoder_outputs, BaseModelOutput):
-            encoder_outputs = BaseModelOutput(
-                last_hidden_state=encoder_outputs[0],
-                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
-                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
-            )
-
-        # Stage is in decoder, we assume that the outputs of last stage of encoder will be passed in.
         if not at_first_decoder_stage and hidden_states is None:
             raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
 
+        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+            # get decoder inputs from shifting lm labels to the right
+            decoder_input_ids = self._shift_right(labels)
+
         # Decode
         decoder_outputs = T5PipelineForwards.t5_stack_forward(
             self.decoder,
@@ -505,6 +497,7 @@ class T5PipelineForwards:
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
             return_dict=return_dict,
+            stage_manager=stage_manager,
             hidden_states=hidden_states,
             position_bias=position_bias,
             encoder_decoder_position_bias=encoder_decoder_position_bias,
@@ -513,7 +506,8 @@ class T5PipelineForwards:
 
         # Directly return outputs of overloaded T5Stack forward if not at last stage.
         if not at_last_decoder_stage:
-            decoder_outputs['encoder_outputs'] = encoder_outputs    # encoder_outputs should be passed to the next stage
+            # encoder_hidden_states should be passed to the next stage
+            decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
             return decoder_outputs
 
         sequence_output = decoder_outputs[0]
@@ -533,20 +527,16 @@ class T5PipelineForwards:
             loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
 
         if not return_dict:
-            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+            output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states
             return ((loss,) + output) if loss is not None else output
 
-        return Seq2SeqLMOutput(
-            loss=loss,
-            logits=lm_logits,
-            past_key_values=decoder_outputs.past_key_values,
-            decoder_hidden_states=decoder_outputs.hidden_states,
-            decoder_attentions=decoder_outputs.attentions,
-            cross_attentions=decoder_outputs.cross_attentions,
-            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
-            encoder_hidden_states=encoder_outputs.hidden_states,
-            encoder_attentions=encoder_outputs.attentions,
-        )
+        return Seq2SeqLMOutput(loss=loss,
+                               logits=lm_logits,
+                               past_key_values=decoder_outputs.past_key_values,
+                               decoder_hidden_states=decoder_outputs.hidden_states,
+                               decoder_attentions=decoder_outputs.attentions,
+                               cross_attentions=decoder_outputs.cross_attentions,
+                               encoder_last_hidden_state=encoder_hidden_states)
 
     @staticmethod
     def t5_encoder_model_forward(
@@ -562,6 +552,7 @@ class T5PipelineForwards:
         hidden_states: Optional[torch.FloatTensor] = None,
         position_bias: Optional[torch.Tensor] = None,
         encoder_decoder_position_bias: Optional[torch.Tensor] = None,
+        backward_tensor_keys: Optional[List[str]] = None,
         stage_index: Optional[List[int]] = None,
         decoder_starting_stage: Optional[int] = None,
     ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 5e78ae909..2ef52c214 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -260,7 +260,7 @@ class T5BasePolicy(Policy):
 
         model = self.model
         encoder = self.model.encoder
-        decoder = self.model.__dict__.get('decoder', None)
+        decoder = getattr(self.model, 'decoder', None)
 
         num_encoder_layers = len(encoder.block)
         num_decoder_layers = len(decoder.block) if decoder else 0
@@ -300,7 +300,7 @@ class T5BasePolicy(Policy):
         stage_manager = self.pipeline_stage_manager
 
         encoder = self.model.encoder
-        decoder = self.model.__dict__.get('decoder', None)
+        decoder = getattr(self.model, 'decoder', None)
 
         num_encoder_layers = len(encoder.block)
         num_decoder_layers = len(decoder.block) if decoder else 0
@@ -355,15 +355,6 @@ class T5ModelPolicy(T5BasePolicy):
                 return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
         return []
 
-    def postprocess(self):
-        if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
-            binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
-            for k, v in binding_map.items():
-                src = getattr_(self.model, k)
-                for dst in v:
-                    setattr_(self.model, dst, src)
-        return self.model
-
 
 class T5ForConditionalGenerationPolicy(T5BasePolicy):
 
@@ -409,29 +400,22 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
                                                                           stage_manager.num_stages)
 
             shared_params = []
+            shared_embedding = {}
             if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
-                shared_params.append({
-                    0: module.shared.weight,
-                    decoder_starting_stage: module.decoder.embed_tokens.weight
-                })
+                shared_embedding[0] = module.shared.weight
+                shared_embedding[decoder_starting_stage] = module.decoder.embed_tokens.weight
+
             if id(module.lm_head.weight) == id(module.shared.weight):
-                shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight})
+                shared_embedding[0] = module.shared.weight
+                shared_embedding[stage_manager.num_stages - 1] = module.lm_head.weight
+
+            if len(shared_embedding) > 0:
+                shared_params.append(shared_embedding)
+
             return shared_params
+
         return []
 
-    def postprocess(self):
-        super().postprocess()
-        if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
-            binding_map = {
-                "shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
-            }
-            for k, v in binding_map.items():
-                src = getattr_(self.model, k)
-                for dst in v:
-                    setattr_(self.model, dst, src)
-
-        return self.model
-
 
 class T5EncoderPolicy(T5BasePolicy):
 
@@ -462,12 +446,3 @@ class T5EncoderPolicy(T5BasePolicy):
 
     def get_shared_params(self) -> List[Dict[int, Tensor]]:
         return []
-
-    def postprocess(self):
-        if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
-            binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
-            for k, v in binding_map.items():
-                src = getattr_(self.model, k)
-                for dst in v:
-                    setattr_(self.model, dst, src)
-        return self.model
diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py
index ae8cd8c6e..0ed745a1f 100644
--- a/colossalai/shardformer/shard/sharder.py
+++ b/colossalai/shardformer/shard/sharder.py
@@ -198,6 +198,20 @@ class ModelSharder(object):
 
             setattr_(org_layer, suffix, replace_layer)
 
+    def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]:
+
+        def collect_sub_modules(module: nn.Module):
+            if module is None:
+                return
+            recursive_held_layers.append(module)
+            for name, child in module.named_children():
+                collect_sub_modules(child)
+
+        recursive_held_layers = []
+        for module in held_layers:
+            collect_sub_modules(module)
+        return recursive_held_layers
+
     def _release_unheld_layers(self) -> Optional[Set[nn.Module]]:
         r"""
         Release the unheld layers in the model
@@ -205,7 +219,7 @@ class ModelSharder(object):
         if self.shard_config and self.shard_config.pipeline_stage_manager:
             held_layers = self.policy.get_held_layers()
             set_tensors_to_none(self.model, exclude=set(held_layers))
-            return set(held_layers)
+            return set(self._get_recursive_held_layers(held_layers))
         return None
 
     def _materialize(self) -> None:
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index f7213d8c5..1882bf782 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -68,16 +68,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
 
     torch.cuda.empty_cache()
 
+
 @parameterize('test_config', [{
-    'tp_size': 1,
-    'pp_size': 2,
-    'num_microbatches': 4,
-    'use_lazy_init': True
-}, {
     'tp_size': 2,
     'pp_size': 2,
     'num_microbatches': 4,
-    'enable_fused_normalization': False,
+    'enable_fused_normalization': True,
+    'use_lazy_init': True
+}, {
+    'tp_size': 1,
+    'pp_size': 2,
+    'num_microbatches': 4,
     'use_lazy_init': False
 }, {
     'tp_size': 4,
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index 22f04c879..d807ffa06 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -1,60 +1,110 @@
-import os
-
 import pytest
 import torch
 
 import colossalai
 from colossalai.logging import disable_existing_loggers
-from colossalai.testing import (
-    assert_hf_output_close,
-    clear_cache_before_run,
-    parameterize,
-    rerun_if_address_is_in_use,
-    spawn,
-)
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
 from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
+from tests.test_shardformer.test_model._utils import (
+    build_model_from_hybrid_plugin,
+    check_grad,
+    check_loss,
+    check_output_hidden_state,
+    check_weight,
+    run_forward_backward_with_hybrid_plugin,
+)
 
 
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
-    # check forward
-    # the value "past_key_values" is sharded, so we ignore
-    org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
-                                                                 output_transform_fn, loss_fn)
-    assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], atol=1e-5)
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
 
-    # do backward
-    org_loss.backward()
-    shard_loss.backward()
+    org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
+        build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
 
-    assert torch.allclose(org_loss, shard_loss,
-                          atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
+    org_loss, org_output, sharded_loss, sharded_output = \
+        run_forward_backward_with_hybrid_plugin(
+            org_model,
+            sharded_model,
+            sharded_optimizer,
+            data_gen_fn,
+            output_transform_fn,
+            criterion,
+            booster)
 
-    # check grad
-    col_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.q', 'shared']
-    row_layer_for_check = ['encoder.block[0].layer[0].SelfAttention.relative_attention_bias']
-    check_grad(org_model, sharded_model, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
-    check_grad(org_model, sharded_model, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
+    stage_manager = booster.plugin.stage_manager
+    tp_group = booster.plugin.tp_group
 
-    # check weights are tied
-    if hasattr(org_model, 'lm_head'):
-        assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr()
-        assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr()
+    # check last hidden state & loss
+    if stage_manager is None or stage_manager.is_last_stage():
+
+        if org_model.__class__.__name__ != 'T5ForConditionalGeneration':
+            check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
+
+        check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
+
+    # unwrap model
+    t5 = org_model
+    sharded_t5 = sharded_model.unwrap()
+
+    row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
+
+    # check weights and gradients
+    if stage_manager is None or stage_manager.is_first_stage():
+        check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-5, rtol=1e-3, dim=0)
+
+    # check weights after optimizer.step()
+    org_optimizer.step()
+    sharded_optimizer.step()
+    if stage_manager is None or stage_manager.is_first_stage():
+        check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
+
+    torch.cuda.empty_cache()
 
 
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-@parameterize('use_lazy_init', [False, True])
-@parameterize('enable_flash_attention', [True, False])
-@parameterize('enable_jit_fused', [True, False])
-def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init, enable_flash_attention,
-                enable_jit_fused):
+@parameterize('test_config', [{
+    'tp_size': 2,
+    'pp_size': 2,
+    'num_microbatches': 2,
+    'enable_fused_normalization': True,
+    'use_lazy_init': True
+}, {
+    'tp_size': 1,
+    'pp_size': 2,
+    'num_microbatches': 4,
+    'use_lazy_init': False
+}, {
+    'tp_size': 4,
+    'pp_size': 1,
+    'enable_fused_normalization': True,
+    'use_lazy_init': False
+}, {
+    'tp_size': 1,
+    'pp_size': 4,
+    'num_microbatches': 4,
+    'use_lazy_init': False
+}])
+@clear_cache_before_run()
+def run_t5_test(test_config):
+
+    # TODO: add plugin_config for TP+DP after supporting & debugging it
+    # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
+
+    # TODO: add test_config for flash attention & jit operator after supporting
+
     sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
+    test_config['precision'] = 'float'    # Do not use fp16/bf16 in testing
+
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
-        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
-                                               enable_flash_attention, enable_jit_fused, use_lazy_init)
-        check_state_dict(org_model, sharded_model, name=name)
-        check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
+
+        # skip 4-stage pp test for t5_encoder
+        if test_config['pp_size'] > 2 and name == 'transformers_t5_encoder_model':
+            continue
+
+        check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+
+    clear_layout_converter()
+    Randomizer.reset_index()
     torch.cuda.empty_cache()
 
 
@@ -68,7 +118,7 @@ def check_t5(rank, world_size, port):
 @rerun_if_address_is_in_use()
 @clear_cache_before_run()
 def test_t5():
-    spawn(check_t5, 2)
+    spawn(check_t5, 4)
 
 
 if __name__ == "__main__":
diff --git a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py b/tests/test_shardformer/test_model/test_shard_t5_pipeline.py
deleted file mode 100644
index 7f3a5f2ea..000000000
--- a/tests/test_shardformer/test_model/test_shard_t5_pipeline.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.logging import disable_existing_loggers
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer.policies.t5 import T5BasePolicy
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-from tests.kit.model_zoo import model_zoo
-from tests.test_shardformer.test_model._utils import build_pipeline_model
-
-
-def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
-    # TODO: add tests for forward/backward later
-    pass
-
-
-@parameterize('enable_tensor_parallelism', [False])
-@parameterize('enable_fused_normalization', [False])
-@parameterize('use_lazy_init', [False])
-#TODO: merge this into test_shard_t5.py
-def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
-    DP_DIM, PP_DIM = 0, 1
-    DP_SIZE, PP_SIZE = 2, 2
-    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
-    stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
-
-    sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
-    for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
-
-        inputs = data_gen_fn()
-        inputs = {k: v.cuda() for k, v in inputs.items()}
-        input_ids = inputs['input_ids']
-
-        _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
-                                                enable_tensor_parallelism, use_lazy_init)
-
-        batch_size, seq_len = input_ids.shape
-        hidden_size = sharded_model.config.d_model
-        num_heads = sharded_model.config.num_heads
-        hidden_state_shape = (batch_size, seq_len, hidden_size)
-        position_bias_shape = (batch_size, num_heads, seq_len, seq_len)
-
-        num_encoder_layers = len(sharded_model.encoder.block)
-        decoder = sharded_model.__dict__.get('decoder', None)
-        num_decoder_layers = len(decoder.block) if decoder else 0
-
-        _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(num_encoder_layers, num_decoder_layers, PP_SIZE)
-        stage = stage_manager.stage
-        at_first_stage = (stage == 0) or (stage == decoder_starting_stage)
-        at_last_stage = (stage == decoder_starting_stage - 1) or (stage == stage_manager.num_stages - 1)
-        in_decoder = stage >= decoder_starting_stage
-
-        if not at_first_stage:
-            # change inputs if not the first stage
-            hidden_states = torch.zeros(*hidden_state_shape).cuda()
-            position_bias = torch.zeros(*position_bias_shape).cuda()
-            encoder_decoder_position_bias = torch.zeros(*position_bias_shape).cuda()
-            inputs['input_ids'] = None
-            inputs['hidden_states'] = hidden_states
-            inputs['position_bias'] = position_bias
-            inputs['encoder_decoder_position_bias'] = encoder_decoder_position_bias
-        if in_decoder:
-            encoder_output_states = torch.zeros(*hidden_state_shape).cuda()
-            inputs['encoder_outputs'] = (encoder_output_states,)
-
-        sharded_model.train()
-        output = sharded_model(**inputs)
-        if at_last_stage:
-            if name == 'transformers_t5_for_conditional_generation' and in_decoder:
-                assert output.loss is not None
-            else:
-                if name != 'transformers_t5_encoder_model' and not in_decoder:
-                    output = output['encoder_outputs']
-                assert output[0].shape == hidden_state_shape
-        else:
-            assert output['hidden_states'].shape == hidden_state_shape
-            # position_bias information should be passed in T5
-            assert output['position_bias'].shape == position_bias_shape
-            if in_decoder:
-                assert output['encoder_decoder_position_bias'].shape == position_bias_shape
-
-    torch.cuda.empty_cache()
-
-
-def check_t5(rank, world_size, port):
-    disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
-    run_t5_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_t5():
-    spawn(check_t5, 4)
-
-
-if __name__ == "__main__":
-    test_t5()