diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index c14d84a..efc9187 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -977,7 +977,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): if gpc.is_pipeline_last_stage(): output_obj = None - assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype + assert output_obj is None or output_obj.dtype == self.dtype # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). @@ -1081,7 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: input_obj_shape = self._input_obj_shapes[next_forward_chunk_id] - assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype + assert output_obj is None or output_obj.dtype == self.dtype forward_async_communicator = comm.AsynCommunicator( output_obj, input_obj_shape, @@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None # Communicate objs. - assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype + assert output_obj is None or output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, input_obj_grad,