[hotfix]change to fit latest p2p (#1100)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [hotfix]change to fit latest p2p

* polish

* polish
pull/1103/head
YuliangLiu0306 2022-06-13 14:57:25 +08:00 committed by GitHub
parent 72bd7c696b
commit 1e9f9c227f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 1 deletions

View File

@ -86,7 +86,14 @@ class PipelineSchedule(BaseSchedule):
self.num_microbatches = num_microbatches
self.dtype = torch.float
self.tensor_shape = tensor_shape
assert not isinstance(tensor_shape,
int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
if tensor_shape is None:
self.tensor_shape = tensor_shape
elif isinstance(tensor_shape, torch.Size):
self.tensor_shape = tensor_shape
else:
self.tensor_shape = torch.Size(tensor_shape)
self.scatter_gather_tensors = False
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
self.scatter_gather_tensors = scatter_gather_tensors