[doc] improved docstring in the communication module (#863)

pull/867/head
Frank Lee 2022-04-25 13:41:43 +08:00 committed by GitHub
parent 8af5f7423d
commit 8004c8e938
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 39 deletions

View File

@ -208,7 +208,7 @@ def reduce(tensor: Tensor,
return out return out
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None): def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
""" """
if dist.distributed_c10d._rank_not_in_group(group): if dist.distributed_c10d._rank_not_in_group(group):

View File

@ -23,7 +23,7 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
Returns: Returns:
Tuple[Union[torch.Size, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
""" """
if chunk_tensor: if chunk_tensor:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
@ -38,31 +38,38 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
return tensor_chunk_shape, chunk_tensor return tensor_chunk_shape, chunk_tensor
def _communicate(tensor_send_next=None, def _communicate(tensor_send_next: torch.Tensor = None,
tensor_send_prev=None, tensor_send_prev: torch.Tensor = None,
recv_prev=False, recv_prev: bool = False,
recv_next=False, recv_next: bool = False,
recv_prev_shape=None, recv_prev_shape: TensorShape = None,
recv_next_shape=None, recv_next_shape: TensorShape = None,
prev_rank=None, prev_rank: int = None,
next_rank=None, next_rank: int = None,
dtype=None, dtype: torch.dtype = None,
scatter_gather_tensors=False): scatter_gather_tensors: bool = False) -> Tuple[torch.Tensor]:
""" """
Adapted from megatron.p2p_communication. Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule. communication methods that are used in pipeline schedule.
Takes the following arguments: Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if tensor_send_next (:class:`torch.Tensor`): tensor to send to next rank (no tensor sent if
set to None). set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if tensor_send_prev (:class:`torch.Tensor`): tensor to send to prev rank (no tensor sent if
set to None). set to None).
recv_prev: boolean for whether tensor should be received from recv_prev (bool): boolean for whether tensor should be received from
previous rank. previous rank.
recv_next: boolean for whether tensor should be received from recv_next (bool): boolean for whether tensor should be received from
next rank. next rank.
recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None.
recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None.
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
next_rank (int): the rank of the next pipeline stage, defualts to None,
dtype (torch.dtype): data type of intermediate buffers, defaults to None
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) Tuple[torch.Tensor]: returns tensor_recv_prev, tensor_recv_next
""" """
# Create placeholder tensors for receive in forward and backward directions # Create placeholder tensors for receive in forward and backward directions
@ -130,7 +137,7 @@ def _communicate(tensor_send_next=None,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False): def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args: Args:
@ -151,7 +158,7 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
return input_tensor return input_tensor
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False): def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args: Args:
@ -172,7 +179,7 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False): def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
"""Sends the input tensor to the next stage in pipeline. """Sends the input tensor to the next stage in pipeline.
Args: Args:
@ -183,7 +190,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) _communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False): def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
"""Sends the gradient tensor to the previous stage in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
Args: Args:
@ -201,7 +208,7 @@ def send_forward_recv_backward(output_tensor,
recv_next=True, recv_next=True,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False) -> torch.Tensor:
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage. next stage in pipeline as the input gradient tensor of this stage.
@ -230,7 +237,7 @@ def send_backward_recv_forward(input_tensor_grad,
recv_prev=True, recv_prev=True,
prev_rank=None, prev_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False) -> torch.Tensor:
"""Batched communication operation. Sends the gradient tensor to the """Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the output tensor from the previous stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage. previous stage in pipeline as the input of this stage.
@ -260,7 +267,7 @@ def send_forward_recv_forward(output_tensor,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False) -> torch.Tensor:
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage. previous stage in pipeline as the input of this stage.
@ -288,7 +295,7 @@ def send_backward_recv_backward(input_tensor_grad,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False) -> torch.Tensor:
"""Batched communication operation. Sends the gradient tensor to the """Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the gradient tensor from the previous stage in pipeline, while receives the gradient tensor from the
next member in pipeline as the input of this stage. next member in pipeline as the input of this stage.
@ -319,7 +326,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False) -> Tuple[torch.Tensor]:
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and """Batched communication operation. Sends the input tensor to the next stage in pipeline and
the gradient tensor to the previous stage, while receives the input gradient tensor from the the gradient tensor to the previous stage, while receives the input gradient tensor from the
next stage and the input tensor from the previous stage. next stage and the input tensor from the previous stage.

View File

@ -8,13 +8,13 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode): def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
"""Sends a tensor to the next member and receives a tensor from the previous member. """Sends a tensor to the next member and receives a tensor from the previous member.
This function returns the received tensor from the previous member. This function returns the received tensor from the previous member.
Args: Args:
tensor_send_next: Tensor sent to next member tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member
parallel_mode: Parallel group mode used in this communication parallel_mode (ParallelMode): Parallel group mode used in this communication
Returns: Returns:
:class:`torch.Tensor`: The tensor received from the previous. :class:`torch.Tensor`: The tensor received from the previous.

View File

@ -4,16 +4,19 @@ import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from typing import Union, List, Tuple
TensorShape = Union[torch.Size, List[int], Tuple[int]]
def send_tensor_meta(tensor, need_meta=True, next_rank=None): def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
"""Sends tensor meta information before sending a specific tensor. """Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`. synchronizes with :func:`recv_tensor_meta`.
Args: Args:
tensor (torch.Tensor): Tensor to be sent. tensor (:class:`torch.Tensor`): Tensor to be sent.
need_meta (bool, optional): If False, meta information won't be sent. need_meta (bool, optional): If False, meta information won't be sent.
next_rank (int): The rank of the next member in pipeline parallel group. next_rank (int): The rank of the next member in pipeline parallel group.
@ -34,14 +37,14 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
return False return False
def recv_tensor_meta(tensor_shape, prev_rank=None): def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size:
"""Receives tensor meta information before receiving a specific tensor. """Receives tensor meta information before receiving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be received before communications. This function meta information of the tensor should be received before communications. This function
synchronizes with :func:`send_tensor_meta`. synchronizes with :func:`send_tensor_meta`.
Args: Args:
tensor_shape (torch.Size): The shape of the tensor to be received. tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
prev_rank (int): The rank of the source of the tensor. prev_rank (int): The rank of the source of the tensor.
Returns: Returns:
@ -63,15 +66,15 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
return tensor_shape return tensor_shape
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
"""Break a tensor into equal 1D chunks. """Break a tensor into equal 1D chunks.
Args: Args:
tensor (torch.Tensor): Tensor to be split before communication. tensor (:class:`torch.Tensor`): Tensor to be split before communication.
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
Returns: Returns:
:class:`torch.Size`: The split tensor :class:`torch.Tensor`: The split tensor
""" """
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@ -84,13 +87,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
return data return data
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Opposite of above function, gather values from model parallel ranks. """Opposite of above function, gather values from model parallel ranks.
Args: Args:
tensor (torch.Tensor): Tensor to be gathered after communication. tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
Returns: Returns:
:class:`torch.Size`: The gathered tensor. :class:`torch.Tensor`: The gathered tensor.
""" """
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
numel = torch.numel(tensor) numel = torch.numel(tensor)