From 8004c8e93872857a3dde69c443cc711973898798 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 25 Apr 2022 13:41:43 +0800 Subject: [PATCH] [doc] improved docstring in the communication module (#863) --- colossalai/communication/collective.py | 2 +- colossalai/communication/p2p.py | 57 +++++++++++++++----------- colossalai/communication/ring.py | 6 +-- colossalai/communication/utils.py | 23 ++++++----- 4 files changed, 49 insertions(+), 39 deletions(-) diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index e0db6ca6c..62436fbbc 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -208,7 +208,7 @@ def reduce(tensor: Tensor, return out -def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None): +def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None: r"""Modified from `torch.distributed.scatter_object_list ` to fix issues """ if dist.distributed_c10d._rank_not_in_group(group): diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py index 220f04861..12737e21d 100644 --- a/colossalai/communication/p2p.py +++ b/colossalai/communication/p2p.py @@ -23,7 +23,7 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> chunk_tensor (bool, optional): whether to chunk tensor, defaults to False Returns: - Tuple[Union[torch.Size, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor + Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor """ if chunk_tensor: tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) @@ -38,31 +38,38 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> return tensor_chunk_shape, chunk_tensor -def _communicate(tensor_send_next=None, - tensor_send_prev=None, - recv_prev=False, - recv_next=False, - recv_prev_shape=None, - recv_next_shape=None, - prev_rank=None, - next_rank=None, - dtype=None, - scatter_gather_tensors=False): +def _communicate(tensor_send_next: torch.Tensor = None, + tensor_send_prev: torch.Tensor = None, + recv_prev: bool = False, + recv_next: bool = False, + recv_prev_shape: TensorShape = None, + recv_next_shape: TensorShape = None, + prev_rank: int = None, + next_rank: int = None, + dtype: torch.dtype = None, + scatter_gather_tensors: bool = False) -> Tuple[torch.Tensor]: """ Adapted from megatron.p2p_communication. Communicate tensors between stages. Used as helper method in other communication methods that are used in pipeline schedule. Takes the following arguments: - tensor_send_next: tensor to send to next rank (no tensor sent if + tensor_send_next (:class:`torch.Tensor`): tensor to send to next rank (no tensor sent if set to None). - tensor_send_prev: tensor to send to prev rank (no tensor sent if + tensor_send_prev (:class:`torch.Tensor`): tensor to send to prev rank (no tensor sent if set to None). - recv_prev: boolean for whether tensor should be received from + recv_prev (bool): boolean for whether tensor should be received from previous rank. - recv_next: boolean for whether tensor should be received from + recv_next (bool): boolean for whether tensor should be received from next rank. + recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None. + recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None. + prev_rank (int): the rank of the previous pipeline stage, defualts to None, + next_rank (int): the rank of the next pipeline stage, defualts to None, + dtype (torch.dtype): data type of intermediate buffers, defaults to None + scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False + Returns: - (tensor_recv_prev, tensor_recv_next) + Tuple[torch.Tensor]: returns tensor_recv_prev, tensor_recv_next """ # Create placeholder tensors for receive in forward and backward directions @@ -130,7 +137,7 @@ def _communicate(tensor_send_next=None, return tensor_recv_prev, tensor_recv_next -def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False): +def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: @@ -151,7 +158,7 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_ return input_tensor -def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False): +def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. Args: @@ -172,7 +179,7 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_ return output_tensor_grad -def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False): +def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None: """Sends the input tensor to the next stage in pipeline. Args: @@ -183,7 +190,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False): _communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) -def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False): +def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None: """Sends the gradient tensor to the previous stage in pipeline. Args: @@ -201,7 +208,7 @@ def send_forward_recv_backward(output_tensor, recv_next=True, next_rank=None, dtype=torch.float, - scatter_gather_tensors=False): + scatter_gather_tensors=False) -> torch.Tensor: """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the gradient tensor from the next stage in pipeline as the input gradient tensor of this stage. @@ -230,7 +237,7 @@ def send_backward_recv_forward(input_tensor_grad, recv_prev=True, prev_rank=None, dtype=torch.float, - scatter_gather_tensors=False): + scatter_gather_tensors=False) -> torch.Tensor: """Batched communication operation. Sends the gradient tensor to the previous stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. @@ -260,7 +267,7 @@ def send_forward_recv_forward(output_tensor, prev_rank=None, next_rank=None, dtype=torch.float, - scatter_gather_tensors=False): + scatter_gather_tensors=False) -> torch.Tensor: """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. @@ -288,7 +295,7 @@ def send_backward_recv_backward(input_tensor_grad, prev_rank=None, next_rank=None, dtype=torch.float, - scatter_gather_tensors=False): + scatter_gather_tensors=False) -> torch.Tensor: """Batched communication operation. Sends the gradient tensor to the previous stage in pipeline, while receives the gradient tensor from the next member in pipeline as the input of this stage. @@ -319,7 +326,7 @@ def send_forward_backward_recv_forward_backward(output_tensor, prev_rank=None, next_rank=None, dtype=torch.float, - scatter_gather_tensors=False): + scatter_gather_tensors=False) -> Tuple[torch.Tensor]: """Batched communication operation. Sends the input tensor to the next stage in pipeline and the gradient tensor to the previous stage, while receives the input gradient tensor from the next stage and the input tensor from the previous stage. diff --git a/colossalai/communication/ring.py b/colossalai/communication/ring.py index 3e1a8998f..aece7574b 100644 --- a/colossalai/communication/ring.py +++ b/colossalai/communication/ring.py @@ -8,13 +8,13 @@ from colossalai.core import global_context as gpc from colossalai.utils import get_current_device, synchronize -def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode): +def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: """Sends a tensor to the next member and receives a tensor from the previous member. This function returns the received tensor from the previous member. Args: - tensor_send_next: Tensor sent to next member - parallel_mode: Parallel group mode used in this communication + tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member + parallel_mode (ParallelMode): Parallel group mode used in this communication Returns: :class:`torch.Tensor`: The tensor received from the previous. diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py index 7554910e4..f57a0009c 100644 --- a/colossalai/communication/utils.py +++ b/colossalai/communication/utils.py @@ -4,16 +4,19 @@ import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device +from typing import Union, List, Tuple +TensorShape = Union[torch.Size, List[int], Tuple[int]] -def send_tensor_meta(tensor, need_meta=True, next_rank=None): + +def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool: """Sends tensor meta information before sending a specific tensor. Since the recipient must know the shape of the tensor in p2p communications, meta information of the tensor should be sent before communications. This function synchronizes with :func:`recv_tensor_meta`. Args: - tensor (torch.Tensor): Tensor to be sent. + tensor (:class:`torch.Tensor`): Tensor to be sent. need_meta (bool, optional): If False, meta information won't be sent. next_rank (int): The rank of the next member in pipeline parallel group. @@ -34,14 +37,14 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None): return False -def recv_tensor_meta(tensor_shape, prev_rank=None): +def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size: """Receives tensor meta information before receiving a specific tensor. Since the recipient must know the shape of the tensor in p2p communications, meta information of the tensor should be received before communications. This function synchronizes with :func:`send_tensor_meta`. Args: - tensor_shape (torch.Size): The shape of the tensor to be received. + tensor_shape (:class:`torch.Size`): The shape of the tensor to be received. prev_rank (int): The rank of the source of the tensor. Returns: @@ -63,15 +66,15 @@ def recv_tensor_meta(tensor_shape, prev_rank=None): return tensor_shape -def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): +def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: """Break a tensor into equal 1D chunks. Args: - tensor (torch.Tensor): Tensor to be split before communication. + tensor (:class:`torch.Tensor`): Tensor to be split before communication. new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. Returns: - :class:`torch.Size`: The split tensor + :class:`torch.Tensor`: The split tensor """ partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) @@ -84,13 +87,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): return data -def gather_split_1d_tensor(tensor): +def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: """Opposite of above function, gather values from model parallel ranks. Args: - tensor (torch.Tensor): Tensor to be gathered after communication. + tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. Returns: - :class:`torch.Size`: The gathered tensor. + :class:`torch.Tensor`: The gathered tensor. """ world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) numel = torch.numel(tensor)