From 1739df423c79b0c52ff5957b7992c14081d5dd24 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 20 Sep 2024 07:34:43 +0000 Subject: [PATCH] [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs' --- colossalai/pipeline/schedule/zero_bubble_pp.py | 15 +++------------ .../test_schedule/test_zerobubble_pp.py | 6 +++--- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 1af62cc8a..bc2b0b7bf 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -429,18 +429,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # Only attention_mask from micro_batch is used with self.stage_manager.switch_model_chunk_id(model_chunk_id): # fwd calculate - if isinstance(model_chunk, ModuleList): - # fwd for ModuleList model - if input_obj is None: - output_obj = model_chunk[model_chunk_id](**micro_batch) - else: - output_obj = model_chunk[model_chunk_id](**input_obj) - else: - # fwd for shardformer - # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers - internal_inputs = {} if input_obj is None else input_obj - # internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) + internal_inputs = {} if input_obj is None else input_obj + # internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ab69d93d3..8ac1f6d01 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -48,7 +48,7 @@ def pp_linear_fwd( return forward(hidden_states) # fwd start elif stage_mgr.is_first_stage() and model_chunk_id == 0: - return {"hidden_states": forward(hidden_states)} + return {"hidden_states": forward(data)} # fwd middle else: return {"hidden_states": forward(hidden_states)} @@ -601,7 +601,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) # data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} + data_iter = {"data": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} # input_base = [t.clone() for t in data_iter] input_base = {k: v.clone() for k, v in data_iter.items()} model_base = deepcopy(model) @@ -694,7 +694,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # Fwd bwd for base ########################## # fwd & bwd - output_base = model_base(input_base["hidden_states"]) + output_base = model_base(input_base["data"]) loss_base = criterion_base(output_base) loss_base.backward() optimizer_base.step()