[tensor] dist spec s2s uses all-to-all (#1136)

* dist spec s2s uses all-to-all

* update unit test

* add sanity check

* polish unitest test with titans

* add sanity check for DistMgr

* add sanity check

Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
pull/1154/head
ver217 2022-06-22 11:32:38 +08:00 committed by GitHub
parent c77da0dc81
commit ffa025e120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 11 deletions

View File

@ -43,8 +43,25 @@ class DistSpecManager:
_use_autograd_function: bool = True
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
raise NotImplementedError
@staticmethod
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
"""_shard_as: shard the tensor w.r.t a distributed specification.
Assuming the tensor passed in is a global (replicated) tensor.
Args:
tensor (torch.Tensor): a global (replicated) tensor before shard
dist_spec (_DistSpec): the distributed spec. to be sharded as.
Returns:
torch.Tensor: a torch tensor after sharded.
"""
assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor
idx = dist_spec.process_group.rank()
num_parts = prod(dist_spec.num_partitions)
@ -57,6 +74,15 @@ class DistSpecManager:
@staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
"""_gather gather sharded tensors to a replicated one.
Args:
tensor (torch.Tensor): a shared torch tensor
old_dist_spec (_DistSpec): the distributed spec. of the tensor.
Returns:
torch.Tensor: a replicated tensor.
"""
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
if version.parse(torch.__version__) < version.parse("1.11.0"):
# pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather.
@ -78,32 +104,55 @@ class DistSpecManager:
buffer[0].data = buffer[0].data.to(saved_dev)
return buffer[0]
@staticmethod
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
world_size = old_dist_spec.process_group.size()
if world_size == 1:
return tensor
assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device and " \
f"{dist.get_backend(old_dist_spec.process_group)} backend"
gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0]
shapes = list(tensor.shape)
scattered_dim_size = shapes[scatter_dim] // world_size
gathered_dim_size = shapes[gather_dim] * world_size
shapes[scatter_dim] = scattered_dim_size
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group)
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
return output_
@staticmethod
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
raise NotImplementedError
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return tensor
@staticmethod
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group:
raise NotImplementedError
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
@staticmethod
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
if old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
raise NotImplementedError
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._gather(tensor, old_dist_spec)
@staticmethod
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
if old_dist_spec.process_group != dist_spec.process_group:
raise NotImplementedError
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
if old_dist_spec == dist_spec:
return tensor
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
# use all-to-all to save memory
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec)
tensor = DistSpecManager._gather(tensor, old_dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)

View File

@ -25,7 +25,7 @@ def run():
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec)
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec))
col_shard = DistSpecManager._shard_as(x, old_dist_spec, col_spec)
col_shard = DistSpecManager._all_to_all(row_shard, row_spec, col_spec)
assert torch.equal(x.chunk(size, -1)[rank], col_shard)
assert torch.equal(x, DistSpecManager._gather(col_shard, col_spec))
mat_shard = DistSpecManager._shard_as(x, old_dist_spec, mat_spec)