mirror of https://github.com/InternLM/InternLM
fix(pipeline_scheduler.py): fix tensor shape err and comm block (#210)
parent
f5f5446560
commit
4832671abe
|
@ -130,7 +130,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
self.dtype = dtype
|
||||
self._hooks = scheduler_hooks
|
||||
|
||||
self.tensor_shape = (
|
||||
self._tensor_shape = (
|
||||
tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
|
||||
)
|
||||
|
||||
|
@ -146,6 +146,14 @@ class PipelineScheduler(BaseScheduler):
|
|||
# cache for the batch data
|
||||
self.batch_data = None
|
||||
|
||||
@property
|
||||
def tensor_shape(self) -> torch.Size:
|
||||
return self._tensor_shape
|
||||
|
||||
@tensor_shape.setter
|
||||
def tensor_shape(self, tensor_shape: torch.Size):
|
||||
self._tensor_shape = tensor_shape
|
||||
|
||||
def pre_processing(self, engine):
|
||||
types = set()
|
||||
|
||||
|
@ -685,6 +693,16 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self._output_obj_shapes = [None for _ in range(num_chunks)]
|
||||
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
|
||||
|
||||
@property
|
||||
def tensor_shape(self) -> torch.Size:
|
||||
return self._tensor_shape
|
||||
|
||||
@tensor_shape.setter
|
||||
def tensor_shape(self, tensor_shape: torch.Size):
|
||||
self._tensor_shape = tensor_shape
|
||||
self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)]
|
||||
self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)]
|
||||
|
||||
def _clear_state(self) -> None:
|
||||
self._accum_loss = None
|
||||
self._return_tensors = None
|
||||
|
@ -1250,6 +1268,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
|
||||
self.load_batch(engine, data_iter)
|
||||
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
|
|
Loading…
Reference in New Issue