[model checkpoint] updated communication ops for cpu tensors (#590)

pull/625/head
アマデウス 2022-04-01 16:52:20 +08:00 committed by GitHub
parent c50bfb807b
commit 6302069c0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 84 additions and 15 deletions

View File

@ -8,10 +8,13 @@ from torch import Tensor
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: def all_gather(tensor: Tensor,
dim: int,
parallel_mode: ParallelMode,
on_cpu: bool = False,
async_op: bool = False) -> Tensor:
r"""Gathers all tensors from the parallel group and concatenates them in a r"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension. specific dimension.
@ -23,6 +26,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
tensor (:class:`torch.Tensor`): Tensor to be gathered. tensor (:class:`torch.Tensor`): Tensor to be gathered.
dim (int): The dimension concatenating in. dim (int): The dimension concatenating in.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -37,11 +41,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
shape = list(tensor.shape) shape = list(tensor.shape)
shape[0], shape[dim] = shape[dim], shape[0] shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth shape[0] *= depth
out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device()) out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
temp = list(torch.chunk(out, depth, dim=0)) temp = list(torch.chunk(out, depth, dim=0))
group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
work = dist.all_gather(tensor_list=temp, work = dist.all_gather(tensor_list=temp,
tensor=tensor.transpose(0, dim).contiguous(), tensor=tensor.transpose(0, dim).contiguous(),
group=gpc.get_group(parallel_mode), group=group,
async_op=async_op) async_op=async_op)
out = torch.transpose(out, 0, dim) out = torch.transpose(out, 0, dim)
if async_op: if async_op:
@ -54,6 +59,7 @@ def reduce_scatter(tensor: Tensor,
dim: int, dim: int,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Reduces all tensors then scatters it in a specific dimension to all r"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group. members in the parallel group.
@ -70,6 +76,7 @@ def reduce_scatter(tensor: Tensor,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_. `ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -82,12 +89,9 @@ def reduce_scatter(tensor: Tensor,
work = None work = None
else: else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device()) out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
work = dist.reduce_scatter(output=out, group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
input_list=temp, work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
op=op,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op: if async_op:
return out, work return out, work
else: else:
@ -97,6 +101,7 @@ def reduce_scatter(tensor: Tensor,
def all_reduce(tensor: Tensor, def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. r"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
@ -111,6 +116,7 @@ def all_reduce(tensor: Tensor,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_. `ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -123,14 +129,15 @@ def all_reduce(tensor: Tensor,
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
work = dist.all_reduce(out, op=op, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
else: else:
return out return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: bool = False, async_op: bool = False):
r"""Broadcast tensors to whole parallel group. Tensor must have the same r"""Broadcast tensors to whole parallel group. Tensor must have the same
number of elements in all processes participating in the collective. number of elements in all processes participating in the collective.
@ -142,6 +149,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
tensor (:class:`torch.Tensor`): Tensor to be broadcast. tensor (:class:`torch.Tensor`): Tensor to be broadcast.
src (int): Source rank. src (int): Source rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -154,14 +162,20 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op) group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
work = dist.broadcast(out, src=src, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
else: else:
return out return out
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): def reduce(tensor: Tensor,
dst: int,
parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM,
on_cpu: bool = False,
async_op: bool = False):
r"""Reduce tensors across whole parallel group. Only the process with r"""Reduce tensors across whole parallel group. Only the process with
rank ``dst`` is going to receive the final result. rank ``dst`` is going to receive the final result.
@ -173,6 +187,7 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
tensor (:class:`torch.Tensor`): Tensor to be reduced. tensor (:class:`torch.Tensor`): Tensor to be reduced.
dst (int): Destination rank. dst (int): Destination rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
on_cpu (bool, optional): Whether to communicate with Gloo backend.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
@ -185,8 +200,62 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
work = None work = None
else: else:
out = tensor.contiguous() out = tensor.contiguous()
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode)
work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op)
if async_op: if async_op:
return out, work return out, work
else: else:
return out return out
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
"""
if dist._rank_not_in_group(group):
return
if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1):
raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.")
# set tensor device to cuda if backend is nccl
device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu")
my_rank = dist.get_rank() # use global rank
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list])
tensor_list = list(map(lambda x: x.to(device), tensor_list))
tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes))
# Src rank broadcasts the maximum tensor size. This is because all ranks are
# expected to call into scatter() with equal-sized tensors.
if my_rank == src:
max_tensor_size = max(tensor_sizes)
for tensor in tensor_list:
tensor.resize_(max_tensor_size)
else:
max_tensor_size = torch.tensor([0], dtype=torch.long).to(device)
dist.broadcast(max_tensor_size, src=src, group=group)
# Scatter actual serialized objects
output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8).to(device)
dist.scatter(
output_tensor,
scatter_list=None if my_rank != src else tensor_list,
src=src,
group=group,
)
# Scatter per-object sizes to trim tensors when deserializing back to object
obj_tensor_size = torch.tensor([0], dtype=torch.long).to(device)
dist.scatter(
obj_tensor_size,
scatter_list=None if my_rank != src else tensor_sizes,
src=src,
group=group,
)
output_tensor, obj_tensor_size = output_tensor.cpu(), obj_tensor_size.cpu()
# Deserialize back to object
scatter_object_output_list[0] = dist.distributed_c10d._tensor_to_object(output_tensor, obj_tensor_size)