diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index ebdb374..0aee2ee 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -130,7 +130,7 @@ class PipelineScheduler(BaseScheduler): self.dtype = dtype self._hooks = scheduler_hooks - self.tensor_shape = ( + self._tensor_shape = ( tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape) ) @@ -146,6 +146,14 @@ class PipelineScheduler(BaseScheduler): # cache for the batch data self.batch_data = None + @property + def tensor_shape(self) -> torch.Size: + return self._tensor_shape + + @tensor_shape.setter + def tensor_shape(self, tensor_shape: torch.Size): + self._tensor_shape = tensor_shape + def pre_processing(self, engine): types = set() @@ -685,6 +693,16 @@ class InterleavedPipelineScheduler(PipelineScheduler): self._output_obj_shapes = [None for _ in range(num_chunks)] self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)] + @property + def tensor_shape(self) -> torch.Size: + return self._tensor_shape + + @tensor_shape.setter + def tensor_shape(self, tensor_shape: torch.Size): + self._tensor_shape = tensor_shape + self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)] + self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)] + def _clear_state(self) -> None: self._accum_loss = None self._return_tensors = None @@ -1250,6 +1268,8 @@ class InterleavedPipelineScheduler(PipelineScheduler): forward_only or return_loss ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + gpc.set_virtual_pipeline_parallel_rank(0) + self.load_batch(engine, data_iter) if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):