mirror of https://github.com/hpcaitech/ColossalAI
[hotfix]change to fit latest p2p (#1100)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [hotfix]change to fit latest p2p
* polish
* polish
pull/1103/head
parent
72bd7c696b
commit
1e9f9c227f
|
@ -86,7 +86,14 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
self.num_microbatches = num_microbatches
|
||||
self.dtype = torch.float
|
||||
self.tensor_shape = tensor_shape
|
||||
assert not isinstance(tensor_shape,
|
||||
int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
|
||||
if tensor_shape is None:
|
||||
self.tensor_shape = tensor_shape
|
||||
elif isinstance(tensor_shape, torch.Size):
|
||||
self.tensor_shape = tensor_shape
|
||||
else:
|
||||
self.tensor_shape = torch.Size(tensor_shape)
|
||||
self.scatter_gather_tensors = False
|
||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
||||
self.scatter_gather_tensors = scatter_gather_tensors
|
||||
|
|
Loading…
Reference in New Issue