From fc8b016887e48b03a52e789770d295a8a9842943 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 25 Sep 2024 06:15:45 +0000 Subject: [PATCH] [fix] fix stage_indices; --- .../pipeline/schedule/zero_bubble_pp.py | 26 ++++++++++++------- .../test_schedule/test_zerobubble_pp.py | 3 +++ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index bbad921b2..307d1035c 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -430,7 +430,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): with self.stage_manager.switch_model_chunk_id(model_chunk_id): # fwd calculate internal_inputs = {} if input_obj is None else input_obj - # internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) # last layer in model @@ -480,22 +480,26 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # For chunk 0 stage 0, use micro_batch as input_obj_ if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj_, _ = tree_flatten(micro_batch) - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + input_obj_, _ = tree_flatten({k: v for k, v in micro_batch.items() if isinstance(v, torch.Tensor)}) + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) # dy # For loss backward; output_obj is loss; output_obj_grad should be None elif model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): assert output_obj_grad is None - input_obj_, _ = tree_flatten(input_obj) + input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) output_obj_.append(output_obj) # LOSS output_obj_grad_.append(output_obj_grad) # None # For other chunk stage, use input_obj as input_obj_; else: - input_obj_, _ = tree_flatten(input_obj) - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + input_obj_, _ = tree_flatten({k: v for k, v in input_obj.items() if isinstance(v, torch.Tensor)}) + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) # dy optimizer.backward_by_grad( tensor=output_obj_, @@ -547,8 +551,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_.append(output_obj) # LOSS output_obj_grad_.append(None) # None else: - output_obj_, _ = tree_flatten(output_obj) # y - output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + output_obj_, _ = tree_flatten({k: v for k, v in output_obj.items() if isinstance(v, torch.Tensor)}) # y + output_obj_grad_, _ = tree_flatten( + {k: v for k, v in output_obj_grad.items() if isinstance(v, torch.Tensor)} + ) # dy optimizer.backward_by_grad( tensor=output_obj_, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 14bc3475d..9fa636504 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -39,6 +39,7 @@ def pp_linear_fwd( forward, data: torch.Tensor = None, hidden_states: torch.Tensor = None, + stage_index=None, stage_mgr: PipelineStageManager = None, model_chunk_id: int = None, ): @@ -605,6 +606,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # input_base = [t.clone() for t in data_iter] input_base = {k: v.clone() for k, v in data_iter.items()} model_base = deepcopy(model) + layers_per_stage = stage_manager.distribute_layers(len(model.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) if rank == 0: # layer 0 & 7 to chunk 0 on rank0