From e6e1a97a6d2d69fc8cd2907883e0627a61e6f372 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 4 Sep 2024 03:31:08 +0000 Subject: [PATCH] [fix] fix requir grad position and detach position and input&output local buffer append position; --- .../pipeline/schedule/zero_bubble_pp.py | 37 +++++-------------- .../test_schedule/test_zerobubble_pp.py | 8 ++-- 2 files changed, 13 insertions(+), 32 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 5c9a02d4e..ad0adc7f7 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.cuda -import torch.distributed from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map @@ -496,29 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): inputs=list(model_chunk[model_chunk_id].parameters()), retain_graph=False, ) - # if model_chunk_id == 0: - # optimizer.backward_by_grad( - # tensor=output_obj, - # grad=output_obj_grad, - # inputs=list(model_chunk[model_chunk_id].parameters()), - # retain_graph=False, - # ) - - # else: - # if self.stage_manager.is_first_stage(ignore_chunk=True): - # optimizer.backward_by_grad( - # tensor=output_obj, - # grad=None, - # inputs=list(model_chunk[model_chunk_id].parameters()), - # retain_graph=False, - # ) - # else: - # optimizer.backward_by_grad( - # tensor=output_obj, - # grad=output_obj_grad, - # inputs=list(model_chunk[model_chunk_id].parameters()), - # retain_graph=False, - # ) def schedule_f( self, @@ -557,6 +533,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # not last stage; recv from next else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + input_obj.requires_grad_() # Step2: fwd step output_obj = self.forward_step( @@ -567,21 +544,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) - - detached_output_obj = output_obj.clone() + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not detach bwd LOSS + detached_output_obj = output_obj.clone() + else: + detached_output_obj = output_obj.clone().detach() # Step3: send fwd # add output to send_fwd_buffer if model_chunk_id == 0: # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): + detached_output_obj = detached_output_obj.detach() self.local_send_forward_buffer.append(detached_output_obj) else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) else: # is first stage; end of fwd; append LOSS to local_send_backward_buffer if self.stage_manager.is_first_stage(ignore_chunk=True): - self.local_send_backward_buffer.append(detached_output_obj) + pass else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) @@ -624,7 +605,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): - output_tensor_grad = self.local_send_backward_buffer.pop(0) + output_tensor_grad = None # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 64e4b0676..3d07bb1dd 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -44,7 +44,7 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: "test_config", [ { - "batch_size": 4, + "batch_size": 8, "tp_size": 1, "pp_size": 4, "num_microbatches": 4, @@ -501,7 +501,7 @@ def run_fwd_bwd_iter_input(test_config): "test_config", [ { - "batch_size": 4, + "batch_size": 8, "tp_size": 1, "pp_size": 4, "num_microbatches": 4, @@ -689,13 +689,13 @@ def run_with_hybridplugin(test_config): "test_config", [ { - "batch_size": 4, + "batch_size": 8, "tp_size": 1, "pp_size": 4, "num_microbatches": 4, "zero_stage": 1, "precision": "bf16", - "num_model_chunk": 4, + "num_model_chunk": 2, }, ], )