ColossalAI/colossalai/communication/collective.py

85 lines
2.9 KiB
Python
Raw Normal View History

2021-10-28 16:21:23 +00:00
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
:param tensor: Tensor to be gathered
:param dim: The dimension concatenating in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by all-gather
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
shape = list(temp.shape)
shape[dim] *= depth
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
out = list(torch.chunk(out, depth, dim=dim))
out = [val.contiguous() for val in out]
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
out = torch.cat(out, dim=dim)
return out
def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
:param tensor: Tensor to be reduced and scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by reduce-scatter
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = list(torch.chunk(tensor, depth, dim=dim))
temp = [val.contiguous() for val in temp]
out = torch.empty(temp[0].shape,
dtype=temp[0].dtype,
device=get_current_device())
dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode))
return out
def scatter(tensor: Tensor, src: int, dim: int,
parallel_mode: ParallelMode) -> Tensor:
"""Scatters in a specific dimension from source rank to all ranks in
the parallel group.
:param tensor: Tensor to be scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor
:type dim: int
:type parallel_mode: ParallelMode
:return: The tensor generated by scatter
:rtype: Tensor
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
return out