From e76308c6e65cb73cc5b20936bd232ba7390c6b11 Mon Sep 17 00:00:00 2001
From: duanjunwen <935724073@qq.com>
Date: Wed, 16 Oct 2024 03:25:04 +0000
Subject: [PATCH] [fix] rm use_zbv flag in Shardconfig; rm debug info;

---
 .../booster/plugin/hybrid_parallel_plugin.py  |   1 -
 .../plugin/moe_hybrid_parallel_plugin.py      |   1 -
 colossalai/shardformer/policies/llama.py      |  24 +-
 colossalai/shardformer/policies/mixtral.py    |  28 +-
 colossalai/shardformer/shard/shard_config.py  |   1 -
 examples/language/llama/benchmark.py          |   2 -
 .../test_schedule/test_zerobubble_pp.py       | 176 ++++-
 tests/test_pipeline/test_schedule/zbv_poc.py  | 628 ------------------
 .../test_model/test_shard_llama.py            |   2 +-
 9 files changed, 212 insertions(+), 651 deletions(-)
 delete mode 100644 tests/test_pipeline/test_schedule/zbv_poc.py

diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index bba943f12..caeed5457 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1201,7 +1201,6 @@ class HybridParallelPlugin(PipelinePluginBase):
             gradient_checkpoint_config=gradient_checkpoint_config,
             fp8_communication=fp8_communication,
             inner_ring_size=inner_ring_size,
-            use_zbv=(pp_style == "zbv"),
         )
         self.amp_config = dict(
             initial_scale=initial_scale,
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index b7e65c6a2..8b62a1e2b 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -373,7 +373,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
             make_vocab_size_divisible_by=make_vocab_size_divisible_by,
             gradient_checkpoint_config=gradient_checkpoint_config,
             fp8_communication=fp8_communication,
-            use_zbv=(pp_style == "zbv"),
         )
         self.amp_config = dict(
             initial_scale=initial_scale,
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 5c68d0c5e..db4515d7e 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -60,6 +60,11 @@ class LlamaPolicy(Policy):
         else:
             norm_cls = RMSNorm
 
+        if self.pipeline_stage_manager:
+            use_zbv = self.pipeline_stage_manager.use_zbv
+        else:
+            use_zbv = False
+
         sp_mode = self.shard_config.sequence_parallelism_mode or None
         sp_size = self.shard_config.sequence_parallel_size or None
         sp_group = self.shard_config.sequence_parallel_process_group or None
@@ -129,7 +134,7 @@ class LlamaPolicy(Policy):
                         kwargs=dict(
                             seq_parallel_mode=sp_mode,
                             fp8_communication=self.shard_config.fp8_communication,
-                            use_zbv=self.shard_config.use_zbv,
+                            use_zbv=use_zbv,
                         ),
                     ),
                     SubModuleReplacementDescription(
@@ -138,7 +143,7 @@ class LlamaPolicy(Policy):
                         kwargs=dict(
                             seq_parallel_mode=sp_mode,
                             fp8_communication=self.shard_config.fp8_communication,
-                            use_zbv=self.shard_config.use_zbv,
+                            use_zbv=use_zbv,
                         ),
                     ),
                     SubModuleReplacementDescription(
@@ -147,7 +152,7 @@ class LlamaPolicy(Policy):
                         kwargs=dict(
                             seq_parallel_mode=sp_mode,
                             fp8_communication=self.shard_config.fp8_communication,
-                            use_zbv=self.shard_config.use_zbv,
+                            use_zbv=use_zbv,
                         ),
                     ),
                     SubModuleReplacementDescription(
@@ -156,7 +161,7 @@ class LlamaPolicy(Policy):
                         kwargs=dict(
                             seq_parallel_mode=sp_mode,
                             fp8_communication=self.shard_config.fp8_communication,
-                            use_zbv=self.shard_config.use_zbv,
+                            use_zbv=use_zbv,
                         ),
                     ),
                     SubModuleReplacementDescription(
@@ -165,7 +170,7 @@ class LlamaPolicy(Policy):
                         kwargs=dict(
                             seq_parallel_mode=sp_mode,
                             fp8_communication=self.shard_config.fp8_communication,
-                            use_zbv=self.shard_config.use_zbv,
+                            use_zbv=use_zbv,
                         ),
                     ),
                     SubModuleReplacementDescription(
@@ -174,7 +179,7 @@ class LlamaPolicy(Policy):
                         kwargs=dict(
                             seq_parallel_mode=sp_mode,
                             fp8_communication=self.shard_config.fp8_communication,
-                            use_zbv=self.shard_config.use_zbv,
+                            use_zbv=use_zbv,
                         ),
                     ),
                     SubModuleReplacementDescription(
@@ -183,7 +188,7 @@ class LlamaPolicy(Policy):
                         kwargs=dict(
                             seq_parallel_mode=sp_mode,
                             fp8_communication=self.shard_config.fp8_communication,
-                            use_zbv=self.shard_config.use_zbv,
+                            use_zbv=use_zbv,
                         ),
                     ),
                 ],
