ColossalAI/colossalai/legacy/tensor/dist_spec_mgr.py

207 lines
8.5 KiB
Python

from contextlib import contextmanager
import torch
import torch.distributed as dist
from numpy import prod
from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.legacy.tensor.process_group import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
# colossalai.legacy.tensor shall not import any submodule from colossal.nn
def divide(numerator, denominator):
"""Only allow exact division.
Args:
numerator (int): Numerator of the division.
denominator (int): Denominator of the division.
Returns:
int: the result of exact division.
"""
assert denominator != 0, "denominator can not be zero"
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
return numerator // denominator
class TransformDistSpec(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
ctx.old_dist_spec = old_dist_spec
ctx.dist_spec = dist_spec
ctx.backward_trans_func = backward_trans_func
ctx.pg = pg
return forward_trans_func(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def backward(ctx, grad_outputs):
return (
ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, ctx.pg),
None,
None,
None,
None,
None,
)
class DistSpecManager:
_use_autograd_function: bool = True
@staticmethod
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
pass
@staticmethod
def _shard_as(
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
) -> torch.Tensor:
"""_shard_as: shard the tensor w.r.t a distributed specification.
Assuming the tensor passed in is a global (replicated) tensor.
Args:
tensor (torch.Tensor): a global (replicated) tensor before shard
dist_spec (_DistSpec): the distributed spec. to be sharded as.
pg (ProcessGroup): the process group of the corresponding colotensor
Returns:
torch.Tensor: a torch tensor after sharded.
"""
assert (
old_dist_spec.placement.value == "r"
), f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor
idx = pg.tp_local_rank()
num_parts = prod(dist_spec.num_partitions)
for i, dim in enumerate(dist_spec.dims):
num_parts //= dist_spec.num_partitions[i]
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
idx %= num_parts
return chunk.clone().detach().contiguous()
@staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
"""_gather gather sharded tensors to a replicated one.
Args:
tensor (torch.Tensor): a shared torch tensor
old_dist_spec (_DistSpec): the distributed spec. of the tensor.
Returns:
torch.Tensor: a replicated tensor.
"""
assert old_dist_spec.placement.value == "s", f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
is_cpu_tensor = False
if tensor.device.type == "cpu":
# pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather.
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
is_cpu_tensor = True
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
assert tensor.device.type == "cuda"
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
new_buffer = []
dim = old_dist_spec.dims[i]
num_parts = old_dist_spec.num_partitions[i]
for start in range(0, len(buffer), num_parts):
new_buffer.append(torch.cat(buffer[start : start + num_parts], dim))
buffer = new_buffer
assert len(buffer) == 1
if is_cpu_tensor:
buffer[0].data = buffer[0].data.to(saved_dev)
return buffer[0]
@staticmethod
def _all_to_all(
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
) -> torch.Tensor:
world_size = pg.tp_world_size()
if world_size == 1:
return tensor
assert tensor.device.type == "cuda", (
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll "
f"collective function, however, we got {tensor.device.type} device"
)
gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0]
shapes = list(tensor.shape)
scattered_dim_size = shapes[scatter_dim] // world_size
gathered_dim_size = shapes[gather_dim] * world_size
shapes[scatter_dim] = scattered_dim_size
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
return output_
@staticmethod
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return tensor
@staticmethod
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._gather(tensor, old_dist_spec, pg)
@staticmethod
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
if old_dist_spec == dist_spec:
return tensor
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
# use all-to-all to save memory
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg)
tensor = DistSpecManager._gather(tensor, old_dist_spec, pg)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def handle_trans_spec(
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
) -> torch.Tensor:
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
trans_func_key = (old_dist_spec.placement, dist_spec.placement)
trans_funcs = {
(DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r,
(DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s,
(DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r,
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s,
}
forward_trans_handle = trans_funcs[trans_func_key]
if not DistSpecManager._use_autograd_function:
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)]
return TransformDistSpec.apply(
tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, backward_trans_handle
)
@staticmethod
@contextmanager
def no_grad():
try:
DistSpecManager._use_autograd_function = False
yield
finally:
DistSpecManager._use_autograd_function = True