mirror of https://github.com/InternLM/InternLM
fix(pipeline_scheduler.py): fix tensor shape err and comm block (#210)
parent
f5f5446560
commit
4832671abe
|
@ -130,7 +130,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self._hooks = scheduler_hooks
|
self._hooks = scheduler_hooks
|
||||||
|
|
||||||
self.tensor_shape = (
|
self._tensor_shape = (
|
||||||
tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
|
tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -146,6 +146,14 @@ class PipelineScheduler(BaseScheduler):
|
||||||
# cache for the batch data
|
# cache for the batch data
|
||||||
self.batch_data = None
|
self.batch_data = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tensor_shape(self) -> torch.Size:
|
||||||
|
return self._tensor_shape
|
||||||
|
|
||||||
|
@tensor_shape.setter
|
||||||
|
def tensor_shape(self, tensor_shape: torch.Size):
|
||||||
|
self._tensor_shape = tensor_shape
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
types = set()
|
types = set()
|
||||||
|
|
||||||
|
@ -685,6 +693,16 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
self._output_obj_shapes = [None for _ in range(num_chunks)]
|
self._output_obj_shapes = [None for _ in range(num_chunks)]
|
||||||
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
|
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tensor_shape(self) -> torch.Size:
|
||||||
|
return self._tensor_shape
|
||||||
|
|
||||||
|
@tensor_shape.setter
|
||||||
|
def tensor_shape(self, tensor_shape: torch.Size):
|
||||||
|
self._tensor_shape = tensor_shape
|
||||||
|
self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)]
|
||||||
|
self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)]
|
||||||
|
|
||||||
def _clear_state(self) -> None:
|
def _clear_state(self) -> None:
|
||||||
self._accum_loss = None
|
self._accum_loss = None
|
||||||
self._return_tensors = None
|
self._return_tensors = None
|
||||||
|
@ -1250,6 +1268,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
forward_only or return_loss
|
forward_only or return_loss
|
||||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||||
|
|
||||||
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||||
|
|
||||||
self.load_batch(engine, data_iter)
|
self.load_batch(engine, data_iter)
|
||||||
|
|
||||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||||
|
|
Loading…
Reference in New Issue