From 7106bd671d66170d619f009455f0e1a83506aee4 Mon Sep 17 00:00:00 2001
From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Date: Thu, 26 May 2022 14:28:46 +0800
Subject: [PATCH] [p2p]add object list send/recv (#1024)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c.

* [p2p]add object list send recv

* refactor for code reusability

* polish
---
 colossalai/communication/p2p.py         | 222 ++++++++++++++----------
 tests/test_comm/test_object_list_p2p.py | 105 +++++++++++
 2 files changed, 238 insertions(+), 89 deletions(-)
 create mode 100644 tests/test_comm/test_object_list_p2p.py

diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py
index 12737e21d..6722860e7 100644
--- a/colossalai/communication/p2p.py
+++ b/colossalai/communication/p2p.py
@@ -38,38 +38,74 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
     return tensor_chunk_shape, chunk_tensor
 
 
-def _communicate(tensor_send_next: torch.Tensor = None,
-                 tensor_send_prev: torch.Tensor = None,
+def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
+    if isinstance(recv_shapes, torch.Size):
+        recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
+        buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
+        return buffer_recv, recv_split
+    buffer_recv = []
+    for recv_shape in recv_shapes:
+        recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
+        tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
+        buffer_recv.append(tensor_recv)
+    return buffer_recv, recv_split
+
+
+def process_object_to_send(object_send, scatter_gather_tensors):
+    if isinstance(object_send, torch.Tensor):
+        send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
+        if send_split:
+            object_send = split_tensor_into_1d_equal_chunks(object_send)
+        return object_send
+    for tensor_send in object_send:
+        send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
+        if send_split:
+            tensor_send = split_tensor_into_1d_equal_chunks(tensor_send)
+    return object_send
+
+
+def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
+    if isinstance(obj, torch.Tensor):
+        op_to_add = dist.P2POp(comm_op, obj, comm_rank)
+        ops_queue.append(op_to_add)
+    else:
+        for tensor_to_comm in obj:
+            op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
+            ops_queue.append(op_to_add)
+
+
+def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
+                 object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
                  recv_prev: bool = False,
                  recv_next: bool = False,
-                 recv_prev_shape: TensorShape = None,
-                 recv_next_shape: TensorShape = None,
+                 recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
+                 recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
                  prev_rank: int = None,
                  next_rank: int = None,
                  dtype: torch.dtype = None,
-                 scatter_gather_tensors: bool = False) -> Tuple[torch.Tensor]:
+                 scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
     """
     Adapted from megatron.p2p_communication.
     Communicate tensors between stages. Used as helper method in other
     communication methods that are used in pipeline schedule.
     Takes the following arguments:
-        tensor_send_next (:class:`torch.Tensor`): tensor to send to next rank (no tensor sent if
+        object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if
                           set to None).
-        tensor_send_prev (:class:`torch.Tensor`): tensor to send to prev rank (no tensor sent if
+        object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if
                           set to None).
         recv_prev (bool): boolean for whether tensor should be received from
                    previous rank.
         recv_next (bool): boolean for whether tensor should be received from
                    next rank.
-        recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None.
-        recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None.
+        recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defualts to None.
+        recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defualts to None.
         prev_rank (int): the rank of the previous pipeline stage, defualts to None,
         next_rank (int): the rank of the next pipeline stage, defualts to None,
         dtype (torch.dtype): data type of intermediate buffers, defaults to None
         scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
 
     Returns:
-        Tuple[torch.Tensor]: returns tensor_recv_prev, tensor_recv_next
+        Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
     """
 
     # Create placeholder tensors for receive in forward and backward directions
@@ -79,50 +115,41 @@ def _communicate(tensor_send_next: torch.Tensor = None,
 
     if recv_prev:
         assert recv_prev_shape is not None
-        recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors)
-        tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
-                                       requires_grad=True,
-                                       device=get_current_device(),
-                                       dtype=dtype)
+        tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype,
+                                                                           scatter_gather_tensors)
+
     if recv_next:
         assert recv_next_shape is not None
-        recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors)
-        tensor_recv_next = torch.empty(recv_next_chunk_shape,
-                                       requires_grad=True,
-                                       device=get_current_device(),
-                                       dtype=dtype)
+        tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype,
+                                                                           scatter_gather_tensors)
 