@@ -413,6 +418,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
         from transformers import LlamaForSequenceClassification
 
         policy = super().module_policy()
+        if self.pipeline_stage_manager:
+            use_zbv = self.pipeline_stage_manager.use_zbv
+        else:
+            use_zbv = False
 
         if self.shard_config.enable_tensor_parallelism:
             # add a new item for sequence classification
@@ -425,6 +434,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
                             kwargs=dict(
                                 gather_output=True,
                                 fp8_communication=self.shard_config.fp8_communication,
+                                use_zbv=use_zbv,
                             ),
                         )
                     ]
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index de546b3c5..11291169a 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -52,6 +52,10 @@ class MixtralPolicy(Policy):
         sp_group = self.shard_config.sequence_parallel_process_group or None
         sp_partial_derived = sp_mode in ["split_gather", "ring"]
         tp_size = self.shard_config.tensor_parallel_size
+        if self.pipeline_stage_manager:
+            use_zbv = self.pipeline_stage_manager.use_zbv
+        else:
+            use_zbv = False
 
         # modified for both SP and TP
         num_q_heads = self.model.config.num_attention_heads
@@ -126,7 +130,7 @@ class MixtralPolicy(Policy):
                         target_module=Linear1D_Col,
                         kwargs={
                             "fp8_communication": self.shard_config.fp8_communication,
-                            "use_zbv": self.shard_config.use_zbv,
+                            "use_zbv": use_zbv,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -134,7 +138,7 @@ class MixtralPolicy(Policy):
                         target_module=Linear1D_Col,
                         kwargs={
                             "fp8_communication": self.shard_config.fp8_communication,
-                            "use_zbv": self.shard_config.use_zbv,
+                            "use_zbv": use_zbv,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -142,7 +146,7 @@ class MixtralPolicy(Policy):
                         target_module=Linear1D_Col,
                         kwargs={
                             "fp8_communication": self.shard_config.fp8_communication,
-                            "use_zbv": self.shard_config.use_zbv,
+                            "use_zbv": use_zbv,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -150,7 +154,7 @@ class MixtralPolicy(Policy):
                         target_module=Linear1D_Row,
                         kwargs={
                             "fp8_communication": self.shard_config.fp8_communication,
-                            "use_zbv": self.shard_config.use_zbv,
+                            "use_zbv": use_zbv,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -159,7 +163,7 @@ class MixtralPolicy(Policy):
                         kwargs={
                             "gather_output": True,
                             "fp8_communication": self.shard_config.fp8_communication,
-                            "use_zbv": self.shard_config.use_zbv,
+                            "use_zbv": use_zbv,
                         },
                     ),
                 ],
@@ -195,7 +199,7 @@ class MixtralPolicy(Policy):
                             "tp_group": self.shard_config.tensor_parallel_process_group,
                             "moe_dp_group": self.shard_config.moe_dp_group,
                             "fp8_communication": self.shard_config.fp8_communication,
-                            "use_zbv": self.shard_config.use_zbv,
+                            "use_zbv": use_zbv,
                         },
                     )
                 ],
@@ -330,6 +334,10 @@ class MixtralModelPolicy(MixtralPolicy):
 class MixtralForCausalLMPolicy(MixtralPolicy):
     def module_policy(self):
         policy = super().module_policy()
+        if self.pipeline_stage_manager:
+            use_zbv = self.pipeline_stage_manager.use_zbv
+        else:
+            use_zbv = False
         # TODO: assign pg mesh from plugin to all modules
         if self.shard_config.enable_tensor_parallelism:
             # add a new item for causal lm
@@ -342,7 +350,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
                             kwargs=dict(
                                 gather_output=True,
                                 fp8_communication=self.shard_config.fp8_communication,
-                                use_zbv=self.shard_config.use_zbv,
+                                use_zbv=use_zbv,
                             ),
                         )
                     ],
@@ -392,6 +400,10 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
         from transformers import MixtralForSequenceClassification
 
         policy = super().module_policy()
+        if self.pipeline_stage_manager:
+            use_zbv = self.pipeline_stage_manager.use_zbv
+        else:
+            use_zbv = False
 
         if self.shard_config.enable_tensor_parallelism:
             # add a new item for sequence classification
