mirror of https://github.com/hpcaitech/ColossalAI
85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
![]() |
#!/usr/bin/env python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch import Tensor
|
||
|
|
||
|
from colossalai.context import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.utils import get_current_device
|
||
|
|
||
|
|
||
|
def all_gather(tensor: Tensor, dim: int,
|
||
|
parallel_mode: ParallelMode) -> Tensor:
|
||
|
"""Gathers all tensors from the parallel group and concatenates them in a
|
||
|
specific dimension.
|
||
|
|
||
|
:param tensor: Tensor to be gathered
|
||
|
:param dim: The dimension concatenating in
|
||
|
:param parallel_mode: Parallel group mode used in this communication
|
||
|
:type tensor: Tensor
|
||
|
:type dim: int
|
||
|
:type parallel_mode: ParallelMode
|
||
|
:return: The tensor generated by all-gather
|
||
|
:rtype: Tensor
|
||
|
"""
|
||
|
depth = gpc.get_world_size(parallel_mode)
|
||
|
temp = tensor.clone()
|
||
|
shape = list(temp.shape)
|
||
|
shape[dim] *= depth
|
||
|
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
|
||
|
out = list(torch.chunk(out, depth, dim=dim))
|
||
|
out = [val.contiguous() for val in out]
|
||
|
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
|
||
|
out = torch.cat(out, dim=dim)
|
||
|
return out
|
||
|
|
||
|
|
||
|
def reduce_scatter(tensor: Tensor, dim: int,
|
||
|
parallel_mode: ParallelMode) -> Tensor:
|
||
|
"""Reduces all tensors then scatters it in a specific dimension to all
|
||
|
members in the parallel group.
|
||
|
|
||
|
:param tensor: Tensor to be reduced and scattered
|
||
|
:param dim: The dimension scattering in
|
||
|
:param parallel_mode: Parallel group mode used in this communication
|
||
|
:type tensor: Tensor
|
||
|
:type dim: int
|
||
|
:type parallel_mode: ParallelMode
|
||
|
:return: The tensor generated by reduce-scatter
|
||
|
:rtype: Tensor
|
||
|
"""
|
||
|
depth = gpc.get_world_size(parallel_mode)
|
||
|
temp = list(torch.chunk(tensor, depth, dim=dim))
|
||
|
temp = [val.contiguous() for val in temp]
|
||
|
out = torch.empty(temp[0].shape,
|
||
|
dtype=temp[0].dtype,
|
||
|
device=get_current_device())
|
||
|
dist.reduce_scatter(output=out,
|
||
|
input_list=temp,
|
||
|
group=gpc.get_group(parallel_mode))
|
||
|
return out
|
||
|
|
||
|
|
||
|
def scatter(tensor: Tensor, src: int, dim: int,
|
||
|
parallel_mode: ParallelMode) -> Tensor:
|
||
|
"""Scatters in a specific dimension from source rank to all ranks in
|
||
|
the parallel group.
|
||
|
|
||
|
:param tensor: Tensor to be scattered
|
||
|
:param dim: The dimension scattering in
|
||
|
:param parallel_mode: Parallel group mode used in this communication
|
||
|
:type tensor: Tensor
|
||
|
:type dim: int
|
||
|
:type parallel_mode: ParallelMode
|
||
|
:return: The tensor generated by scatter
|
||
|
:rtype: Tensor
|
||
|
"""
|
||
|
depth = gpc.get_world_size(parallel_mode)
|
||
|
temp = tensor.clone()
|
||
|
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
|
||
|
rank = gpc.get_local_rank(parallel_mode)
|
||
|
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
|
||
|
return out
|