refactor code for assert

pull/423/head
Wenwen Qu 2023-10-18 19:22:33 +08:00
parent 12f897f553
commit 6480e03949
1 changed files with 3 additions and 3 deletions

View File

@ -977,7 +977,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if gpc.is_pipeline_last_stage():
output_obj = None
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
assert output_obj is None or output_obj.dtype == self.dtype
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
@ -1081,7 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
else:
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
assert output_obj is None or output_obj.dtype == self.dtype
forward_async_communicator = comm.AsynCommunicator(
output_obj,
input_obj_shape,
@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
# Communicate objs.
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
assert output_obj is None or output_obj.dtype == self.dtype
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
input_obj_grad,