mirror of https://github.com/hpcaitech/ColossalAI
[doc] improved docstring in the communication module (#863)
parent
8af5f7423d
commit
8004c8e938
|
@ -208,7 +208,7 @@ def reduce(tensor: Tensor,
|
|||
return out
|
||||
|
||||
|
||||
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
|
||||
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
|
||||
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
||||
"""
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
|
|
|
@ -23,7 +23,7 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
|
|||
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
|
||||
|
||||
Returns:
|
||||
Tuple[Union[torch.Size, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
|
||||
Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
|
||||
"""
|
||||
if chunk_tensor:
|
||||
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
|
||||
|
@ -38,31 +38,38 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
|
|||
return tensor_chunk_shape, chunk_tensor
|
||||
|
||||
|
||||
def _communicate(tensor_send_next=None,
|
||||
tensor_send_prev=None,
|
||||
recv_prev=False,
|
||||
recv_next=False,
|
||||
recv_prev_shape=None,
|
||||
recv_next_shape=None,
|
||||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=None,
|
||||
scatter_gather_tensors=False):
|
||||
def _communicate(tensor_send_next: torch.Tensor = None,
|
||||
tensor_send_prev: torch.Tensor = None,
|
||||
recv_prev: bool = False,
|
||||
recv_next: bool = False,
|
||||
recv_prev_shape: TensorShape = None,
|
||||
recv_next_shape: TensorShape = None,
|
||||
prev_rank: int = None,
|
||||
next_rank: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
scatter_gather_tensors: bool = False) -> Tuple[torch.Tensor]:
|
||||
"""
|
||||
Adapted from megatron.p2p_communication.
|
||||
Communicate tensors between stages. Used as helper method in other
|
||||
communication methods that are used in pipeline schedule.
|
||||
Takes the following arguments:
|
||||
tensor_send_next: tensor to send to next rank (no tensor sent if
|
||||
tensor_send_next (:class:`torch.Tensor`): tensor to send to next rank (no tensor sent if
|
||||
set to None).
|
||||
tensor_send_prev: tensor to send to prev rank (no tensor sent if
|
||||
tensor_send_prev (:class:`torch.Tensor`): tensor to send to prev rank (no tensor sent if
|
||||
set to None).
|
||||
recv_prev: boolean for whether tensor should be received from
|
||||
recv_prev (bool): boolean for whether tensor should be received from
|
||||
previous rank.
|
||||
recv_next: boolean for whether tensor should be received from
|
||||
recv_next (bool): boolean for whether tensor should be received from
|
||||
next rank.
|
||||
recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None.
|
||||
recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None.
|
||||
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
|
||||
next_rank (int): the rank of the next pipeline stage, defualts to None,
|
||||
dtype (torch.dtype): data type of intermediate buffers, defaults to None
|
||||
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
|
||||
|
||||
Returns:
|
||||
(tensor_recv_prev, tensor_recv_next)
|
||||
Tuple[torch.Tensor]: returns tensor_recv_prev, tensor_recv_next
|
||||
"""
|
||||
|
||||
# Create placeholder tensors for receive in forward and backward directions
|
||||
|
@ -130,7 +137,7 @@ def _communicate(tensor_send_next=None,
|
|||
return tensor_recv_prev, tensor_recv_next
|
||||
|
||||
|
||||
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
|
||||
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
||||
Args:
|
||||
|
@ -151,7 +158,7 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
|
|||
return input_tensor
|
||||
|
||||
|
||||
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
|
||||
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
|
||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||
|
||||
Args:
|
||||
|
@ -172,7 +179,7 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_
|
|||
return output_tensor_grad
|
||||
|
||||
|
||||
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
|
||||
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
|
||||
"""Sends the input tensor to the next stage in pipeline.
|
||||
|
||||
Args:
|
||||
|
@ -183,7 +190,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
|
|||
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||
|
||||
|
||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
|
||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
|
||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||
|
||||
Args:
|
||||
|
@ -201,7 +208,7 @@ def send_forward_recv_backward(output_tensor,
|
|||
recv_next=True,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
scatter_gather_tensors=False) -> torch.Tensor:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the gradient tensor from the
|
||||
next stage in pipeline as the input gradient tensor of this stage.
|
||||
|
@ -230,7 +237,7 @@ def send_backward_recv_forward(input_tensor_grad,
|
|||
recv_prev=True,
|
||||
prev_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
scatter_gather_tensors=False) -> torch.Tensor:
|
||||
"""Batched communication operation. Sends the gradient tensor to the
|
||||
previous stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
|
@ -260,7 +267,7 @@ def send_forward_recv_forward(output_tensor,
|
|||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
scatter_gather_tensors=False) -> torch.Tensor:
|
||||
"""Batched communication operation. Sends the input tensor to the
|
||||
next stage in pipeline, while receives the output tensor from the
|
||||
previous stage in pipeline as the input of this stage.
|
||||
|
@ -288,7 +295,7 @@ def send_backward_recv_backward(input_tensor_grad,
|
|||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
scatter_gather_tensors=False) -> torch.Tensor:
|
||||
"""Batched communication operation. Sends the gradient tensor to the
|
||||
previous stage in pipeline, while receives the gradient tensor from the
|
||||
next member in pipeline as the input of this stage.
|
||||
|
@ -319,7 +326,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
|||
prev_rank=None,
|
||||
next_rank=None,
|
||||
dtype=torch.float,
|
||||
scatter_gather_tensors=False):
|
||||
scatter_gather_tensors=False) -> Tuple[torch.Tensor]:
|
||||
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
||||
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
||||
next stage and the input tensor from the previous stage.
|
||||
|
|
|
@ -8,13 +8,13 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.utils import get_current_device, synchronize
|
||||
|
||||
|
||||
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
|
||||
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
|
||||
"""Sends a tensor to the next member and receives a tensor from the previous member.
|
||||
This function returns the received tensor from the previous member.
|
||||
|
||||
Args:
|
||||
tensor_send_next: Tensor sent to next member
|
||||
parallel_mode: Parallel group mode used in this communication
|
||||
tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member
|
||||
parallel_mode (ParallelMode): Parallel group mode used in this communication
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: The tensor received from the previous.
|
||||
|
|
|
@ -4,16 +4,19 @@ import torch.distributed as dist
|
|||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
||||
def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
|
||||
"""Sends tensor meta information before sending a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be sent before communications. This function
|
||||
synchronizes with :func:`recv_tensor_meta`.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor to be sent.
|
||||
tensor (:class:`torch.Tensor`): Tensor to be sent.
|
||||
need_meta (bool, optional): If False, meta information won't be sent.
|
||||
next_rank (int): The rank of the next member in pipeline parallel group.
|
||||
|
||||
|
@ -34,14 +37,14 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
|||
return False
|
||||
|
||||
|
||||
def recv_tensor_meta(tensor_shape, prev_rank=None):
|
||||
def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size:
|
||||
"""Receives tensor meta information before receiving a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be received before communications. This function
|
||||
synchronizes with :func:`send_tensor_meta`.
|
||||
|
||||
Args:
|
||||
tensor_shape (torch.Size): The shape of the tensor to be received.
|
||||
tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
||||
prev_rank (int): The rank of the source of the tensor.
|
||||
|
||||
Returns:
|
||||
|
@ -63,15 +66,15 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
|
|||
return tensor_shape
|
||||
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
|
||||
"""Break a tensor into equal 1D chunks.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor to be split before communication.
|
||||
tensor (:class:`torch.Tensor`): Tensor to be split before communication.
|
||||
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
|
||||
|
||||
Returns:
|
||||
:class:`torch.Size`: The split tensor
|
||||
:class:`torch.Tensor`: The split tensor
|
||||
"""
|
||||
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
@ -84,13 +87,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
|||
return data
|
||||
|
||||
|
||||
def gather_split_1d_tensor(tensor):
|
||||
def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Opposite of above function, gather values from model parallel ranks.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor to be gathered after communication.
|
||||
tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
|
||||
Returns:
|
||||
:class:`torch.Size`: The gathered tensor.
|
||||
:class:`torch.Tensor`: The gathered tensor.
|
||||
"""
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
numel = torch.numel(tensor)
|
||||
|
|
Loading…
Reference in New Issue