mirror of https://github.com/InternLM/InternLM
fix(pipeline): fix interleave type assert and metrics error (#423)
* fix interleave type assert bug * refactor code for assert * fix is_no_pp_or_last_stage logicpull/427/head
parent
3ea46324dd
commit
3c992a2101
|
@ -329,7 +329,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
return self.is_last_rank(ParallelMode.PIPELINE)
|
return self.is_last_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
def is_no_pp_or_last_stage(self):
|
def is_no_pp_or_last_stage(self):
|
||||||
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_pipeline_last_stage()
|
# NOTICE!!!, this will ignore virutal stage
|
||||||
|
return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
def get_world_size(self, parallel_mode: ParallelMode):
|
def get_world_size(self, parallel_mode: ParallelMode):
|
||||||
"""Returns the world size for `parallel_mode`.
|
"""Returns the world size for `parallel_mode`.
|
||||||
|
|
|
@ -977,11 +977,12 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
if gpc.is_pipeline_last_stage():
|
if gpc.is_pipeline_last_stage():
|
||||||
output_obj = None
|
output_obj = None
|
||||||
|
|
||||||
|
assert output_obj is None or output_obj.dtype == self.dtype
|
||||||
|
|
||||||
# Send and receive tensors as appropriate (send tensors computed
|
# Send and receive tensors as appropriate (send tensors computed
|
||||||
# in this iteration; receive tensors for next iteration).
|
# in this iteration; receive tensors for next iteration).
|
||||||
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
|
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
|
||||||
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
|
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
|
||||||
assert output_obj.dtype == self.dtype
|
|
||||||
input_obj = comm.send_forward_recv_forward(
|
input_obj = comm.send_forward_recv_forward(
|
||||||
output_obj,
|
output_obj,
|
||||||
input_shape,
|
input_shape,
|
||||||
|
@ -993,7 +994,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
if self._communication_overlap:
|
if self._communication_overlap:
|
||||||
# In this case, we should handle forward and backward communication separately, consistent with the
|
# In this case, we should handle forward and backward communication separately, consistent with the
|
||||||
# overlap version of the 1F1B stage
|
# overlap version of the 1F1B stage
|
||||||
assert output_obj.dtype == self.dtype
|
|
||||||
input_obj = comm.send_forward_recv_forward(
|
input_obj = comm.send_forward_recv_forward(
|
||||||
output_obj,
|
output_obj,
|
||||||
input_shape,
|
input_shape,
|
||||||
|
@ -1010,7 +1010,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
else:
|
else:
|
||||||
# In this case, we should handle forward and backward communication together, consistent with the
|
# In this case, we should handle forward and backward communication together, consistent with the
|
||||||
# non-overlap version of the 1F1B stage
|
# non-overlap version of the 1F1B stage
|
||||||
assert output_obj.dtype == self.dtype
|
|
||||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||||
output_obj,
|
output_obj,
|
||||||
None, # no backward grad to send
|
None, # no backward grad to send
|
||||||
|
@ -1082,6 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
else:
|
else:
|
||||||
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
|
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
|
||||||
|
|
||||||
|
assert output_obj is None or output_obj.dtype == self.dtype
|
||||||
forward_async_communicator = comm.AsynCommunicator(
|
forward_async_communicator = comm.AsynCommunicator(
|
||||||
output_obj,
|
output_obj,
|
||||||
input_obj_shape,
|
input_obj_shape,
|
||||||
|
@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
|
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
|
||||||
|
|
||||||
# Communicate objs.
|
# Communicate objs.
|
||||||
assert output_obj.dtype == self.dtype
|
assert output_obj is None or output_obj.dtype == self.dtype
|
||||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||||
output_obj,
|
output_obj,
|
||||||
input_obj_grad,
|
input_obj_grad,
|
||||||
|
|
Loading…
Reference in New Issue