-    if tensor_send_prev is not None or recv_prev:
+    if object_send_prev is not None or recv_prev:
         if prev_rank is None:
             prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
 
-    if tensor_send_next is not None or recv_next:
+    if object_send_next is not None or recv_next:
         if next_rank is None:
             next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
 
-    if tensor_send_prev is not None:
-        send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
-        if send_prev_split:
-            tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
+    if object_send_prev is not None:
+        object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
 
-    if tensor_send_next is not None:
-        send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1]
-        if send_next_split:
-            tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
+    if object_send_next is not None:
+        object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
 
     ops = []
-    if tensor_send_prev is not None:
-        send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
-        ops.append(send_prev_op)
+    if object_send_prev is not None:
+        filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
+
     if tensor_recv_prev is not None:
-        recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
-        ops.append(recv_prev_op)
+        filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
+
     if tensor_recv_next is not None:
-        recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
-        ops.append(recv_next_op)
-    if tensor_send_next is not None:
-        send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
-        ops.append(send_next_op)
+        filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
+
+    if object_send_next is not None:
+        filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
+
     if len(ops) > 0:
         reqs = dist.batch_isend_irecv(ops)
         for req in reqs:
@@ -131,21 +158,34 @@ def _communicate(tensor_send_next: torch.Tensor = None,
     torch.cuda.synchronize()
 
     if recv_prev and recv_prev_split:
-        tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
+        if isinstance(tensor_recv_prev, torch.Tensor):
+            tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
+        else:
+            for tensor_recv, tensor_shape in zip(tensor_recv_prev, recv_prev_shape):
+                tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
+
     if recv_next and recv_next_split:
-        tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
+        if isinstance(tensor_recv_next, torch.Tensor):
+            tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
+        else:
+            for tensor_recv, tensor_shape in zip(tensor_recv_next, recv_next_shape):
+                tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
+
     return tensor_recv_prev, tensor_recv_next
 
 
-def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
+def recv_forward(input_tensor_shape,
+                 prev_rank=None,
+                 dtype=torch.float,
+                 scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
     """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
 
     Args:
-        input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
+        input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
         prev_rank (int, optional): The rank of the source of the tensor.
 
     Returns:
-        :class:`torch.Tensor`: The input tensor.
+        Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
     """
     if gpc.is_pipeline_first_stage():
         input_tensor = None
@@ -158,15 +198,18 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
     return input_tensor
 
 
-def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
+def recv_backward(output_grad_shape,
+                  next_rank=None,
+                  dtype=torch.float,
+                  scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
     """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
 
     Args:
-        output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
+        output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
         next_rank (int, optional): The rank of the source of the tensor.
 
     Returns:
-        :class:`torch.Tensor`: The input gradient tensor.
+        Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
     """
     if gpc.is_pipeline_last_stage():
         output_tensor_grad = None
@@ -183,22 +226,22 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) ->
     """Sends the input tensor to the next stage in pipeline.
 
     Args:
-        output_tensor (:class:`torch.Tensor`): Tensor to be sent.
+        output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
         next_rank (int, optional): The rank of the recipient of the tensor.
     """
     if not gpc.is_pipeline_last_stage():
-        _communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
+        _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
 
 
 def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
     """Sends the gradient tensor to the previous stage in pipeline.
 
     Args:
-        input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent
+        input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
         prev_rank (int, optional): The rank of the recipient of the tensor
     """
     if not gpc.is_pipeline_first_stage():
-        _communicate(tensor_send_prev=input_tensor_grad,
+        _communicate(object_send_prev=input_tensor_grad,
                      prev_rank=prev_rank,
                      scatter_gather_tensors=scatter_gather_tensors)
 
@@ -208,22 +251,22 @@ def send_forward_recv_backward(output_tensor,
                                recv_next=True,
                                next_rank=None,
                                dtype=torch.float,
-                               scatter_gather_tensors=False) -> torch.Tensor:
+                               scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
     """Batched communication operation. Sends the input tensor to the 
     next stage in pipeline, while receives the gradient tensor from the
     next stage in pipeline as the input gradient tensor of this stage.
 
     Args:
-        output_tensor (:class:`torch.Tensor`): Tensor to be sent.
-        output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
+        output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
+        output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
 
     Returns:
-        :class:`torch.Tensor`: The input gradient tensor.
+        Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
     """
     if gpc.is_pipeline_last_stage():
         output_tensor_grad = None
     else:
-        _, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
+        _, output_tensor_grad = _communicate(object_send_next=output_tensor,
                                              recv_next=recv_next,
                                              recv_next_shape=output_grad_shape,
                                              next_rank=next_rank,
@@ -237,22 +280,22 @@ def send_backward_recv_forward(input_tensor_grad,
                                recv_prev=True,
                                prev_rank=None,
                                dtype=torch.float,
-                               scatter_gather_tensors=False) -> torch.Tensor:
+                               scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
     """Batched communication operation. Sends the gradient tensor to the
     previous stage in pipeline, while receives the output tensor from the
     previous stage in pipeline as the input of this stage.
 
     Args:
-        input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent.
-        input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
+        input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
+        input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
 
     Returns:
-        :class:`torch.Tensor`: The input tensor.
+        Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
     """
     if gpc.is_pipeline_first_stage():
         input_tensor = None
     else:
-        input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
+        input_tensor, _ = _communicate(object_send_prev=input_tensor_grad,
                                        recv_prev=recv_prev,
                                        recv_prev_shape=input_tensor_shape,
                                        prev_rank=prev_rank,
@@ -267,19 +310,19 @@ def send_forward_recv_forward(output_tensor,
                               prev_rank=None,
                               next_rank=None,
                               dtype=torch.float,
-                              scatter_gather_tensors=False) -> torch.Tensor:
+                              scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
     """Batched communication operation. Sends the input tensor to the 
     next stage in pipeline, while receives the output tensor from the
     previous stage in pipeline as the input of this stage.
 
     Args:
-        output_tensor (:class:`torch.Tensor`): Tensor to be sent.
-        input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
+        output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
+        input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
 
     Returns:
-        :class:`torch.Tensor`: The input tensor.
+        Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
     """
-    input_tensor, _ = _communicate(tensor_send_next=output_tensor,
+    input_tensor, _ = _communicate(object_send_next=output_tensor,
                                    recv_prev=recv_prev,
                                    recv_prev_shape=input_tensor_shape,
                                    prev_rank=prev_rank,
@@ -295,19 +338,19 @@ def send_backward_recv_backward(input_tensor_grad,
                                 prev_rank=None,
                                 next_rank=None,
                                 dtype=torch.float,
-                                scatter_gather_tensors=False) -> torch.Tensor:
+                                scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
     """Batched communication operation. Sends the gradient tensor to the
     previous stage in pipeline, while receives the gradient tensor from the
     next member in pipeline as the input of this stage.
 
     Args:
-        input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent.
-        output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
+        input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
+        output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
 
     Returns:
-        :class:`torch.Tensor`: The input gradient tensor.
+        Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
     """
-    _, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
+    _, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad,
                                          recv_next=recv_next,
                                          recv_next_shape=output_grad_shape,
                                          prev_rank=prev_rank,
@@ -317,31 +360,32 @@ def send_backward_recv_backward(input_tensor_grad,
     return output_tensor_grad
 
 
-def send_forward_backward_recv_forward_backward(output_tensor,
-                                                input_tensor_grad,
-                                                input_tensor_shape,
-                                                output_grad_shape,
-                                                recv_prev=True,
-                                                recv_next=True,
-                                                prev_rank=None,
-                                                next_rank=None,
-                                                dtype=torch.float,
-                                                scatter_gather_tensors=False) -> Tuple[torch.Tensor]:
+def send_forward_backward_recv_forward_backward(
+        output_tensor,
+        input_tensor_grad,
+        input_tensor_shape,
+        output_grad_shape,
+        recv_prev=True,
+        recv_next=True,
+        prev_rank=None,
+        next_rank=None,
+        dtype=torch.float,
+        scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
     """Batched communication operation. Sends the input tensor to the next stage in pipeline and
     the gradient tensor to the previous stage, while receives the input gradient tensor from the
     next stage and the input tensor from the previous stage.
 
     Args:
-        output_tensor (:class:`torch.Tensor`): Tensor sent to the next.
-        input_tensor_grad (:class:`torch.Tensor`): Tensor sent to the previous.
-        input_tensor_shape (:class:`torch.Size`): The shape of the tensor received from the previous.
-        output_grad_shape (:class:`torch.Size`): The shape of the tensor received from the next.
+        output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
+        input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
+        input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous.
+        output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next.
 
     Returns:
-        Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor)
+        Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
     """
-    input_tensor, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
-                                                    tensor_send_prev=input_tensor_grad,
+    input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor,
+                                                    object_send_prev=input_tensor_grad,
                                                     recv_prev=recv_prev,
                                                     recv_next=recv_next,
                                                     recv_prev_shape=input_tensor_shape,
diff --git a/tests/test_comm/test_object_list_p2p.py b/tests/test_comm/test_object_list_p2p.py
new file mode 100644
index 000000000..701e3e8ad
--- /dev/null
+++ b/tests/test_comm/test_object_list_p2p.py
@@ -0,0 +1,105 @@
+from functools import partial
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.initialize import launch
+from colossalai.utils import free_port, get_current_device
+from colossalai.testing import rerun_if_address_is_in_use
+
+CONFIG = dict(parallel=dict(pipeline=2))
+torch.manual_seed(123)
+LIST_LENGTH = 3
+TENSOR_SIZE = torch.Size((3, 3))
+TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)]
+data = torch.rand(3, 3)
+data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
+grad = torch.rand(3, 3)
+grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
+
+
+def check_send_recv_forward():
+    if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
+        device = torch.device('cuda:0')
+        data_to_send = data.to(device)
+        data_list_to_send = []
+        for data_in_list in data_list:
+            data_list_to_send.append(data_in_list.to(device))
+        send_forward(data_to_send)
+        send_forward(data_list_to_send)
+    else:
+        device = torch.device('cuda:1')
+        data_recv = recv_forward(TENSOR_SIZE)
+        data_list_recv = recv_forward(TENSOR_SIZE_LIST)
+        data_to_check = data.to(device)
+        assert data_recv.equal(data_to_check)
+        for data_recv, data_send in zip(data_list_recv, data_list):
+            data_to_check = data_send.to(device)
+            assert data_recv.equal(data_to_check)
+
+
+def check_send_recv_backward():
+    if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
+        device = torch.device('cuda:0')
+        grad_recv = recv_backward(TENSOR_SIZE)
+        grad_list_recv = recv_backward(TENSOR_SIZE_LIST)
+        grad_to_check = grad.to(device)
+        assert grad_recv.equal(grad_to_check)
+        for grad_recv, grad_send in zip(grad_list_recv, grad_list):
+            grad_to_check = grad_send.to(device)
+            assert grad_recv.equal(grad_to_check)
+    else:
+        device = torch.device('cuda:1')
+        grad_to_send = grad.to(device)
+        grad_list_to_send = []
+        for grad_in_list in grad_list:
+            grad_list_to_send.append(grad_in_list.to(device))
+        send_backward(grad_to_send)
+        send_backward(grad_list_to_send)
+
+
+def check_send_recv_forward_backward():
+    if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
+        device = torch.device('cuda:0')
+        data_list_to_send = []
+        for data_in_list in data_list:
+            data_list_to_send.append(data_in_list.to(device))
+        grad_list_recv = send_forward_recv_backward(data_list_to_send, TENSOR_SIZE_LIST)
+
+        for grad_recv, grad_send in zip(grad_list_recv, grad_list):
+            grad_to_check = grad_send.to(device)
+            assert grad_recv.equal(grad_to_check)
+    else:
+        device = torch.device('cuda:1')
+        grad_list_to_send = []
+        for grad_in_list in grad_list:
+            grad_list_to_send.append(grad_in_list.to(device))
+        data_list_recv = send_backward_recv_forward(grad_list_to_send, TENSOR_SIZE_LIST)
+        for data_recv, data_send in zip(data_list_recv, data_list):
+            data_to_check = data_send.to(device)
+            assert data_recv.equal(data_to_check)
+
+
+def check_layer(rank, world_size, port):
+    launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+    check_send_recv_forward()
+    check_send_recv_backward()
+    check_send_recv_forward_backward()
+    gpc.destroy()
+    torch.cuda.empty_cache()
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_object_list_p2p():
+    world_size = 2
+    run_func = partial(check_layer, world_size=world_size, port=free_port())
+    mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+    test_object_list_p2p()