From 12f897f553f7e4b12a89e09bad72c6acc218931a Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Wed, 18 Oct 2023 13:56:42 +0800 Subject: [PATCH] fix interleave type assert bug --- internlm/core/scheduler/pipeline_scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 6f2558c..c14d84a 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -977,11 +977,12 @@ class InterleavedPipelineScheduler(PipelineScheduler): if gpc.is_pipeline_last_stage(): output_obj = None + assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype + # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). if k != (num_warmup_microsteps - 1) or not receive_extra_backward: # Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage - assert output_obj.dtype == self.dtype input_obj = comm.send_forward_recv_forward( output_obj, input_shape, @@ -993,7 +994,6 @@ class InterleavedPipelineScheduler(PipelineScheduler): if self._communication_overlap: # In this case, we should handle forward and backward communication separately, consistent with the # overlap version of the 1F1B stage - assert output_obj.dtype == self.dtype input_obj = comm.send_forward_recv_forward( output_obj, input_shape, @@ -1010,7 +1010,6 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: # In this case, we should handle forward and backward communication together, consistent with the # non-overlap version of the 1F1B stage - assert output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, None, # no backward grad to send @@ -1082,6 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: input_obj_shape = self._input_obj_shapes[next_forward_chunk_id] + assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype forward_async_communicator = comm.AsynCommunicator( output_obj, input_obj_shape, @@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None # Communicate objs. - assert output_obj.dtype == self.dtype + assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, input_obj_grad,