From 3c992a2101ee4849718aaa98b5c574c38c7c9cbd Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 19 Oct 2023 17:29:30 +0800 Subject: [PATCH] fix(pipeline): fix interleave type assert and metrics error (#423) * fix interleave type assert bug * refactor code for assert * fix is_no_pp_or_last_stage logic --- internlm/core/context/parallel_context.py | 3 ++- internlm/core/scheduler/pipeline_scheduler.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 997bd46..915905a 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -329,7 +329,8 @@ class ParallelContext(metaclass=SingletonMeta): return self.is_last_rank(ParallelMode.PIPELINE) def is_no_pp_or_last_stage(self): - return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage() + # NOTICE!!!, this will ignore virutal stage + return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE) def get_world_size(self, parallel_mode: ParallelMode): """Returns the world size for `parallel_mode`. diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 6f2558c..efc9187 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -977,11 +977,12 @@ class InterleavedPipelineScheduler(PipelineScheduler): if gpc.is_pipeline_last_stage(): output_obj = None + assert output_obj is None or output_obj.dtype == self.dtype + # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). if k != (num_warmup_microsteps - 1) or not receive_extra_backward: # Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage - assert output_obj.dtype == self.dtype input_obj = comm.send_forward_recv_forward( output_obj, input_shape, @@ -993,7 +994,6 @@ class InterleavedPipelineScheduler(PipelineScheduler): if self._communication_overlap: # In this case, we should handle forward and backward communication separately, consistent with the # overlap version of the 1F1B stage - assert output_obj.dtype == self.dtype input_obj = comm.send_forward_recv_forward( output_obj, input_shape, @@ -1010,7 +1010,6 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: # In this case, we should handle forward and backward communication together, consistent with the # non-overlap version of the 1F1B stage - assert output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, None, # no backward grad to send @@ -1082,6 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): else: input_obj_shape = self._input_obj_shapes[next_forward_chunk_id] + assert output_obj is None or output_obj.dtype == self.dtype forward_async_communicator = comm.AsynCommunicator( output_obj, input_obj_shape, @@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None # Communicate objs. - assert output_obj.dtype == self.dtype + assert output_obj is None or output_obj.dtype == self.dtype input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( output_obj, input_obj_grad,