diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 5a6b9597f..2b2c4ecbc 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -86,7 +86,14 @@ class PipelineSchedule(BaseSchedule): self.num_microbatches = num_microbatches self.dtype = torch.float - self.tensor_shape = tensor_shape + assert not isinstance(tensor_shape, + int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." + if tensor_shape is None: + self.tensor_shape = tensor_shape + elif isinstance(tensor_shape, torch.Size): + self.tensor_shape = tensor_shape + else: + self.tensor_shape = torch.Size(tensor_shape) self.scatter_gather_tensors = False if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: self.scatter_gather_tensors = scatter_gather_tensors