From 1e9f9c227f3b6bf191a1e708889194993631f538 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Mon, 13 Jun 2022 14:57:25 +0800 Subject: [PATCH] [hotfix]change to fit latest p2p (#1100) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [hotfix]change to fit latest p2p * polish * polish --- colossalai/engine/schedule/_pipeline_schedule.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 5a6b9597f..2b2c4ecbc 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -86,7 +86,14 @@ class PipelineSchedule(BaseSchedule): self.num_microbatches = num_microbatches self.dtype = torch.float - self.tensor_shape = tensor_shape + assert not isinstance(tensor_shape, + int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." + if tensor_shape is None: + self.tensor_shape = tensor_shape + elif isinstance(tensor_shape, torch.Size): + self.tensor_shape = tensor_shape + else: + self.tensor_shape = torch.Size(tensor_shape) self.scatter_gather_tensors = False if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: self.scatter_gather_tensors = scatter_gather_tensors