@@ -404,7 +416,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
                             kwargs=dict(
                                 gather_output=True,
                                 fp8_communication=self.shard_config.fp8_communication,
-                                use_zbv=self.shard_config.use_zbv,
+                                use_zbv=use_zbv,
                             ),
                         )
                     ]
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index 33e93fa51..1219119bb 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -49,7 +49,6 @@ class ShardConfig:
     make_vocab_size_divisible_by: int = 64
     gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
     extra_kwargs: Dict[str, Any] = field(default_factory=dict)
-    use_zbv: bool = False
 
     # For ring attention
     inner_ring_size: Optional[int] = None
diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py
index 4f2c45d75..041c51fb1 100644
--- a/examples/language/llama/benchmark.py
+++ b/examples/language/llama/benchmark.py
@@ -5,8 +5,6 @@ import warnings
 from contextlib import nullcontext
 
 import torch
-
-torch.autograd.set_detect_anomaly(True)
 import torch.distributed as dist
 from data_utils import RandomDataset
 from model_utils import format_numel_str, get_model_numel
diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
index 1a1fbbeb2..bdc539043 100644
--- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
+++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
@@ -8,12 +8,14 @@ import torch
 import torch.distributed as dist
 import torch.nn as nn
 from torch.testing import assert_close
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaModel
 from transformers.models.mixtral.configuration_mixtral import MixtralConfig
 from transformers.models.mixtral.modeling_mixtral import MixtralModel
 
 import colossalai
 from colossalai.booster.booster import Booster
-from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin, MoeHybridParallelPlugin
 from colossalai.cluster import ProcessGroupMesh
 from colossalai.interface import OptimizerWrapper
 from colossalai.logging import disable_existing_loggers
@@ -918,11 +920,181 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
     torch.cuda.empty_cache()
 
 
