mirror of https://github.com/hpcaitech/ColossalAI
[model checkpoint] updated communication ops for cpu tensors (#590)
parent
c50bfb807b
commit
6302069c0e
|
@ -8,10 +8,13 @@ from torch import Tensor
|
|||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
||||
def all_gather(tensor: Tensor,
|
||||
dim: int,
|
||||
parallel_mode: ParallelMode,
|
||||
on_cpu: bool = False,
|
||||
async_op: bool = False) -> Tensor:
|
||||
r"""Gathers all tensors from the parallel group and concatenates them in a
|
||||
specific dimension.
|
||||
|
||||
|
@ -23,6 +26,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
|
|||
tensor (:class:`torch.Tensor`): Tensor to be gathered.
|
||||
dim (int): The dimension concatenating in.
|
||||
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
|
||||
on_cpu (bool, optional): Whether to communicate with Gloo backend.
|
||||
async_op (bool, optional): Whether operations are asynchronous.
|
||||
|
||||
Returns:
|
||||
|
@ -37,11 +41,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
|
|||
shape = list(tensor.shape)
|
||||
shape[0], shape[dim] = shape[dim], shape[0]
|
||||
shape[0] *= depth
|
||||
out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device())
|
||||
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
|
||||
temp = list(torch.chunk(out, depth, dim=0))
|
||||
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
|
||||
work = dist.all_gather(tensor_list=temp,
|
||||
tensor=tensor.transpose(0, dim).contiguous(),
|
||||
group=gpc.get_group(parallel_mode),
|
||||
group=group,
|
||||
async_op=async_op)
|
||||
out = torch.transpose(out, 0, dim)
|
||||
if async_op:
|
||||
|
@ -54,6 +59,7 @@ def reduce_scatter(tensor: Tensor,
|
|||
dim: int,
|
||||
parallel_mode: ParallelMode,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
on_cpu: bool = False,
|
||||
async_op: bool = False) -> Tensor:
|
||||
r"""Reduces all tensors then scatters it in a specific dimension to all
|
||||
members in the parallel group.
|
||||
|
@ -70,6 +76,7 @@ def reduce_scatter(tensor: Tensor,
|
|||
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
|
||||
More details about ReduceOp please refer to
|
||||
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
|
||||
on_cpu (bool, optional): Whether to communicate with Gloo backend.
|
||||
async_op (bool, optional): Whether operations are asynchronous.
|
||||
|
||||
Returns:
|
||||
|
@ -82,12 +89,9 @@ def reduce_scatter(tensor: Tensor,
|
|||
work = None
|
||||
else:
|
||||
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
||||
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device())
|
||||
work = dist.reduce_scatter(output=out,
|
||||
input_list=temp,
|
||||
op=op,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
|
||||
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
|
||||
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
|
||||
if async_op:
|
||||
return out, work
|
||||
else:
|
||||
|
@ -97,6 +101,7 @@ def reduce_scatter(tensor: Tensor,
|
|||
def all_reduce(tensor: Tensor,
|
||||
parallel_mode: ParallelMode,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
on_cpu: bool = False,
|
||||
async_op: bool = False) -> Tensor:
|
||||
r"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
|
||||
|
||||
|
@ -111,6 +116,7 @@ def all_reduce(tensor: Tensor,
|
|||
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
|
||||
More details about ReduceOp please refer to
|
||||
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
|
||||
on_cpu (bool, optional): Whether to communicate with Gloo backend.
|
||||
async_op (bool, optional): Whether operations are asynchronous.
|
||||
|
||||
Returns:
|
||||
|
@ -123,14 +129,15 @@ def all_reduce(tensor: Tensor,
|
|||
work = None
|
||||
else:
|
||||
out = tensor.contiguous()
|
||||
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
|
||||
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
|
||||
work = dist.all_reduce(out, op=op, group=group, async_op=async_op)
|
||||
if async_op:
|
||||
return out, work
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
|
||||
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: bool = False, async_op: bool = False):
|
||||
r"""Broadcast tensors to whole parallel group. Tensor must have the same
|
||||
number of elements in all processes participating in the collective.
|
||||
|
||||
|
@ -142,6 +149,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
|
|||
tensor (:class:`torch.Tensor`): Tensor to be broadcast.
|
||||
src (int): Source rank.
|
||||
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
|
||||
on_cpu (bool, optional): Whether to communicate with Gloo backend.
|
||||
async_op (bool, optional): Whether operations are asynchronous.
|
||||
|
||||
Returns:
|
||||
|
@ -154,14 +162,20 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
|
|||
work = None
|
||||
else:
|
||||
out = tensor.contiguous()
|
||||
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
|
||||
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
|
||||
work = dist.broadcast(out, src=src, group=group, async_op=async_op)
|
||||
if async_op:
|
||||
return out, work
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
|
||||
def reduce(tensor: Tensor,
|
||||
dst: int,
|
||||
parallel_mode: ParallelMode,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
on_cpu: bool = False,
|
||||
async_op: bool = False):
|
||||
r"""Reduce tensors across whole parallel group. Only the process with
|
||||
rank ``dst`` is going to receive the final result.
|
||||
|
||||
|
@ -173,6 +187,7 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
|
|||
tensor (:class:`torch.Tensor`): Tensor to be reduced.
|
||||
dst (int): Destination rank.
|
||||
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
|
||||
on_cpu (bool, optional): Whether to communicate with Gloo backend.
|
||||
async_op (bool, optional): Whether operations are asynchronous.
|
||||
|
||||
Returns:
|
||||
|
@ -185,8 +200,62 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
|
|||
work = None
|
||||
else:
|
||||
out = tensor.contiguous()
|
||||
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
|
||||
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
|
||||
work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op)
|
||||
if async_op:
|
||||
return out, work
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
|
||||
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
||||
"""
|
||||
if dist._rank_not_in_group(group):
|
||||
return
|
||||
|
||||
if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1):
|
||||
raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.")
|
||||
|
||||
# set tensor device to cuda if backend is nccl
|
||||
device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu")
|
||||
|
||||
my_rank = dist.get_rank() # use global rank
|
||||
if my_rank == src:
|
||||
tensor_list, tensor_sizes = zip(
|
||||
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list])
|
||||
tensor_list = list(map(lambda x: x.to(device), tensor_list))
|
||||
tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes))
|
||||
|
||||
# Src rank broadcasts the maximum tensor size. This is because all ranks are
|
||||
# expected to call into scatter() with equal-sized tensors.
|
||||
if my_rank == src:
|
||||
max_tensor_size = max(tensor_sizes)
|
||||
for tensor in tensor_list:
|
||||
tensor.resize_(max_tensor_size)
|
||||
else:
|
||||
max_tensor_size = torch.tensor([0], dtype=torch.long).to(device)
|
||||
|
||||
dist.broadcast(max_tensor_size, src=src, group=group)
|
||||
|
||||
# Scatter actual serialized objects
|
||||
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8).to(device)
|
||||
dist.scatter(
|
||||
output_tensor,
|
||||
scatter_list=None if my_rank != src else tensor_list,
|
||||
src=src,
|
||||
group=group,
|
||||
)
|
||||
|
||||
# Scatter per-object sizes to trim tensors when deserializing back to object
|
||||
obj_tensor_size = torch.tensor([0], dtype=torch.long).to(device)
|
||||
dist.scatter(
|
||||
obj_tensor_size,
|
||||
scatter_list=None if my_rank != src else tensor_sizes,
|
||||
src=src,
|
||||
group=group,
|
||||
)
|
||||
|
||||
output_tensor, obj_tensor_size = output_tensor.cpu(), obj_tensor_size.cpu()
|
||||
# Deserialize back to object
|
||||
scatter_object_output_list[0] = dist.distributed_c10d._tensor_to_object(output_tensor, obj_tensor_size)
|
||||
|
|
Loading…
Reference in New Issue