mirror of https://github.com/hpcaitech/ColossalAI
[tensor] dist spec s2s uses all-to-all (#1136)
* dist spec s2s uses all-to-all * update unit test * add sanity check * polish unitest test with titans * add sanity check for DistMgr * add sanity check Co-authored-by: jiaruifang <fangjiarui123@gmail.com>pull/1154/head
parent
c77da0dc81
commit
ffa025e120
|
@ -43,8 +43,25 @@ class DistSpecManager:
|
||||||
|
|
||||||
_use_autograd_function: bool = True
|
_use_autograd_function: bool = True
|
||||||
|
|
||||||
|
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
|
||||||
|
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
|
||||||
|
and dist_spec.process_group is not None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
|
"""_shard_as: shard the tensor w.r.t a distributed specification.
|
||||||
|
Assuming the tensor passed in is a global (replicated) tensor.
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): a global (replicated) tensor before shard
|
||||||
|
dist_spec (_DistSpec): the distributed spec. to be sharded as.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: a torch tensor after sharded.
|
||||||
|
"""
|
||||||
|
assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
|
||||||
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||||
|
|
||||||
chunk = tensor
|
chunk = tensor
|
||||||
idx = dist_spec.process_group.rank()
|
idx = dist_spec.process_group.rank()
|
||||||
num_parts = prod(dist_spec.num_partitions)
|
num_parts = prod(dist_spec.num_partitions)
|
||||||
|
@ -57,6 +74,15 @@ class DistSpecManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
|
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
|
"""_gather gather sharded tensors to a replicated one.
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): a shared torch tensor
|
||||||
|
old_dist_spec (_DistSpec): the distributed spec. of the tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: a replicated tensor.
|
||||||
|
"""
|
||||||
|
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
|
||||||
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
||||||
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
||||||
# Therefore, we transfer tensor to GPU before gather.
|
# Therefore, we transfer tensor to GPU before gather.
|
||||||
|
@ -78,32 +104,55 @@ class DistSpecManager:
|
||||||
buffer[0].data = buffer[0].data.to(saved_dev)
|
buffer[0].data = buffer[0].data.to(saved_dev)
|
||||||
return buffer[0]
|
return buffer[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
|
world_size = old_dist_spec.process_group.size()
|
||||||
|
if world_size == 1:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \
|
||||||
|
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||||
|
f"collective function, however, we got {tensor.device.type} device and " \
|
||||||
|
f"{dist.get_backend(old_dist_spec.process_group)} backend"
|
||||||
|
|
||||||
|
gather_dim = old_dist_spec.dims[0]
|
||||||
|
scatter_dim = dist_spec.dims[0]
|
||||||
|
shapes = list(tensor.shape)
|
||||||
|
scattered_dim_size = shapes[scatter_dim] // world_size
|
||||||
|
gathered_dim_size = shapes[gather_dim] * world_size
|
||||||
|
shapes[scatter_dim] = scattered_dim_size
|
||||||
|
|
||||||
|
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
|
||||||
|
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||||
|
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group)
|
||||||
|
|
||||||
|
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
|
||||||
|
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
|
||||||
|
return output_
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||||
and dist_spec.process_group is not None:
|
|
||||||
raise NotImplementedError
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group:
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||||
raise NotImplementedError
|
|
||||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
|
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
if old_dist_spec.process_group != dist_spec.process_group \
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||||
and dist_spec.process_group is not None:
|
|
||||||
raise NotImplementedError
|
|
||||||
return DistSpecManager._gather(tensor, old_dist_spec)
|
return DistSpecManager._gather(tensor, old_dist_spec)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
if old_dist_spec.process_group != dist_spec.process_group:
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||||
raise NotImplementedError
|
|
||||||
if old_dist_spec == dist_spec:
|
if old_dist_spec == dist_spec:
|
||||||
return tensor
|
return tensor
|
||||||
|
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
|
||||||
|
# use all-to-all to save memory
|
||||||
|
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec)
|
||||||
tensor = DistSpecManager._gather(tensor, old_dist_spec)
|
tensor = DistSpecManager._gather(tensor, old_dist_spec)
|
||||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
|
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ def run():
|
||||||
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec)
|
row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec)
|
||||||
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
|
assert torch.equal(x.chunk(size, 0)[rank], row_shard)
|
||||||
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec))
|
assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec))
|
||||||
col_shard = DistSpecManager._shard_as(x, old_dist_spec, col_spec)
|
col_shard = DistSpecManager._all_to_all(row_shard, row_spec, col_spec)
|
||||||
assert torch.equal(x.chunk(size, -1)[rank], col_shard)
|
assert torch.equal(x.chunk(size, -1)[rank], col_shard)
|
||||||
assert torch.equal(x, DistSpecManager._gather(col_shard, col_spec))
|
assert torch.equal(x, DistSpecManager._gather(col_shard, col_spec))
|
||||||
mat_shard = DistSpecManager._shard_as(x, old_dist_spec, mat_spec)
|
mat_shard = DistSpecManager._shard_as(x, old_dist_spec, mat_spec)
|
||||||
|
|
Loading…
Reference in New Issue