mirror of https://github.com/InternLM/InternLM
fix(pipeline): fix interleave type assert and metrics error (#423)
* fix interleave type assert bug * refactor code for assert * fix is_no_pp_or_last_stage logicpull/427/head
parent
3ea46324dd
commit
3c992a2101
|
@ -329,7 +329,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||
|
||||
def is_no_pp_or_last_stage(self):
|
||||
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
|
||||
# NOTICE!!!, this will ignore virutal stage
|
||||
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE)
|
||||
|
||||
def get_world_size(self, parallel_mode: ParallelMode):
|
||||
"""Returns the world size for `parallel_mode`.
|
||||
|
|
|
@ -977,11 +977,12 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
if gpc.is_pipeline_last_stage():
|
||||
output_obj = None
|
||||
|
||||
assert output_obj is None or output_obj.dtype == self.dtype
|
||||
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
|
||||
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
|
||||
assert output_obj.dtype == self.dtype
|
||||
input_obj = comm.send_forward_recv_forward(
|
||||
output_obj,
|
||||
input_shape,
|
||||
|
@ -993,7 +994,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
if self._communication_overlap:
|
||||
# In this case, we should handle forward and backward communication separately, consistent with the
|
||||
# overlap version of the 1F1B stage
|
||||
assert output_obj.dtype == self.dtype
|
||||
input_obj = comm.send_forward_recv_forward(
|
||||
output_obj,
|
||||
input_shape,
|
||||
|
@ -1010,7 +1010,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
else:
|
||||
# In this case, we should handle forward and backward communication together, consistent with the
|
||||
# non-overlap version of the 1F1B stage
|
||||
assert output_obj.dtype == self.dtype
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
None, # no backward grad to send
|
||||
|
@ -1082,6 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
else:
|
||||
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
|
||||
|
||||
assert output_obj is None or output_obj.dtype == self.dtype
|
||||
forward_async_communicator = comm.AsynCommunicator(
|
||||
output_obj,
|
||||
input_obj_shape,
|
||||
|
@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
|
||||
|
||||
# Communicate objs.
|
||||
assert output_obj.dtype == self.dtype
|
||||
assert output_obj is None or output_obj.dtype == self.dtype
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
input_obj_grad,
|
||||
|
|
Loading…
Reference in New Issue