diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index 8ee560ff9..f76b624c9 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -43,8 +43,25 @@ class DistSpecManager: _use_autograd_function: bool = True + def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None: + if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \ + and dist_spec.process_group is not None: + raise NotImplementedError + @staticmethod def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: + """_shard_as: shard the tensor w.r.t a distributed specification. + Assuming the tensor passed in is a global (replicated) tensor. + Args: + tensor (torch.Tensor): a global (replicated) tensor before shard + dist_spec (_DistSpec): the distributed spec. to be sharded as. + + Returns: + torch.Tensor: a torch tensor after sharded. + """ + assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" + DistSpecManager._sanity_check(old_dist_spec, dist_spec) + chunk = tensor idx = dist_spec.process_group.rank() num_parts = prod(dist_spec.num_partitions) @@ -57,6 +74,15 @@ class DistSpecManager: @staticmethod def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor: + """_gather gather sharded tensors to a replicated one. + Args: + tensor (torch.Tensor): a shared torch tensor + old_dist_spec (_DistSpec): the distributed spec. of the tensor. + + Returns: + torch.Tensor: a replicated tensor. + """ + assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" if version.parse(torch.__version__) < version.parse("1.11.0"): # pytorch lower than 1.11 dose not support gather a cpu tensor. # Therefore, we transfer tensor to GPU before gather. @@ -78,32 +104,55 @@ class DistSpecManager: buffer[0].data = buffer[0].data.to(saved_dev) return buffer[0] + @staticmethod + def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: + world_size = old_dist_spec.process_group.size() + if world_size == 1: + return tensor + + assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \ + "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ + f"collective function, however, we got {tensor.device.type} device and " \ + f"{dist.get_backend(old_dist_spec.process_group)} backend" + + gather_dim = old_dist_spec.dims[0] + scatter_dim = dist_spec.dims[0] + shapes = list(tensor.shape) + scattered_dim_size = shapes[scatter_dim] // world_size + gathered_dim_size = shapes[gather_dim] * world_size + shapes[scatter_dim] = scattered_dim_size + + scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)] + gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group) + + output_ = torch.cat(gather_list, dim=gather_dim).contiguous() + assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size + return output_ + @staticmethod def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: - if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \ - and dist_spec.process_group is not None: - raise NotImplementedError + DistSpecManager._sanity_check(old_dist_spec, dist_spec) return tensor @staticmethod def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: - if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group: - raise NotImplementedError + DistSpecManager._sanity_check(old_dist_spec, dist_spec) return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec) @staticmethod def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: - if old_dist_spec.process_group != dist_spec.process_group \ - and dist_spec.process_group is not None: - raise NotImplementedError + DistSpecManager._sanity_check(old_dist_spec, dist_spec) return DistSpecManager._gather(tensor, old_dist_spec) @staticmethod def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor: - if old_dist_spec.process_group != dist_spec.process_group: - raise NotImplementedError + DistSpecManager._sanity_check(old_dist_spec, dist_spec) if old_dist_spec == dist_spec: return tensor + if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1: + # use all-to-all to save memory + return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec) tensor = DistSpecManager._gather(tensor, old_dist_spec) return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec) diff --git a/tests/test_tensor/test_dist_spec_mgr.py b/tests/test_tensor/test_dist_spec_mgr.py index f21790da1..9c459dace 100644 --- a/tests/test_tensor/test_dist_spec_mgr.py +++ b/tests/test_tensor/test_dist_spec_mgr.py @@ -25,7 +25,7 @@ def run(): row_shard = DistSpecManager._shard_as(x, old_dist_spec, row_spec) assert torch.equal(x.chunk(size, 0)[rank], row_shard) assert torch.equal(x, DistSpecManager._gather(row_shard, row_spec)) - col_shard = DistSpecManager._shard_as(x, old_dist_spec, col_spec) + col_shard = DistSpecManager._all_to_all(row_shard, row_spec, col_spec) assert torch.equal(x.chunk(size, -1)[rank], col_shard) assert torch.equal(x, DistSpecManager._gather(col_shard, col_spec)) mat_shard = DistSpecManager._shard_as(x, old_dist_spec, mat_spec)