From 9912cc8c07f66e9f5537d469428b1f06f890e29a Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 15 Oct 2024 06:26:01 +0000 Subject: [PATCH] [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd; --- colossalai/pipeline/schedule/zero_bubble_pp.py | 9 +++------ colossalai/shardformer/layer/_operation.py | 1 - colossalai/shardformer/layer/linear.py | 1 - .../test_pipeline/test_schedule/test_zerobubble_pp.py | 11 ++++++++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 089ca48ee..e155284bf 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -509,12 +509,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): optimizer.backward_by_grad( tensor=output_obj_, grad=output_obj_grad_, - inputs=input_obj_, - retain_graph=True, + # inputs=input_obj_, + # retain_graph=True, ) - # Format output_obj_grad - input_obj_grad = {} + input_obj_grad = dict() if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): pass else: @@ -714,7 +713,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # # we save output_tensor_grad here # self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - # Step2: bwd step input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, @@ -761,7 +759,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # get y & dy from buffer # output_obj = self.output_tensors_dw[model_chunk_id].pop(0) # output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) - WeightGradStore.pop(chunk=model_chunk_id) # self.backward_w_step( diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 9d3d91034..4a0800468 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -177,7 +177,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function): handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py - if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad if use_zbv: diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index cb3ad0b45..a8a3be63a 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -230,7 +230,6 @@ class Linear1D_Col(ParallelModule): fp8_communication=self.fp8_communication, use_zbv=self.use_zbv, ) - if self.gather_output: # All-gather across the partitions. output = gather_forward_split_backward( diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index fb59e0b2c..6286cc6f0 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -753,8 +753,10 @@ def run_with_hybridplugin(test_config): "config", [ # TODO:ERR in second iter - # (0, 1, 4, 1, 1), - # (1, 2, 2, 1, 1), + (0, 1, 4, 1, 1), + (1, 2, 2, 1, 1), + (1, 1, 2, 2, 1), + # Pass (1, 2, 1, 2, 1), (1, 2, 1, 1, 2), ], @@ -891,19 +893,22 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): # =================================================================================== # run normal model with all dp(different) inputs - all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] + all_inputs = [input_embeddings.clone() for _ in range(dp_size)] dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output.backward() torch_output_sum += torch_output.detach() + # print(f"parallel_output {parallel_output} torch_output_sum {torch_output_sum}") # avg dp grads follows zero optimizer for p in torch_model.parameters(): if p.grad is not None: p.grad /= dp_size torch_optimizer.step() torch_optimizer.zero_grad() + + # print(f"rank {dist.get_rank()} parallel_output {parallel_output} torch_output_sum {torch_output_sum}") assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) print(f"rank {dist.get_rank()} config {test_config} test passed") clear_layout_converter()