mirror of https://github.com/hpcaitech/ColossalAI
[doc] improved docstring in the communication module (#863)
parent
8af5f7423d
commit
8004c8e938
|
@ -208,7 +208,7 @@ def reduce(tensor: Tensor,
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None):
|
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
|
||||||
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
||||||
"""
|
"""
|
||||||
if dist.distributed_c10d._rank_not_in_group(group):
|
if dist.distributed_c10d._rank_not_in_group(group):
|
||||||
|
|
|
@ -23,7 +23,7 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
|
||||||
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
|
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Union[torch.Size, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
|
Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
|
||||||
"""
|
"""
|
||||||
if chunk_tensor:
|
if chunk_tensor:
|
||||||
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
|
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
|
||||||
|
@ -38,31 +38,38 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
|
||||||
return tensor_chunk_shape, chunk_tensor
|
return tensor_chunk_shape, chunk_tensor
|
||||||
|
|
||||||
|
|
||||||
def _communicate(tensor_send_next=None,
|
def _communicate(tensor_send_next: torch.Tensor = None,
|
||||||
tensor_send_prev=None,
|
tensor_send_prev: torch.Tensor = None,
|
||||||
recv_prev=False,
|
recv_prev: bool = False,
|
||||||
recv_next=False,
|
recv_next: bool = False,
|
||||||
recv_prev_shape=None,
|
recv_prev_shape: TensorShape = None,
|
||||||
recv_next_shape=None,
|
recv_next_shape: TensorShape = None,
|
||||||
prev_rank=None,
|
prev_rank: int = None,
|
||||||
next_rank=None,
|
next_rank: int = None,
|
||||||
dtype=None,
|
dtype: torch.dtype = None,
|
||||||
scatter_gather_tensors=False):
|
scatter_gather_tensors: bool = False) -> Tuple[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Adapted from megatron.p2p_communication.
|
Adapted from megatron.p2p_communication.
|
||||||
Communicate tensors between stages. Used as helper method in other
|
Communicate tensors between stages. Used as helper method in other
|
||||||
communication methods that are used in pipeline schedule.
|
communication methods that are used in pipeline schedule.
|
||||||
Takes the following arguments:
|
Takes the following arguments:
|
||||||
tensor_send_next: tensor to send to next rank (no tensor sent if
|
tensor_send_next (:class:`torch.Tensor`): tensor to send to next rank (no tensor sent if
|
||||||
set to None).
|
set to None).
|
||||||
tensor_send_prev: tensor to send to prev rank (no tensor sent if
|
tensor_send_prev (:class:`torch.Tensor`): tensor to send to prev rank (no tensor sent if
|
||||||
set to None).
|
set to None).
|
||||||
recv_prev: boolean for whether tensor should be received from
|
recv_prev (bool): boolean for whether tensor should be received from
|
||||||
previous rank.
|
previous rank.
|
||||||
recv_next: boolean for whether tensor should be received from
|
recv_next (bool): boolean for whether tensor should be received from
|
||||||
next rank.
|
next rank.
|
||||||
|
recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None.
|
||||||
|
recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None.
|
||||||
|
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
|
||||||
|
next_rank (int): the rank of the next pipeline stage, defualts to None,
|
||||||
|
dtype (torch.dtype): data type of intermediate buffers, defaults to None
|
||||||
|
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tensor_recv_prev, tensor_recv_next)
|
Tuple[torch.Tensor]: returns tensor_recv_prev, tensor_recv_next
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create placeholder tensors for receive in forward and backward directions
|
# Create placeholder tensors for receive in forward and backward directions
|
||||||
|
@ -130,7 +137,7 @@ def _communicate(tensor_send_next=None,
|
||||||
return tensor_recv_prev, tensor_recv_next
|
return tensor_recv_prev, tensor_recv_next
|
||||||
|
|
||||||
|
|
||||||
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
|
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
|
||||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -151,7 +158,7 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
|
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
|
||||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -172,7 +179,7 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
|
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
|
||||||
"""Sends the input tensor to the next stage in pipeline.
|
"""Sends the input tensor to the next stage in pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -183,7 +190,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
|
||||||
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||||
|
|
||||||
|
|
||||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
|
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
|
||||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -201,7 +208,7 @@ def send_forward_recv_backward(output_tensor,
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False):
|
scatter_gather_tensors=False) -> torch.Tensor:
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next stage in pipeline, while receives the gradient tensor from the
|
next stage in pipeline, while receives the gradient tensor from the
|
||||||
next stage in pipeline as the input gradient tensor of this stage.
|
next stage in pipeline as the input gradient tensor of this stage.
|
||||||
|
@ -230,7 +237,7 @@ def send_backward_recv_forward(input_tensor_grad,
|
||||||
recv_prev=True,
|
recv_prev=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False):
|
scatter_gather_tensors=False) -> torch.Tensor:
|
||||||
"""Batched communication operation. Sends the gradient tensor to the
|
"""Batched communication operation. Sends the gradient tensor to the
|
||||||
previous stage in pipeline, while receives the output tensor from the
|
previous stage in pipeline, while receives the output tensor from the
|
||||||
previous stage in pipeline as the input of this stage.
|
previous stage in pipeline as the input of this stage.
|
||||||
|
@ -260,7 +267,7 @@ def send_forward_recv_forward(output_tensor,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False):
|
scatter_gather_tensors=False) -> torch.Tensor:
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next stage in pipeline, while receives the output tensor from the
|
next stage in pipeline, while receives the output tensor from the
|
||||||
previous stage in pipeline as the input of this stage.
|
previous stage in pipeline as the input of this stage.
|
||||||
|
@ -288,7 +295,7 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False):
|
scatter_gather_tensors=False) -> torch.Tensor:
|
||||||
"""Batched communication operation. Sends the gradient tensor to the
|
"""Batched communication operation. Sends the gradient tensor to the
|
||||||
previous stage in pipeline, while receives the gradient tensor from the
|
previous stage in pipeline, while receives the gradient tensor from the
|
||||||
next member in pipeline as the input of this stage.
|
next member in pipeline as the input of this stage.
|
||||||
|
@ -319,7 +326,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False):
|
scatter_gather_tensors=False) -> Tuple[torch.Tensor]:
|
||||||
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
||||||
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
||||||
next stage and the input tensor from the previous stage.
|
next stage and the input tensor from the previous stage.
|
||||||
|
|
|
@ -8,13 +8,13 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_current_device, synchronize
|
from colossalai.utils import get_current_device, synchronize
|
||||||
|
|
||||||
|
|
||||||
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
|
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
|
||||||
"""Sends a tensor to the next member and receives a tensor from the previous member.
|
"""Sends a tensor to the next member and receives a tensor from the previous member.
|
||||||
This function returns the received tensor from the previous member.
|
This function returns the received tensor from the previous member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_send_next: Tensor sent to next member
|
tensor_send_next (:class:`torch.Tensor`): Tensor sent to next member
|
||||||
parallel_mode: Parallel group mode used in this communication
|
parallel_mode (ParallelMode): Parallel group mode used in this communication
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The tensor received from the previous.
|
:class:`torch.Tensor`: The tensor received from the previous.
|
||||||
|
|
|
@ -4,16 +4,19 @@ import torch.distributed as dist
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from typing import Union, List, Tuple
|
||||||
|
|
||||||
|
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||||
|
|
||||||
|
|
||||||
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
|
||||||
"""Sends tensor meta information before sending a specific tensor.
|
"""Sends tensor meta information before sending a specific tensor.
|
||||||
Since the recipient must know the shape of the tensor in p2p communications,
|
Since the recipient must know the shape of the tensor in p2p communications,
|
||||||
meta information of the tensor should be sent before communications. This function
|
meta information of the tensor should be sent before communications. This function
|
||||||
synchronizes with :func:`recv_tensor_meta`.
|
synchronizes with :func:`recv_tensor_meta`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor (torch.Tensor): Tensor to be sent.
|
tensor (:class:`torch.Tensor`): Tensor to be sent.
|
||||||
need_meta (bool, optional): If False, meta information won't be sent.
|
need_meta (bool, optional): If False, meta information won't be sent.
|
||||||
next_rank (int): The rank of the next member in pipeline parallel group.
|
next_rank (int): The rank of the next member in pipeline parallel group.
|
||||||
|
|
||||||
|
@ -34,14 +37,14 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def recv_tensor_meta(tensor_shape, prev_rank=None):
|
def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size:
|
||||||
"""Receives tensor meta information before receiving a specific tensor.
|
"""Receives tensor meta information before receiving a specific tensor.
|
||||||
Since the recipient must know the shape of the tensor in p2p communications,
|
Since the recipient must know the shape of the tensor in p2p communications,
|
||||||
meta information of the tensor should be received before communications. This function
|
meta information of the tensor should be received before communications. This function
|
||||||
synchronizes with :func:`send_tensor_meta`.
|
synchronizes with :func:`send_tensor_meta`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor_shape (torch.Size): The shape of the tensor to be received.
|
tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
||||||
prev_rank (int): The rank of the source of the tensor.
|
prev_rank (int): The rank of the source of the tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -63,15 +66,15 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
|
||||||
return tensor_shape
|
return tensor_shape
|
||||||
|
|
||||||
|
|
||||||
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
|
||||||
"""Break a tensor into equal 1D chunks.
|
"""Break a tensor into equal 1D chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor (torch.Tensor): Tensor to be split before communication.
|
tensor (:class:`torch.Tensor`): Tensor to be split before communication.
|
||||||
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
|
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Size`: The split tensor
|
:class:`torch.Tensor`: The split tensor
|
||||||
"""
|
"""
|
||||||
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
|
@ -84,13 +87,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def gather_split_1d_tensor(tensor):
|
def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
"""Opposite of above function, gather values from model parallel ranks.
|
"""Opposite of above function, gather values from model parallel ranks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor (torch.Tensor): Tensor to be gathered after communication.
|
tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Size`: The gathered tensor.
|
:class:`torch.Tensor`: The gathered tensor.
|
||||||
"""
|
"""
|
||||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||||
numel = torch.numel(tensor)
|
numel = torch.numel(tensor)
|
||||||
|
|
Loading…
Reference in New Issue