fix(pipeline): fix interleave type assert and metrics error (#423)

* fix interleave type assert bug

* refactor code for assert

* fix is_no_pp_or_last_stage logic
pull/427/head
Wenwen Qu 2023-10-19 17:29:30 +08:00 committed by GitHub
parent 3ea46324dd
commit 3c992a2101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -329,7 +329,8 @@ class ParallelContext(metaclass=SingletonMeta):
return self.is_last_rank(ParallelMode.PIPELINE)
def is_no_pp_or_last_stage(self):
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
# NOTICE!!!, this will ignore virutal stage
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE)
def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.

View File

@ -977,11 +977,12 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if gpc.is_pipeline_last_stage():
output_obj = None
assert output_obj is None or output_obj.dtype == self.dtype
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
assert output_obj.dtype == self.dtype
input_obj = comm.send_forward_recv_forward(
output_obj,
input_shape,
@ -993,7 +994,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if self._communication_overlap:
# In this case, we should handle forward and backward communication separately, consistent with the
# overlap version of the 1F1B stage
assert output_obj.dtype == self.dtype
input_obj = comm.send_forward_recv_forward(
output_obj,
input_shape,
@ -1010,7 +1010,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else:
# In this case, we should handle forward and backward communication together, consistent with the
# non-overlap version of the 1F1B stage
assert output_obj.dtype == self.dtype
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
None, # no backward grad to send
@ -1082,6 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else:
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
assert output_obj is None or output_obj.dtype == self.dtype
forward_async_communicator = comm.AsynCommunicator(
output_obj,
input_obj_shape,
@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
# Communicate objs.
assert output_obj.dtype == self.dtype
assert output_obj is None or output_obj.dtype == self.dtype
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
input_obj_grad,