fix(pipeline_scheduler.py): fix tensor shape err and comm block (#210)

pull/216/head^2
huangting4201 2023-08-21 12:09:27 +08:00 committed by GitHub
parent f5f5446560
commit 4832671abe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 1 deletions

View File

@ -130,7 +130,7 @@ class PipelineScheduler(BaseScheduler):
self.dtype = dtype
self._hooks = scheduler_hooks
self.tensor_shape = (
self._tensor_shape = (
tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
)
@ -146,6 +146,14 @@ class PipelineScheduler(BaseScheduler):
# cache for the batch data
self.batch_data = None
@property
def tensor_shape(self) -> torch.Size:
return self._tensor_shape
@tensor_shape.setter
def tensor_shape(self, tensor_shape: torch.Size):
self._tensor_shape = tensor_shape
def pre_processing(self, engine):
types = set()
@ -685,6 +693,16 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._output_obj_shapes = [None for _ in range(num_chunks)]
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
@property
def tensor_shape(self) -> torch.Size:
return self._tensor_shape
@tensor_shape.setter
def tensor_shape(self, tensor_shape: torch.Size):
self._tensor_shape = tensor_shape
self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)]
self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)]
def _clear_state(self) -> None:
self._accum_loss = None
self._return_tensors = None
@ -1250,6 +1268,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
gpc.set_virtual_pipeline_parallel_rank(0)
self.load_batch(engine, data_iter)
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):