[doc] improved docstring in the communication module (#863)

pull/867/head
Frank Lee 3 years ago committed by GitHub
parent 8af5f7423d
commit 8004c8e938
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -208,7 +208,7 @@ def reduce(tensor: Tensor,
return out
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
"""
if dist.distributed_c10d._rank_not_in_group(group):

@ -23,7 +23,7 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
Returns:
Tuple[Union[torch.Size, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
"""
if chunk_tensor:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
@ -38,31 +38,38 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
return tensor_chunk_shape, chunk_tensor
def _communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
recv_prev_shape=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
dtype=None,
scatter_gather_tensors=False):
def _communicate(tensor_send_next: torch.Tensor = None,
tensor_send_prev: torch.Tensor = None,
recv_prev: bool = False,
recv_next: bool = False,
recv_prev_shape: TensorShape = None,
recv_next_shape: TensorShape = None,
prev_rank: int = None,
next_rank: int = None,
dtype: torch.dtype = None,
scatter_gather_tensors: bool = False) -> Tuple[torch.Tensor]:
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
tensor_send_next (:class:`torch.Tensor`): tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
tensor_send_prev (:class:`torch.Tensor`): tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
recv_prev (bool): boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
recv_next (bool): boolean for whether tensor should be received from
next rank.
recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None.
recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None.
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
next_rank (int): the rank of the next pipeline stage, defualts to None,
dtype (torch.dtype): data type of intermediate buffers, defaults to None
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
Returns:
(tensor_recv_prev, tensor_recv_next)
Tuple[torch.Tensor]: returns tensor_recv_prev, tensor_recv_next
"""
# Create placeholder tensors for receive in forward and backward directions
@ -130,7 +137,7 @@ def _communicate(tensor_send_next=None,
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
@ -151,7 +158,7 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
return input_tensor
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
@ -172,7 +179,7 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_
return output_tensor_grad
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
"""Sends the input tensor to the next stage in pipeline.
Args:
@ -183,7 +190,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
@ -201,7 +208,7 @@ def send_forward_recv_backward(output_tensor,
recv_next=True,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
scatter_gather_tensors=False) -> torch.Tensor:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage.
@ -230,7 +237,7 @@ def send_backward_recv_forward(input_tensor_grad,
recv_prev=True,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
scatter_gather_tensors=False) -> torch.Tensor:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
@ -260,7 +267,7 @@ def send_forward_recv_forward(output_tensor,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
scatter_gather_tensors=False) -> torch.Tensor:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
@ -288,7 +295,7 @@ def send_backward_recv_backward(input_tensor_grad,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
scatter_gather_tensors=False) -> torch.Tensor:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the gradient tensor from the
next member in pipeline as the input of this stage.
@ -319,7 +326,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
scatter_gather_tensors=False) -> Tuple[torch.Tensor]:
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
the gradient tensor to the previous stage, while receives the input gradient tensor from the
next stage and the input tensor from the previous stage.

@ -8,13 +8,13 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
"""Sends a tensor to the next member and receives a tensor from the previous member.
This function returns the received tensor from the previous member.
Args:
tensor_send_next: Tensor sent to next member
parallel_mode: Parallel group mode used in this communication
tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member
parallel_mode (ParallelMode): Parallel group mode used in this communication
Returns:
:class:`torch.Tensor`: The tensor received from the previous.

@ -4,16 +4,19 @@ import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from typing import Union, List, Tuple
TensorShape = Union[torch.Size, List[int], Tuple[int]]
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.
Args:
tensor (torch.Tensor): Tensor to be sent.
tensor (:class:`torch.Tensor`): Tensor to be sent.
need_meta (bool, optional): If False, meta information won't be sent.
next_rank (int): The rank of the next member in pipeline parallel group.
@ -34,14 +37,14 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
return False
def recv_tensor_meta(tensor_shape, prev_rank=None):
def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size:
"""Receives tensor meta information before receiving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be received before communications. This function
synchronizes with :func:`send_tensor_meta`.
Args:
tensor_shape (torch.Size): The shape of the tensor to be received.
tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
prev_rank (int): The rank of the source of the tensor.
Returns:
@ -63,15 +66,15 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
return tensor_shape
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
"""Break a tensor into equal 1D chunks.
Args:
tensor (torch.Tensor): Tensor to be split before communication.
tensor (:class:`torch.Tensor`): Tensor to be split before communication.
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
Returns:
:class:`torch.Size`: The split tensor
:class:`torch.Tensor`: The split tensor
"""
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@ -84,13 +87,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
return data
def gather_split_1d_tensor(tensor):
def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Opposite of above function, gather values from model parallel ranks.
Args:
tensor (torch.Tensor): Tensor to be gathered after communication.
tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
Returns:
:class:`torch.Size`: The gathered tensor.
:class:`torch.Tensor`: The gathered tensor.
"""
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
numel = torch.numel(tensor)

Loading…
Cancel
Save