+@parameterize(
+    "config",
+    [
+        (0, 4, 1, 1),
+        # (1, 2, 2, 1),
+        # (1, 2, 1, 2),
+        # (1, 1, 2, 2),
+    ],
+)
+def run_with_booster_hybridplugin(config: Tuple[int, ...]):
+    stage, pp_size, tp_size, sp_size = config
+    num_microbatches = pp_size
+    dist.get_world_size()
+    rank = dist.get_rank()
+    dtype, precision = torch.float16, "fp16"
+    torch.cuda.set_device(dist.get_rank())
+
+    ########
+    # init base model
+    ########
+    assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
+    config = LlamaConfig(
+        hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
+        intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
+        num_hidden_layers=NUM_LAYERS,
+        num_attention_heads=NUM_HEADS,
+        num_key_value_heads=NUM_HEADS,
+        attn_implementation="flash_attention_2",
+    )
+
+    # init model with the same seed
+    seed_all(10086)
+
+    torch_model = LlamaModel(config).to(dtype).cuda()
+    # TODO: Support MixtralForCausalLM
+    # torch_model = MixtralForCausalLM(config).to(dtype).cuda()
+    torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+    # init schedule
+    h, a, s = config.hidden_size, config.num_attention_heads, 1024
+    mem_f = 34 * h + 5 * a * s
+    mem_w = -32 * h
+    mem_b = -mem_w - mem_f
+    graph = PipelineGraph(
+        n_stage=pp_size,
+        n_micro=num_microbatches,
+        f_cost=1,
+        b_cost=1,
+        w_cost=1,
+        c_cost=1,
+        f_mem=mem_f,
+        b_mem=mem_b,
+        w_mem=mem_w,
+    )
+
+    zbv_schedule = graph.get_v_schedule()
+
+    # init MoeHybridPlugin
+    plugin = HybridParallelPlugin(
+        pp_size=pp_size,
+        num_microbatches=pp_size,
+        tp_size=tp_size,
+        sp_size=sp_size,
+        zero_stage=stage,
+        enable_sequence_parallelism=sp_size > 1,
+        sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
+        overlap_communication=False,
+        initial_scale=1,
+        precision=precision,
+        find_unused_parameters=True,
+        pp_style="zbv",
+        scheduler_nodes=zbv_schedule,
+        num_model_chunks=2,
+    )
+
+    dp_size = plugin.dp_size
+
+    booster = Booster(plugin=plugin)
+
+    ########
+    # init pp model
+    ########
+
+    parallel_model = deepcopy(torch_model)
+    parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
+    parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
+    # create different input along dp axis
+    seed_all(1453 + rank)
+
+    torch_model.train()
+    parallel_model.train()
+    for i in range(2):
+        # gen random input
+        # input = torch.rand(
+        #     NUM_BATCH, NUM_TOK_PER_BATCH, NUM_HEADS, HIDDEN_SIZE_PER_HEAD, requires_grad=True
+        # ).cuda()
+        input_ids = torch.randint(0, torch_model.vocab_size, (NUM_BATCH, config.max_position_embeddings)).cuda()
+        attention_mask = torch.ones_like(input_ids).cuda()
+        input_ids.clone().cuda()
+        input_data = {"input_ids": input_ids, "attention_mask": attention_mask}
+
+        # dist.all_reduce(
+        #     input, group=plugin.pp_group
+        # )  # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
+
+        # dist.all_reduce(input, group=plugin.tp_group)  # tp group duplicate input
+        # dist.all_reduce(input, group=plugin.sp_group)  # sp group duplicate input
+
+        # run the model with hybrid parallel
+        if booster.plugin.stage_manager is not None:
+            # for test with pp
+            data_iter = iter([input_data])
+            sharded_output = booster.execute_pipeline(
+                data_iter,
+                parallel_model,
+                lambda x, y: x.last_hidden_state.mean(),
+                parallel_optimizer,
+                return_loss=True,
+                return_outputs=True,
+            )
+            # stage 0 chunk 0
+            parallel_output = None
+            if (
+                booster.plugin.stage_manager.is_first_stage(ignore_chunk=True)
+                and rank == dist.get_process_group_ranks(plugin.pp_group)[0]
+            ):
+                parallel_output = sharded_output["loss"]
+            else:
+                parallel_output = torch.tensor(12345.0, device="cuda")
+            # broadcast along pp axis
+            dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
+
+        else:
+            # for test without pp
+            parallel_output = parallel_model(
+                input_ids=input_data["input_ids"],
+                attention_mask=input_data["attention_mask"],
+            ).last_hidden_state.mean()
+            parallel_optimizer.backward(parallel_output)
+        parallel_optimizer.step()
+        parallel_optimizer.zero_grad()
+        dist.all_reduce(parallel_output, group=plugin.dp_group)
+
+        # ===================================================================================
+        # run normal model with all dp(different) inputs
+        all_inputs = [input_data for _ in range(dp_size)]
+        # dist.all_gather(all_inputs, input, group=plugin.dp_group)
+        torch_output_sum = 0
+        for input_data_ in all_inputs:
+            torch_output = torch_model(
+                input_ids=input_data_["input_ids"],
+                attention_mask=input_data_["attention_mask"],
+            ).last_hidden_state.mean()
+            torch_output.backward()
+            torch_output_sum += torch_output.detach()
+        # print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
+        # avg dp grads follows zero optimizer
+        for p in torch_model.parameters():
+            if p.grad is not None:
+                p.grad /= dp_size
+        torch_optimizer.step()
+        torch_optimizer.zero_grad()
+
+        print(f"loop {i} rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}")
+        # assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
+        # print(f"rank {dist.get_rank()} config {test_config}  test passed")
+    clear_layout_converter()
+    Randomizer.reset_index()
+    torch.cuda.empty_cache()
+
+
 def run_dist(rank, world_size, port):
     disable_existing_loggers()
     colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
-    # run_fwd_bwd_vschedule_with_optim()
     run_with_booster_moehybridplugin()
+    # run_with_booster_hybridplugin()
 
 
 @pytest.mark.dist
