2022-05-19 04:44:59 +00:00
|
|
|
from colossalai.tensor.distspec import _DistSpec
|
2022-06-03 10:04:22 +00:00
|
|
|
# from colossalai.nn.layer.utils import divide
|
2022-05-13 07:13:52 +00:00
|
|
|
from numpy import prod
|
|
|
|
from contextlib import contextmanager
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
2022-06-21 10:28:38 +00:00
|
|
|
from packaging import version
|
2022-07-04 10:54:37 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
2022-07-06 08:15:16 +00:00
|
|
|
from colossalai.tensor import ProcessGroup
|
2022-05-13 07:13:52 +00:00
|
|
|
|
|
|
|
|
2022-06-03 10:04:22 +00:00
|
|
|
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
|
|
|
# colossalai.tensor shall not import any submodule from colossal.nn
|
|
|
|
def divide(numerator, denominator):
|
|
|
|
"""Only allow exact division.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
numerator (int): Numerator of the division.
|
|
|
|
denominator (int): Denominator of the division.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: the result of exact division.
|
|
|
|
"""
|
|
|
|
assert denominator != 0, 'denominator can not be zero'
|
|
|
|
assert numerator % denominator == 0, \
|
|
|
|
'{} is not divisible by {}'.format(numerator, denominator)
|
|
|
|
return numerator // denominator
|
|
|
|
|
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
class TransformDistSpec(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
|
2022-05-13 07:13:52 +00:00
|
|
|
ctx.old_dist_spec = old_dist_spec
|
|
|
|
ctx.dist_spec = dist_spec
|
|
|
|
ctx.backward_trans_func = backward_trans_func
|
2022-07-06 08:15:16 +00:00
|
|
|
ctx.pg = pg
|
|
|
|
return forward_trans_func(tensor, old_dist_spec, dist_spec, pg)
|
2022-05-13 07:13:52 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_outputs):
|
2022-07-06 08:15:16 +00:00
|
|
|
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec,
|
|
|
|
ctx.pg), None, None, None, None, None
|
2022-05-13 07:13:52 +00:00
|
|
|
|
|
|
|
|
|
|
|
class DistSpecManager:
|
|
|
|
|
|
|
|
_use_autograd_function: bool = True
|
|
|
|
|
2022-06-23 03:35:25 +00:00
|
|
|
@staticmethod
|
2022-06-22 03:32:38 +00:00
|
|
|
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
|
2022-07-06 08:15:16 +00:00
|
|
|
pass
|
2022-06-22 03:32:38 +00:00
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
|
|
|
pg: ProcessGroup) -> torch.Tensor:
|
2022-06-22 03:32:38 +00:00
|
|
|
"""_shard_as: shard the tensor w.r.t a distributed specification.
|
|
|
|
Assuming the tensor passed in is a global (replicated) tensor.
|
|
|
|
Args:
|
|
|
|
tensor (torch.Tensor): a global (replicated) tensor before shard
|
|
|
|
dist_spec (_DistSpec): the distributed spec. to be sharded as.
|
2022-07-06 08:15:16 +00:00
|
|
|
pg (ProcessGrouo): the process group of the corresponding colotensor
|
2022-06-22 03:32:38 +00:00
|
|
|
Returns:
|
|
|
|
torch.Tensor: a torch tensor after sharded.
|
|
|
|
"""
|
|
|
|
assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
|
|
|
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
chunk = tensor
|
2022-07-06 08:15:16 +00:00
|
|
|
idx = pg.tp_local_rank()
|
2022-05-13 07:13:52 +00:00
|
|
|
num_parts = prod(dist_spec.num_partitions)
|
|
|
|
for i, dim in enumerate(dist_spec.dims):
|
|
|
|
num_parts //= dist_spec.num_partitions[i]
|
2022-06-27 01:45:26 +00:00
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
|
|
|
|
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
|
|
|
|
idx %= num_parts
|
2022-05-31 04:14:39 +00:00
|
|
|
return chunk.clone().detach().contiguous()
|
2022-05-13 07:13:52 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
2022-06-22 03:32:38 +00:00
|
|
|
"""_gather gather sharded tensors to a replicated one.
|
|
|
|
Args:
|
|
|
|
tensor (torch.Tensor): a shared torch tensor
|
|
|
|
old_dist_spec (_DistSpec): the distributed spec. of the tensor.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: a replicated tensor.
|
|
|
|
"""
|
|
|
|
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
|
2022-06-21 10:28:38 +00:00
|
|
|
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
|
|
|
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
|
|
|
# Therefore, we transfer tensor to GPU before gather.
|
|
|
|
saved_dev = tensor.device
|
|
|
|
tensor.data = tensor.data.cuda()
|
|
|
|
|
2022-07-06 08:15:16 +00:00
|
|
|
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
|
2022-07-04 10:54:37 +00:00
|
|
|
assert tensor.device.type == 'cuda'
|
2022-07-06 08:15:16 +00:00
|
|
|
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
|
2022-05-13 07:13:52 +00:00
|
|
|
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
|
|
|
new_buffer = []
|
|
|
|
dim = old_dist_spec.dims[i]
|
|
|
|
num_parts = old_dist_spec.num_partitions[i]
|
|
|
|
for start in range(0, len(buffer), num_parts):
|
|
|
|
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
|
|
|
|
buffer = new_buffer
|
|
|
|
assert len(buffer) == 1
|
2022-06-21 10:28:38 +00:00
|
|
|
|
|
|
|
if version.parse(torch.__version__) < version.parse("1.11.0"):
|
|
|
|
buffer[0].data = buffer[0].data.to(saved_dev)
|
2022-05-13 07:13:52 +00:00
|
|
|
return buffer[0]
|
|
|
|
|
2022-06-22 03:32:38 +00:00
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
|
|
|
pg: ProcessGroup) -> torch.Tensor:
|
|
|
|
world_size = pg.tp_world_size()
|
2022-06-22 03:32:38 +00:00
|
|
|
if world_size == 1:
|
|
|
|
return tensor
|
|
|
|
|
2022-07-06 08:15:16 +00:00
|
|
|
assert tensor.device.type == "cuda", \
|
|
|
|
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
2022-07-05 06:58:28 +00:00
|
|
|
f"collective function, however, we got {tensor.device.type} device"
|
2022-06-22 03:32:38 +00:00
|
|
|
|
|
|
|
gather_dim = old_dist_spec.dims[0]
|
|
|
|
scatter_dim = dist_spec.dims[0]
|
|
|
|
shapes = list(tensor.shape)
|
|
|
|
scattered_dim_size = shapes[scatter_dim] // world_size
|
|
|
|
gathered_dim_size = shapes[gather_dim] * world_size
|
|
|
|
shapes[scatter_dim] = scattered_dim_size
|
|
|
|
|
|
|
|
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
|
|
|
|
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
2022-07-06 08:15:16 +00:00
|
|
|
dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
2022-06-22 03:32:38 +00:00
|
|
|
|
|
|
|
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
|
|
|
|
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
|
|
|
|
return output_
|
|
|
|
|
2022-05-13 07:13:52 +00:00
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
2022-06-22 03:32:38 +00:00
|
|
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
2022-05-13 07:13:52 +00:00
|
|
|
return tensor
|
|
|
|
|
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
2022-06-22 03:32:38 +00:00
|
|
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
2022-07-06 08:15:16 +00:00
|
|
|
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
|
2022-05-13 07:13:52 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
2022-06-22 03:32:38 +00:00
|
|
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
2022-07-06 08:15:16 +00:00
|
|
|
return DistSpecManager._gather(tensor, old_dist_spec, pg)
|
2022-05-13 07:13:52 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
2022-06-22 03:32:38 +00:00
|
|
|
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
2022-05-13 07:13:52 +00:00
|
|
|
if old_dist_spec == dist_spec:
|
|
|
|
return tensor
|
2022-06-22 03:32:38 +00:00
|
|
|
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
|
|
|
|
# use all-to-all to save memory
|
2022-07-06 08:15:16 +00:00
|
|
|
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg)
|
|
|
|
tensor = DistSpecManager._gather(tensor, old_dist_spec, pg)
|
|
|
|
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
|
2022-05-13 07:13:52 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2022-07-06 08:15:16 +00:00
|
|
|
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
|
|
|
pg: ProcessGroup) -> torch.Tensor:
|
|
|
|
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
|
|
|
|
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
|
2022-05-13 07:13:52 +00:00
|
|
|
forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}')
|
|
|
|
if not DistSpecManager._use_autograd_function:
|
2022-07-06 08:15:16 +00:00
|
|
|
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
|
2022-05-13 07:13:52 +00:00
|
|
|
backward_trans_handle = getattr(DistSpecManager,
|
|
|
|
f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}')
|
2022-07-06 08:15:16 +00:00
|
|
|
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
|
|
|
|
backward_trans_handle)
|
2022-05-13 07:13:52 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@contextmanager
|
|
|
|
def no_grad():
|
|
|
|
try:
|
|
|
|
DistSpecManager._use_autograd_function = False
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
DistSpecManager._use_autograd_function = True
|