From 6480e03949ece71e670364f231953f58fafb94ae Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Wed, 18 Oct 2023 19:22:33 +0800 Subject: [PATCH] refactor code for assert --- internlm/core/scheduler/pipeline_scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index c14d84a..efc9187 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -977,7 +977,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): if gpc.is_pipeline_last_stage(): output_obj = None - assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype + assert output_obj is None or output_obj.dtype == self.dtype # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). @@ -1081,7 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: input_obj_shape = self._input_obj_shapes[next_forward_chunk_id] - assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype + assert output_obj is None or output_obj.dtype == self.dtype forward_async_communicator = comm.AsynCommunicator( output_obj, input_obj_shape, @@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None # Communicate objs. - assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype + assert output_obj is None or output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, input_obj_grad,