2022-05-19 04:44:59 +00:00
|
|
|
from colossalai.tensor.distspec import _DistSpec
|
2022-05-13 07:13:52 +00:00
|
|
|
from colossalai.nn.layer.utils import divide
|
|
|
|
from numpy import prod
|
|
|
|
from contextlib import contextmanager
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
|
|
|
class TransformDistSpec(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, tensor, old_dist_spec, dist_spec, forward_trans_func, backward_trans_func):
|
|
|
|
ctx.old_dist_spec = old_dist_spec
|
|
|
|
ctx.dist_spec = dist_spec
|
|
|
|
ctx.backward_trans_func = backward_trans_func
|
|
|
|
return forward_trans_func(tensor, old_dist_spec, dist_spec)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_outputs):
|
|
|
|
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec), None, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
class DistSpecManager:
|
|
|
|
|
|
|
|
_use_autograd_function: bool = True
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
|
|
|
chunk = tensor
|
|
|
|
idx = dist_spec.process_group.rank()
|
|
|
|
num_parts = prod(dist_spec.num_partitions)
|
|
|
|
for i, dim in enumerate(dist_spec.dims):
|
|
|
|
num_parts //= dist_spec.num_partitions[i]
|
|
|
|
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
|
|
|
|
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
|
|
|
|
idx %= num_parts
|
|
|
|
return chunk.detach().contiguous()
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
|
|
|
|
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
|
|
|
|
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
|
|
|
|
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
|
|
|
new_buffer = []
|
|
|
|
dim = old_dist_spec.dims[i]
|
|
|
|
num_parts = old_dist_spec.num_partitions[i]
|
|
|
|
for start in range(0, len(buffer), num_parts):
|
|
|
|
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
|
|
|
|
buffer = new_buffer
|
|
|
|
assert len(buffer) == 1
|
|
|
|
return buffer[0]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
2022-05-13 12:29:50 +00:00
|
|
|
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
|
2022-05-19 04:44:59 +00:00
|
|
|
and dist_spec.process_group is not None:
|
2022-05-13 07:13:52 +00:00
|
|
|
raise NotImplementedError
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
|
|
|
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group:
|
|
|
|
raise NotImplementedError
|
|
|
|
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
2022-05-13 12:29:50 +00:00
|
|
|
if old_dist_spec.process_group != dist_spec.process_group \
|
2022-05-19 04:44:59 +00:00
|
|
|
and dist_spec.process_group is not None:
|
2022-05-13 07:13:52 +00:00
|
|
|
raise NotImplementedError
|
|
|
|
return DistSpecManager._gather(tensor, old_dist_spec)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
|
|
|
if old_dist_spec.process_group != dist_spec.process_group:
|
|
|
|
raise NotImplementedError
|
|
|
|
if old_dist_spec == dist_spec:
|
|
|
|
return tensor
|
|
|
|
tensor = DistSpecManager._gather(tensor, old_dist_spec)
|
|
|
|
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
|
|
|
forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}')
|
|
|
|
if not DistSpecManager._use_autograd_function:
|
|
|
|
return forward_trans_handle(tensor, old_dist_spec, dist_spec)
|
|
|
|
backward_trans_handle = getattr(DistSpecManager,
|
|
|
|
f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}')
|
|
|
|
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, forward_trans_handle, backward_trans_handle)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@contextmanager
|
|
|
|
def no_grad():
|
|
|
|
try:
|
|
|
|
DistSpecManager._use_autograd_function = False
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
DistSpecManager._use_autograd_function = True
|