fix interleave type assert bug

pull/423/head
Wenwen Qu 2023-10-18 13:56:42 +08:00
parent aa5e34d815
commit 12f897f553
1 changed files with 4 additions and 4 deletions

View File

@ -977,11 +977,12 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if gpc.is_pipeline_last_stage():
output_obj = None
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
assert output_obj.dtype == self.dtype
input_obj = comm.send_forward_recv_forward(
output_obj,
input_shape,
@ -993,7 +994,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if self._communication_overlap:
# In this case, we should handle forward and backward communication separately, consistent with the
# overlap version of the 1F1B stage
assert output_obj.dtype == self.dtype
input_obj = comm.send_forward_recv_forward(
output_obj,
input_shape,
@ -1010,7 +1010,6 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else:
# In this case, we should handle forward and backward communication together, consistent with the
# non-overlap version of the 1F1B stage
assert output_obj.dtype == self.dtype
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
None, # no backward grad to send
@ -1082,6 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else:
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
forward_async_communicator = comm.AsynCommunicator(
output_obj,
input_obj_shape,
@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
# Communicate objs.
assert output_obj.dtype == self.dtype
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
input_obj_grad,