[p2p]add object list send/recv (#1024)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [p2p]add object list send recv

* refactor for code reusability

* polish
pull/1034/head
YuliangLiu0306 2022-05-26 14:28:46 +08:00 committed by GitHub
parent e4685832f8
commit 7106bd671d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 238 additions and 89 deletions

View File

@ -38,38 +38,74 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
return tensor_chunk_shape, chunk_tensor
def _communicate(tensor_send_next: torch.Tensor = None,
tensor_send_prev: torch.Tensor = None,
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
if isinstance(recv_shapes, torch.Size):
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
return buffer_recv, recv_split
buffer_recv = []
for recv_shape in recv_shapes:
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
buffer_recv.append(tensor_recv)
return buffer_recv, recv_split
def process_object_to_send(object_send, scatter_gather_tensors):
if isinstance(object_send, torch.Tensor):
send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
if send_split:
object_send = split_tensor_into_1d_equal_chunks(object_send)
return object_send
for tensor_send in object_send:
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
if send_split:
tensor_send = split_tensor_into_1d_equal_chunks(tensor_send)
return object_send
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
if isinstance(obj, torch.Tensor):
op_to_add = dist.P2POp(comm_op, obj, comm_rank)
ops_queue.append(op_to_add)
else:
for tensor_to_comm in obj:
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
ops_queue.append(op_to_add)
def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
recv_prev: bool = False,
recv_next: bool = False,
recv_prev_shape: TensorShape = None,
recv_next_shape: TensorShape = None,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
prev_rank: int = None,
next_rank: int = None,
dtype: torch.dtype = None,
scatter_gather_tensors: bool = False) -> Tuple[torch.Tensor]:
scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule.
Takes the following arguments:
tensor_send_next (:class:`torch.Tensor`): tensor to send to next rank (no tensor sent if
object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev (:class:`torch.Tensor`): tensor to send to prev rank (no tensor sent if
object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if
set to None).
recv_prev (bool): boolean for whether tensor should be received from
previous rank.
recv_next (bool): boolean for whether tensor should be received from
next rank.
recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None.
recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None.
recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defualts to None.
recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defualts to None.
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
next_rank (int): the rank of the next pipeline stage, defualts to None,
dtype (torch.dtype): data type of intermediate buffers, defaults to None
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
Returns:
Tuple[torch.Tensor]: returns tensor_recv_prev, tensor_recv_next
Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
"""
# Create placeholder tensors for receive in forward and backward directions
@ -79,50 +115,41 @@ def _communicate(tensor_send_next: torch.Tensor = None,
if recv_prev:
assert recv_prev_shape is not None
recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors)
tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype,
scatter_gather_tensors)
if recv_next:
assert recv_next_shape is not None
recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors)
tensor_recv_next = torch.empty(recv_next_chunk_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype,
scatter_gather_tensors)
if tensor_send_prev is not None or recv_prev:
if object_send_prev is not None or recv_prev:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
if tensor_send_next is not None or recv_next:
if object_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
if tensor_send_prev is not None:
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
if send_prev_split:
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
if object_send_prev is not None:
object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
if tensor_send_next is not None:
send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1]
if send_next_split:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
if object_send_next is not None:
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
ops = []
if tensor_send_prev is not None:
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
ops.append(send_prev_op)
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
if tensor_recv_prev is not None:
recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
ops.append(recv_prev_op)
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
if tensor_recv_next is not None:
recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
ops.append(recv_next_op)
if tensor_send_next is not None:
send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
ops.append(send_next_op)
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
@ -131,21 +158,34 @@ def _communicate(tensor_send_next: torch.Tensor = None,
torch.cuda.synchronize()
if recv_prev and recv_prev_split:
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
if isinstance(tensor_recv_prev, torch.Tensor):
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for tensor_recv, tensor_shape in zip(tensor_recv_prev, recv_prev_shape):
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
if recv_next and recv_next_split:
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for tensor_recv, tensor_shape in zip(tensor_recv_next, recv_next_shape):
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
def recv_forward(input_tensor_shape,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
:class:`torch.Tensor`: The input tensor.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
"""
if gpc.is_pipeline_first_stage():
input_tensor = None
@ -158,15 +198,18 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
return input_tensor
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
def recv_backward(output_grad_shape,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
next_rank (int, optional): The rank of the source of the tensor.
Returns:
:class:`torch.Tensor`: The input gradient tensor.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
"""
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
@ -183,22 +226,22 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) ->
"""Sends the input tensor to the next stage in pipeline.
Args:
output_tensor (:class:`torch.Tensor`): Tensor to be sent.
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
Args:
input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not gpc.is_pipeline_first_stage():
_communicate(tensor_send_prev=input_tensor_grad,
_communicate(object_send_prev=input_tensor_grad,
prev_rank=prev_rank,
scatter_gather_tensors=scatter_gather_tensors)
@ -208,22 +251,22 @@ def send_forward_recv_backward(output_tensor,
recv_next=True,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> torch.Tensor:
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage.
Args:
output_tensor (:class:`torch.Tensor`): Tensor to be sent.
output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
Returns:
:class:`torch.Tensor`: The input gradient tensor.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
"""
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
_, output_tensor_grad = _communicate(object_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
@ -237,22 +280,22 @@ def send_backward_recv_forward(input_tensor_grad,
recv_prev=True,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> torch.Tensor:
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
Args:
input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent.
input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
Returns:
:class:`torch.Tensor`: The input tensor.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
"""
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
input_tensor, _ = _communicate(object_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
@ -267,19 +310,19 @@ def send_forward_recv_forward(output_tensor,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> torch.Tensor:
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
Args:
output_tensor (:class:`torch.Tensor`): Tensor to be sent.
input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
Returns:
:class:`torch.Tensor`: The input tensor.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
"""
input_tensor, _ = _communicate(tensor_send_next=output_tensor,
input_tensor, _ = _communicate(object_send_next=output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
@ -295,19 +338,19 @@ def send_backward_recv_backward(input_tensor_grad,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> torch.Tensor:
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the gradient tensor from the
next member in pipeline as the input of this stage.
Args:
input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent.
output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
Returns:
:class:`torch.Tensor`: The input gradient tensor.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
"""
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
_, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
@ -317,31 +360,32 @@ def send_backward_recv_backward(input_tensor_grad,
return output_tensor_grad
def send_forward_backward_recv_forward_backward(output_tensor,
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Tuple[torch.Tensor]:
def send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
the gradient tensor to the previous stage, while receives the input gradient tensor from the
next stage and the input tensor from the previous stage.
Args:
output_tensor (:class:`torch.Tensor`): Tensor sent to the next.
input_tensor_grad (:class:`torch.Tensor`): Tensor sent to the previous.
input_tensor_shape (:class:`torch.Size`): The shape of the tensor received from the previous.
output_grad_shape (:class:`torch.Size`): The shape of the tensor received from the next.
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous.
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next.
Returns:
Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor)
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
"""
input_tensor, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor,
object_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev_shape=input_tensor_shape,

View File

@ -0,0 +1,105 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(pipeline=2))
torch.manual_seed(123)
LIST_LENGTH = 3
TENSOR_SIZE = torch.Size((3, 3))
TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)]
data = torch.rand(3, 3)
data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
grad = torch.rand(3, 3)
grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
def check_send_recv_forward():
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
device = torch.device('cuda:0')
data_to_send = data.to(device)
data_list_to_send = []
for data_in_list in data_list:
data_list_to_send.append(data_in_list.to(device))
send_forward(data_to_send)
send_forward(data_list_to_send)
else:
device = torch.device('cuda:1')
data_recv = recv_forward(TENSOR_SIZE)
data_list_recv = recv_forward(TENSOR_SIZE_LIST)
data_to_check = data.to(device)
assert data_recv.equal(data_to_check)
for data_recv, data_send in zip(data_list_recv, data_list):
data_to_check = data_send.to(device)
assert data_recv.equal(data_to_check)
def check_send_recv_backward():
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
device = torch.device('cuda:0')
grad_recv = recv_backward(TENSOR_SIZE)
grad_list_recv = recv_backward(TENSOR_SIZE_LIST)
grad_to_check = grad.to(device)
assert grad_recv.equal(grad_to_check)
for grad_recv, grad_send in zip(grad_list_recv, grad_list):
grad_to_check = grad_send.to(device)
assert grad_recv.equal(grad_to_check)
else:
device = torch.device('cuda:1')
grad_to_send = grad.to(device)
grad_list_to_send = []
for grad_in_list in grad_list:
grad_list_to_send.append(grad_in_list.to(device))
send_backward(grad_to_send)
send_backward(grad_list_to_send)
def check_send_recv_forward_backward():
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
device = torch.device('cuda:0')
data_list_to_send = []
for data_in_list in data_list:
data_list_to_send.append(data_in_list.to(device))
grad_list_recv = send_forward_recv_backward(data_list_to_send, TENSOR_SIZE_LIST)
for grad_recv, grad_send in zip(grad_list_recv, grad_list):
grad_to_check = grad_send.to(device)
assert grad_recv.equal(grad_to_check)
else:
device = torch.device('cuda:1')
grad_list_to_send = []
for grad_in_list in grad_list:
grad_list_to_send.append(grad_in_list.to(device))
data_list_recv = send_backward_recv_forward(grad_list_to_send, TENSOR_SIZE_LIST)
for data_recv, data_send in zip(data_list_recv, data_list):
data_to_check = data_send.to(device)
assert data_recv.equal(data_to_check)
def check_layer(rank, world_size, port):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_send_recv_forward()
check_send_recv_backward()
check_send_recv_forward_backward()
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_object_list_p2p():
world_size = 2
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_object_list_p2p()