diff --git a/tests/test_pipeline/test_schedule/zbv_poc.py b/tests/test_pipeline/test_schedule/zbv_poc.py
deleted file mode 100644
index 6280990a9..000000000
--- a/tests/test_pipeline/test_schedule/zbv_poc.py
+++ /dev/null
@@ -1,628 +0,0 @@
-import gc
-import time
-from copy import deepcopy
-
-import torch
-import torch.nn as nn
-from torch.testing import assert_close
-
-
-def get_model_numel(model):
-    return sum(p.numel() for p in model.parameters()) / 1024**2
-
-
-# Step1: dx = w*dy
-def backward_b(loss, x, model):
-    torch.autograd.backward(loss, inputs=x, retain_graph=True)
-
-
-# Step2: dummy dw = x*dy
-def backward_w(loss, model):
-    torch.autograd.backward(loss, inputs=list(model.parameters()))
-
-
-def test_double_dx_dw_split_nsync():
-    device = "cuda:0"
-    model = nn.Linear(4096, 4096, bias=None).to(device=device)
-    # print(f"model numel {get_model_numel(model)}") # 4GB
-    x1 = torch.rand(4096, 4096).to(device=device)
-    x2 = torch.rand(4096, 4096).to(device=device)
-    ref_model = deepcopy(model)
-    ref_x1 = x1.clone()
-    ref_x2 = x1.clone()
-
-    # first step
-    x1.requires_grad_()
-    x2.requires_grad_()
-    ref_x1.requires_grad_()
-    ref_x2.requires_grad_()
-
-    # loss for dx_dw bwd
-    loss1 = model(x1).sum()
-    loss2 = model(x2).sum()
-
-    # loss for common bwd
-    ref_loss1 = ref_model(ref_x1).sum()
-    ref_loss2 = ref_model(ref_x2).sum()
-
-    # dx1
-    torch.cuda.synchronize()
-    bwd_b_start_time = time.time()
-    backward_b(loss1, x1, model)
-    bwd_b_end_time = time.time()
-    print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
-
-    for p in model.parameters():
-        assert p.grad is None
-    assert x1.grad is not None
-
-    # dx2
-    torch.cuda.synchronize()
-    bwd_b_start_time = time.time()
-    backward_b(loss2, x2, model)
-    bwd_b_end_time = time.time()
-    print(f"loss_2 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
-
-    # dw1
-    torch.cuda.synchronize()
-    bwd_w_start_time = time.time()
-    backward_w(loss1, model)
-    bwd_w_end_time = time.time()
-    print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
-    for p in model.parameters():
-        assert p.grad is not None
-
-    # common bwd 1
-    torch.cuda.synchronize()
-    comm_bwd_start_time = time.time()
-    ref_loss1.backward()
-    comm_bwd_end_time = time.time()
-    print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
-
-    # # assert dx1 & dw1 == bwd 1
-    # assert_close(x1.grad, ref_x1.grad)
-    # for p1, p2 in zip(model.parameters(), ref_model.parameters()):
-    #     assert_close(p1, p2)
-    #     assert_close(p1.grad, p2.grad)
-
-    # dw2
-    torch.cuda.synchronize()
-    bwd_w_start_time = time.time()
-    backward_w(loss2, model)
-    bwd_w_end_time = time.time()
-    print(f"loss_2 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
-
-    # common bwd 2
-    torch.cuda.synchronize()
-    comm_bwd_start_time = time.time()
-    ref_loss2.backward()
-    comm_bwd_end_time = time.time()
-    print(f"loss_2 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
-
-    # # assert dx2 & dw2 == bwd 2
-    # assert_close(x2.grad, ref_x2.grad)
-    # for p1, p2 in zip(model.parameters(), ref_model.parameters()):
-    #     print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
-    #     assert_close(p1, p2)
-    #     assert_close(p1.grad, p2.grad)
-
-
-def test_double_dx_dw_split_sync():
-    device = "cuda:0"
-    model = nn.Linear(8, 8, bias=None).to(device=device)
-    print(f"model size {get_model_numel(model)} ")  # 4GB
-    x1 = torch.rand(8, 8).to(device=device)
-    x2 = torch.rand(8, 8).to(device=device)
-
-    # x1 = torch.ones(8, 8).to(device=device)
-    # x2 = torch.ones(8, 8).to(device=device)
-
-    ref_model = deepcopy(model)
-    ref_x1 = x1.clone()
-    ref_x2 = x2.clone()
-
-    x1.requires_grad_()
-    x2.requires_grad_()
-    ref_x1.requires_grad_()
-    ref_x2.requires_grad_()
-
-    ############
-    # step1:
-    ############
-
-    # loss1
-    loss1 = model(x1).sum()
-
-    # ref_loss1
-    ref_model(ref_x1).sum()
-
-    # dx1
-    backward_b(loss1, x1, model)
-    for p in model.parameters():
-        assert p.grad is None
-    assert x1.grad is not None
-
-    # dw1
-    backward_w(loss1, model)
-    for p in model.parameters():
-        assert p.grad is not None
-
-    # common bwd 1
-    # ref_loss1.backward()
-
-    # assert dx1 & dw1 == bwd 1
-    assert_close(x1.grad, ref_x1.grad)
-    for p1, p2 in zip(model.parameters(), ref_model.parameters()):
-        assert_close(p1, p2)
-        assert_close(p1.grad, p2.grad)
-
-    ############
-    # step2:
-    ############
-
-    # loss2
-    loss2 = model(x2).sum()
-
-    # ref_loss2
-    ref_loss2 = ref_model(ref_x2).sum()
-
-    for p1, p2 in zip(model.parameters(), ref_model.parameters()):
-        print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
-        assert_close(p1, p2)
-        assert_close(p1.grad, p2.grad)
-
-    # dx2
-    backward_b(loss2, x2, model)
-
-    # dw2
-    backward_w(loss2, model)
-
-    # common bwd 2
-    ref_loss2.backward()
-
-    # assert dx2 & dw2 == bwd 2
-    assert_close(x2.grad, ref_x2.grad)
-    for p1, p2 in zip(model.parameters(), ref_model.parameters()):
-        print(f"bwd2:\n p1 {p1.grad},\n p2 {p2.grad}\n")
-        assert_close(p1, p2)
-        assert_close(p1.grad, p2.grad)
-
-
-def deallocate_output_tensor(out):
-    """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
-
-    This method should be called right after the output tensor has been
-    sent to the next pipeline stage. At this point, the output tensor is
-    only useful for its '.grad_fn' field, and not its '.data'.
-    """
-    assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
-    assert out._base is None, "counter-productive to free a view of another tensor."
-    out.data = torch.empty(
-        (1,),
-        device=out.device,
-        dtype=out.dtype,
-    )
-
-
-IN_DIM = 8192
-OUT_DIM = 8192
-NUM_LAYER = 3
-
-
-class MlpModel(nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.layers = nn.ModuleList([nn.Linear(IN_DIM, OUT_DIM, bias=None) for _ in range(NUM_LAYER)])
-
-    def forward(self, x):
-        for layer in self.layers:
-            x = layer(x)
-        return x
-
-
-class Attention(nn.Module):
-    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, with_qkv=True):
-        super().__init__()
-        self.num_heads = num_heads
-        head_dim = dim // num_heads
-        self.scale = qk_scale or head_dim**-0.5
-        self.with_qkv = with_qkv
-        if self.with_qkv:
-            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
-            self.proj = nn.Linear(dim, dim)
-            self.proj_drop = nn.Dropout(proj_drop)
-        self.attn_drop = nn.Dropout(attn_drop)
-
-    def forward(self, x):
-        B, N, C = x.shape
-        if self.with_qkv:
-            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
-            q, k, v = qkv[0], qkv[1], qkv[2]
-        else:
-            qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
-            q, k, v = qkv, qkv, qkv
-
-        attn = (q @ k.transpose(-2, -1)) * self.scale
-        attn = attn.softmax(dim=-1)
-        attn = self.attn_drop(attn)
-
-        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
-        if self.with_qkv:
-            x = self.proj(x)
-            x = self.proj_drop(x)
-        return x
-
-
-def mem_dx_dw():
-    device = "cuda:0"
-    # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
-    print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-    model = MlpModel().to(device=device)
-    print(f"model numel {get_model_numel(model)}")  # 4GB
-    print(f"After init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    print(f"Before init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-    x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
-    x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
-    x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
-
-    x1.requires_grad_()
-    x2.requires_grad_()
-    x3.requires_grad_()
-    print(f"After init x1&2&3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    ############
-    # step1:
-    ############
-    print(f"\nStep1")
-
-    # loss1
-    print(f"Before Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-    y1 = model(x1)
-    print(f"After Fwd x1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    print(f"Before loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-    loss1 = y1.sum()
-    print(f"After loss1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    # dx1
-    backward_b(loss1, x1, model)
-
-    # dw1
-    backward_w(loss1, model)
-
-    deallocate_output_tensor(x1)
-    deallocate_output_tensor(y1)
-    # del x1
-    # del y1
-    print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    # print(f"\n Step1:collect:{gc.collect()}")
-    # print(f"object: {gc.get_objects()}")
-    # print(f"garbage: {gc.garbage}")
-
-    ############
-    # step2:
-    ############
-    print(f"\nStep2")
-
-    # loss2
-    y2 = model(x2)
-    loss2 = y2.sum()
-
-    # dx2
-    backward_b(loss2, x2, model)
-
-    # dw2
-    backward_w(loss2, model)
-    deallocate_output_tensor(x2)
-    deallocate_output_tensor(y2)
-    # del x2
-    # del y2
-    print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    print(f"\n Step2:collect:{gc.collect()}")
-    # print(f"object: {gc.get_objects()}")
-    print(f"garbage: {gc.garbage}")
-
-    ############
-    # step3:
-    ############
-
-    print(f"\nStep3")
-
-    # loss3
-    y3 = model(x3)
-    loss3 = y3.sum()
-
-    # dx2
-    backward_b(loss3, x3, model)
-
-    # dw2
-    backward_w(loss3, model)
-
-    deallocate_output_tensor(x3)
-    deallocate_output_tensor(y3)
-    # del x3
-    # del y3
-
-    print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    print(f"\n Step3:collect:{gc.collect()}")
-    # print(f"object: {gc.get_objects()}")
-    print(f"garbage: {gc.garbage}")
-
-
-# del activation
-def activation_dx_dw():
-    device = "cuda:0"
-    # model = nn.Linear(IN_DIM, OUT_DIM, bias=None).to(device=device)
-    print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-    model = MlpModel().to(device=device)
-    x1 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
-    x2 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
-    x3 = torch.rand(IN_DIM, OUT_DIM).to(device=device)
-
-    x1.requires_grad_()
-    x2.requires_grad_()
-    x3.requires_grad_()
-    print(f"After init Model, x1,x2,x3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    activations = {}
-
-    def register_hooks(module):
-        def activation_hook(module, input, output):
-            activations[f"{module.__class__.__name__}_{id(module)}"] = output.detach()
-
-        def bwd_hook(module, grad_input, grad_output):
-            del activations[f"{module.__class__.__name__}_{id(module)}"]
-
-        module.register_forward_hook(activation_hook)
-        module.register_backward_hook(bwd_hook)
-
-    model.apply(register_hooks)
-
-    ############
-    # step1:
-    ############
-    print(f"\nStep1")
-
-    # loss1
-    loss1 = model(x1).sum()
-
-    # dx1
-    backward_b(loss1, x1, model)
-
-    # dw1
-    backward_w(loss1, model)
-
-    del loss1, x1
-    print(f"After del x1&y1: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    ############
-    # step2:
-    ############
-    print(f"\nStep2")
-
-    # loss2
-    loss2 = model(x2).sum()
-
-    # dx2
-    backward_b(loss2, x2, model)
-
-    # dw2
-    backward_w(loss2, model)
-
-    # deallocate_output_tensor(x2)
-    # deallocate_output_tensor(loss2)
-    del x2, loss2
-    print(f"After del x2&y2: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    ############
-    # step3:
-    ############
-    print(f"\nStep3")
-
-    # loss3
-    loss3 = model(x3).sum()
-
-    # dx2
-    backward_b(loss3, x3, model)
-
-    # dw2
-    backward_w(loss3, model)
-
-    del x3, loss3
-
-    print(f"After del x3&y3: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-
-# text dx dw in model chunk
-def model_chunk_dx_dw():
-    device = "cuda:0"
-    num_layers = 4
-    print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-    model = MlpModel(in_dim=4096, out_dim=4096, num_layers=num_layers).to(device=device)
-    x = torch.rand(4096, 4096).to(device=device)
-    x.requires_grad_()
-
-    model_chunk_0 = torch.nn.ModuleList()  # for layer 1 & 2
-    model_chunk_1 = torch.nn.ModuleList()  # for layer 3 & 4
-
-    for idx, sub_model in enumerate(model.layers):
-        if idx < 2:
-            model_chunk_0.append(sub_model).cuda()
-        else:
-            model_chunk_1.append(sub_model).cuda()
-
-    print(f"After init Model & input: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    # Step1:chunk 0 fwd
-    activation = dict()  # layer_id: activation
-    out = x
-    for i in range(len(model_chunk_0)):
-        layer = model_chunk_0[i]
-        activation[i] = layer(out)
-    print(f"After chunk0 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-    # Step2:chunk 1 fwd
-    for i in range(len(model_chunk_1)):
-        layer = model_chunk_0[i]
-        activation[i + 2] = layer(out)
-    print(f"After chunk1 fwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;")
-
-    # Step3:chunk 1 bwd b: dx=w*dy & bwd w:dw=x*dy
-    # visit layer reversely
-    for i in range(len(model_chunk_1) - 1, -1, -1):
-        layer = model_chunk_1[i]
-        global_layer_idx = i + 2
-        prev_global_layer_idx = i + 1 if i + 1 > 0 else None
-        i + 3 if i + 3 < 4 else None
-
-        # bwd b
-        if global_layer_idx == num_layers - 1:  # last layer in last chunk; calculate loss
-            loss = activation[global_layer_idx].sum()
-            x = activation[prev_global_layer_idx]
-            backward_b(loss, x, layer)
-        else:
-            loss = activation[global_layer_idx].sum()
-            x = activation[prev_global_layer_idx]
-            backward_b(loss, x, layer)
-
-        # bwd w
-        backward_w(loss, layer)
-
-
-def test_dx_dw_linear_benchmark():
-    device = "cuda:0"
-    model = nn.Linear(4096, 4096, bias=None).to(device=device)
-    # print(f"model numel {get_model_numel(model)}") # 4GB
-    x1 = torch.rand(4096, 4096).to(device=device)
-    # x2 = torch.rand(4096, 4096).to(device=device)
-    ref_model = deepcopy(model)
-    ref_x1 = x1.clone()
-    # ref_x2 = x1.clone()
-
-    # first step
-    x1.requires_grad_()
-    # x2.requires_grad_()
-    ref_x1.requires_grad_()
-    # ref_x2.requires_grad_()
-
-    # loss for dx_dw bwd
-    loss1 = model(x1).sum()
-    # loss2 = model(x2).sum()
-
-    # loss for common bwd
-    ref_model(ref_x1).sum()
-    # ref_loss2 = ref_model(ref_x2).sum()
-
-    with torch.profiler.profile(
-        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
-        # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
-        on_trace_ready=torch.profiler.tensorboard_trace_handler(
-            f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule"
-        ),
-        record_shapes=True,
-        profile_memory=True,
-        with_stack=True,
-        with_flops=True,
-    ) as prof:
-        # dx1
-        torch.cuda.synchronize()
-        bwd_b_start_time = time.time()
-        backward_b(loss1, x1, model)
-        bwd_b_end_time = time.time()
-        print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
-
-        for p in model.parameters():
-            assert p.grad is None
-        assert x1.grad is not None
-
-        # dw1
-        torch.cuda.synchronize()
-        bwd_w_start_time = time.time()
-        backward_w(loss1, model)
-        bwd_w_end_time = time.time()
-        print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
-        for p in model.parameters():
-            assert p.grad is not None
-
-        # # common bwd 1
-        # torch.cuda.synchronize()
-        # comm_bwd_start_time = time.time()
-        # ref_loss1.backward()
-        # comm_bwd_end_time = time.time()
-        # print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
-
-
-def test_dx_dw_attn_benchmark():
-    device = "cuda:0"
-    model = Attention(dim=4096).to(device=device)
-    # print(f"model numel {get_model_numel(model)}") # 4GB
-    x1 = torch.rand(1, 256, 4096).to(device=device)
-    # x2 = torch.rand(1, 256, 4096).to(device=device)
-    ref_model = deepcopy(model)
-    ref_x1 = x1.clone()
-    # ref_x2 = x1.clone()
-
-    # first step
-    x1.requires_grad_()
-    # x2.requires_grad_()
-    ref_x1.requires_grad_()
-    # ref_x2.requires_grad_()
-
-    # loss for dx_dw bwd
-    loss1 = model(x1).sum()
-    # loss2 = model(x2).sum()
-
-    # loss for common bwd
-    ref_model(ref_x1).sum()
-    # ref_loss2 = ref_model(ref_x2).sum()
-
-    with torch.profiler.profile(
-        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
-        # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
-        on_trace_ready=torch.profiler.tensorboard_trace_handler(
-            f"/home/nvme-share/home/duanjunwen/ColossalAI/tests/test_pipeline/test_schedule"
-        ),
-        record_shapes=True,
-        profile_memory=True,
-        with_stack=True,
-        with_flops=True,
-    ) as prof:
-        # dx1
-        torch.cuda.synchronize()
-        bwd_b_start_time = time.time()
-        backward_b(loss1, x1, model)
-        bwd_b_end_time = time.time()
-        print(f"loss_1 bwd B runtime {bwd_b_end_time - bwd_b_start_time}")
-
-        for p in model.parameters():
-            assert p.grad is None
-        assert x1.grad is not None
-
-        # dw1
-        torch.cuda.synchronize()
-        bwd_w_start_time = time.time()
-        backward_w(loss1, model)
-        bwd_w_end_time = time.time()
-        print(f"loss_1 bwd W runtime {bwd_w_end_time - bwd_w_start_time}")
-        for p in model.parameters():
-            assert p.grad is not None
-
-        # # common bwd 1
-        # torch.cuda.synchronize()
-        # comm_bwd_start_time = time.time()
-        # ref_loss1.backward()
-        # comm_bwd_end_time = time.time()
-        # print(f"loss_1 comm bwd runtime {comm_bwd_end_time - comm_bwd_start_time}")
-
-
-if __name__ == "__main__":
-    # test_dx_dw_split()
-    # test_double_dx_dw_split_nsync()
-    # test_double_dx_dw_split_sync()
-    # mem_dx_dw()
-    # activation_dx_dw()
-    # test_dx_dw_linear_benchmark()
-    test_dx_dw_attn_benchmark()
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index ce513f1fd..33707a4f6 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -277,7 +277,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             "precision": "fp16",
             "initial_scale": 1,
         },
-        # TODO: assert layer error
+        # # TODO: assert layer error
         # {
         #     "tp_size": 2,
         #     "pp_size": 2,