#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch import torch.distributed as dist from torch import Tensor from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: """Gathers all tensors from the parallel group and concatenates them in a specific dimension. :param tensor: Tensor to be gathered :param dim: The dimension concatenating in :param parallel_mode: Parallel group mode used in this communication :type tensor: Tensor :type dim: int :type parallel_mode: ParallelMode :return: The tensor generated by all-gather :rtype: Tensor """ depth = gpc.get_world_size(parallel_mode) temp = tensor.clone() shape = list(temp.shape) shape[dim] *= depth out = torch.empty(shape, dtype=temp.dtype, device=get_current_device()) out = list(torch.chunk(out, depth, dim=dim)) out = [val.contiguous() for val in out] dist.all_gather(out, temp, group=gpc.get_group(parallel_mode)) out = torch.cat(out, dim=dim) return out def reduce_scatter(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: """Reduces all tensors then scatters it in a specific dimension to all members in the parallel group. :param tensor: Tensor to be reduced and scattered :param dim: The dimension scattering in :param parallel_mode: Parallel group mode used in this communication :type tensor: Tensor :type dim: int :type parallel_mode: ParallelMode :return: The tensor generated by reduce-scatter :rtype: Tensor """ depth = gpc.get_world_size(parallel_mode) temp = list(torch.chunk(tensor, depth, dim=dim)) temp = [val.contiguous() for val in temp] out = torch.empty(temp[0].shape, dtype=temp[0].dtype, device=get_current_device()) dist.reduce_scatter(output=out, input_list=temp, group=gpc.get_group(parallel_mode)) return out def scatter(tensor: Tensor, src: int, dim: int, parallel_mode: ParallelMode) -> Tensor: """Scatters in a specific dimension from source rank to all ranks in the parallel group. :param tensor: Tensor to be scattered :param dim: The dimension scattering in :param parallel_mode: Parallel group mode used in this communication :type tensor: Tensor :type dim: int :type parallel_mode: ParallelMode :return: The tensor generated by scatter :rtype: Tensor """ depth = gpc.get_world_size(parallel_mode) temp = tensor.clone() dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode)) rank = gpc.get_local_rank(parallel_mode) out = torch.chunk(temp, depth, dim=dim)[rank].contiguous() return out