mirror of https://github.com/InternLM/InternLM
refactor code for assert
parent
12f897f553
commit
6480e03949
|
@ -977,7 +977,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
if gpc.is_pipeline_last_stage():
|
||||
output_obj = None
|
||||
|
||||
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
|
||||
assert output_obj is None or output_obj.dtype == self.dtype
|
||||
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
|
@ -1081,7 +1081,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
else:
|
||||
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
|
||||
|
||||
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
|
||||
assert output_obj is None or output_obj.dtype == self.dtype
|
||||
forward_async_communicator = comm.AsynCommunicator(
|
||||
output_obj,
|
||||
input_obj_shape,
|
||||
|
@ -1203,7 +1203,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
|
||||
|
||||
# Communicate objs.
|
||||
assert gpc.is_pipeline_last_stage() or output_obj.dtype == self.dtype
|
||||
assert output_obj is None or output_obj.dtype == self.dtype
|
||||
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
||||
output_obj,
|
||||
input_obj_grad,
|
||||
|
|
Loading…
Reference in New Issue