diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index a37bef29a..6f605d22c 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -58,7 +58,7 @@ class OptimizerWrapper: def backward_by_grad(self, tensor: Tensor, grad: Tensor): torch.autograd.backward(tensor, grad) - def backward_b_by_grad(self, tensor: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True): + def backward_b_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = True): """ Performs a backward pass for dx, we only calculate dx = w*dy here @@ -69,16 +69,28 @@ class OptimizerWrapper: retain_graph (bool): default to be True, we retain graph in backward_b """ torch.autograd.backward( - tensors=tensor, + tensors=tensors, grad_tensors=grad_tensors, inputs=inputs, retain_graph=retain_graph, ) - def backward_w_by_grad(): + def backward_w_by_grad(self, tensors: Tensor, grad_tensors: Tensor, inputs: Tensor, retain_graph: bool = False): """ Performs a backward pass for dw, we only calculate dw = x*dy here + + Args: + tensor (Tensor): y or loss of current chunk; + grad_tensors (Tensor): dy of current chunk; + input_obj (Tensor): w; + retain_graph (bool): default to be False, we release graph in backward_w """ + torch.autograd.backward( + tensors=tensors, + grad_tensors=grad_tensors, + inputs=inputs, + retain_graph=retain_graph, + ) def state_dict(self): """ diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 90da38fcd..23039af6d 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -13,7 +13,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import detach, get_batch_size, get_micro_batch, retain_grad, to_device +from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -51,8 +51,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.schedules = schedule # TODO: optim post valid self.do_post_validation = False - self.is_first_run = True - self.optimizer = None + # self.is_first_run = True + # self.optimizer = None # P2PMeta cache # self.enable_metadata_cache = enable_metadata_cache @@ -405,6 +405,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss.add_(loss.detach()) if outputs is not None: outputs.append(tree_map(detach, output_obj)) + # print(f"accum_loss {accum_loss}; outputs {len(outputs)}; model_chunk_id {model_chunk_id}") return loss else: return output_obj @@ -438,17 +439,36 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if model_chunk_id == 0: # bwd step - torch.autograd.backward( - tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True + # torch.autograd.backward( + # tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True + # ) + optimizer.backward_b_by_grad( + tensors=output_obj, + grad_tensors=output_obj_grad, + inputs=input_obj, + retain_graph=True, ) else: if self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss - torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True) + # torch.autograd.backward(tensors=output_obj, grad_tensors=None, inputs=input_obj, retain_graph=True) + optimizer.backward_b_by_grad( + tensors=output_obj, + grad_tensors=None, + inputs=input_obj, + retain_graph=True, + ) + else: # commom bwd step - torch.autograd.backward( - tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True + # torch.autograd.backward( + # tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True + # ) + optimizer.backward_b_by_grad( + tensors=output_obj, + grad_tensors=output_obj_grad, + inputs=input_obj, + retain_graph=True, ) return input_obj.grad @@ -457,7 +477,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, - # optimizer: OptimizerWrapper, + optimizer: OptimizerWrapper, output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ): @@ -475,15 +495,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ # calculate bwd w step ; only dw = x*dy; if model_chunk_id == 0: - torch.autograd.backward( + # torch.autograd.backward( + # tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()) + # ) + optimizer.backward_w_by_grad( tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()) ) else: if self.stage_manager.is_first_stage(ignore_chunk=True): - torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())) + # torch.autograd.backward(tensors=output_obj_grad, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters())) + optimizer.backward_w_by_grad( + tensors=output_obj, grad_tensors=None, inputs=list(model_chunk[model_chunk_id].parameters()) + ) else: - torch.autograd.backward( + # torch.autograd.backward( + # tensors=output_obj, + # grad_tensors=output_obj_grad, + # inputs=list(model_chunk[model_chunk_id].parameters()), + # ) + + optimizer.backward_w_by_grad( tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()), @@ -535,7 +567,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss=accum_loss, outputs=outputs, ) - # add input and output object for backward b self.input_tensors[model_chunk_id].append(input_obj) self.output_tensors[model_chunk_id].append(output_obj) @@ -641,7 +672,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): scheduled_node, model_chunk: Union[ModuleList, Module], model_chunk_id: int, - # optimizer: OptimizerWrapper, + optimizer: OptimizerWrapper, ): """A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w); @@ -660,7 +691,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.backward_w_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, - # optimizer: OptimizerWrapper, + optimizer=optimizer, output_obj=output_obj, output_obj_grad=output_obj_grad, ) @@ -677,16 +708,26 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ Runs Zerobubble schedule, with communication between pipeline stages. """ - # # prepare batch + # prepare batch self.load_batch(data_iter) print( f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}" ) + # prepare accum loss & output + accum_loss = None + + # reset accum loss at fwd end; + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) + + outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None + it = 0 # while we still have schedules_node in self.schedules while it < len(self.schedules): scheduled_node = self.schedules[it] + print( f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};" ) @@ -706,8 +747,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, criterion=criterion, - accum_loss=return_loss, - outputs=return_outputs, + accum_loss=accum_loss, + outputs=outputs, ) elif scheduled_node.type == "B": self.schedule_b( @@ -721,5 +762,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, ) it += 1 + + # return loss & output + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index d97e60e2f..8086f4b7d 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -672,7 +672,7 @@ def run_fwd_bwd_vschedule_with_optim( batch_size = batch_size num_layers = 8 assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" - in_dim = out_dim = 8 + in_dim = out_dim = 16 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] @@ -714,15 +714,17 @@ def run_fwd_bwd_vschedule_with_optim( ) torch.cuda.synchronize() - scheduler.run_forward_backward( + result = scheduler.run_forward_backward( model_chunk=local_chunk, data_iter=iter(data_iter), criterion=criterion, optimizer=optimizer_pp, - return_loss=None, - return_outputs=None, + return_loss=True, + return_outputs=True, ) + optimizer_pp.step() + ########################## # Fwd bwd for base ########################## @@ -733,6 +735,15 @@ def run_fwd_bwd_vschedule_with_optim( optimizer_base.step() print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + ########################## + # assert loss & output + ########################## + # only chunk 1 stage 0 hold loss and output + if rank == 0: + assert_close(result["loss"], loss_base) + assert_close(result["outputs"], output_base) + + # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") ########################## # assert weight ########################## @@ -768,6 +779,18 @@ def run_fwd_bwd_vschedule_with_optim( ########################## # assert optim state ########################## + optim_base_state_dict = optimizer_base.state_dict()["param_groups"][0] + optim_pp_state_dict = optimizer_pp.state_dict()["param_groups"][0] + + for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_state_dict.items(), optim_pp_state_dict.items()): + if key_base == key_pp: + if key_base != "params": + assert val_base == val_pp + else: + # BUG: + # param_base: [0, 1, 2, 3, 4, 5, 6, 7]; + # params pp: [0, 1]; + assert val_base[:2] == val_pp @pytest.mark.dist