mirror of https://github.com/hpcaitech/ColossalAI
[p2p]add object list send/recv (#1024)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [p2p]add object list send recv
* refactor for code reusability
* polish
pull/1034/head
parent
e4685832f8
commit
7106bd671d
|
@ -38,38 +38,74 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
|
||||||
return tensor_chunk_shape, chunk_tensor
|
return tensor_chunk_shape, chunk_tensor
|
||||||
|
|
||||||
|
|
||||||
def _communicate(tensor_send_next: torch.Tensor = None,
|
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
|
||||||
tensor_send_prev: torch.Tensor = None,
|
if isinstance(recv_shapes, torch.Size):
|
||||||
|
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
|
||||||
|
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||||
|
return buffer_recv, recv_split
|
||||||
|
buffer_recv = []
|
||||||
|
for recv_shape in recv_shapes:
|
||||||
|
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
|
||||||
|
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||||
|
buffer_recv.append(tensor_recv)
|
||||||
|
return buffer_recv, recv_split
|
||||||
|
|
||||||
|
|
||||||
|
def process_object_to_send(object_send, scatter_gather_tensors):
|
||||||
|
if isinstance(object_send, torch.Tensor):
|
||||||
|
send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
|
||||||
|
if send_split:
|
||||||
|
object_send = split_tensor_into_1d_equal_chunks(object_send)
|
||||||
|
return object_send
|
||||||
|
for tensor_send in object_send:
|
||||||
|
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
|
||||||
|
if send_split:
|
||||||
|
tensor_send = split_tensor_into_1d_equal_chunks(tensor_send)
|
||||||
|
return object_send
|
||||||
|
|
||||||
|
|
||||||
|
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
op_to_add = dist.P2POp(comm_op, obj, comm_rank)
|
||||||
|
ops_queue.append(op_to_add)
|
||||||
|
else:
|
||||||
|
for tensor_to_comm in obj:
|
||||||
|
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
|
||||||
|
ops_queue.append(op_to_add)
|
||||||
|
|
||||||
|
|
||||||
|
def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||||
|
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
|
||||||
recv_prev: bool = False,
|
recv_prev: bool = False,
|
||||||
recv_next: bool = False,
|
recv_next: bool = False,
|
||||||
recv_prev_shape: TensorShape = None,
|
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||||
recv_next_shape: TensorShape = None,
|
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
||||||
prev_rank: int = None,
|
prev_rank: int = None,
|
||||||
next_rank: int = None,
|
next_rank: int = None,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
scatter_gather_tensors: bool = False) -> Tuple[torch.Tensor]:
|
scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Adapted from megatron.p2p_communication.
|
Adapted from megatron.p2p_communication.
|
||||||
Communicate tensors between stages. Used as helper method in other
|
Communicate tensors between stages. Used as helper method in other
|
||||||
communication methods that are used in pipeline schedule.
|
communication methods that are used in pipeline schedule.
|
||||||
Takes the following arguments:
|
Takes the following arguments:
|
||||||
tensor_send_next (:class:`torch.Tensor`): tensor to send to next rank (no tensor sent if
|
object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank (no tensor sent if
|
||||||
set to None).
|
set to None).
|
||||||
tensor_send_prev (:class:`torch.Tensor`): tensor to send to prev rank (no tensor sent if
|
object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank (no tensor sent if
|
||||||
set to None).
|
set to None).
|
||||||
recv_prev (bool): boolean for whether tensor should be received from
|
recv_prev (bool): boolean for whether tensor should be received from
|
||||||
previous rank.
|
previous rank.
|
||||||
recv_next (bool): boolean for whether tensor should be received from
|
recv_next (bool): boolean for whether tensor should be received from
|
||||||
next rank.
|
next rank.
|
||||||
recv_prev_shape (TensorShape): shape of the tensor to be received from the previous stage, defualts to None.
|
recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the previous stage, defualts to None.
|
||||||
recv_next_shape (TensorShape): shape of the tensor to be received from the next stage, defualts to None.
|
recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received from the next stage, defualts to None.
|
||||||
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
|
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
|
||||||
next_rank (int): the rank of the next pipeline stage, defualts to None,
|
next_rank (int): the rank of the next pipeline stage, defualts to None,
|
||||||
dtype (torch.dtype): data type of intermediate buffers, defaults to None
|
dtype (torch.dtype): data type of intermediate buffers, defaults to None
|
||||||
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
|
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor]: returns tensor_recv_prev, tensor_recv_next
|
Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create placeholder tensors for receive in forward and backward directions
|
# Create placeholder tensors for receive in forward and backward directions
|
||||||
|
@ -79,50 +115,41 @@ def _communicate(tensor_send_next: torch.Tensor = None,
|
||||||
|
|
||||||
if recv_prev:
|
if recv_prev:
|
||||||
assert recv_prev_shape is not None
|
assert recv_prev_shape is not None
|
||||||
recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors)
|
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype,
|
||||||
tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
|
scatter_gather_tensors)
|
||||||
requires_grad=True,
|
|
||||||
device=get_current_device(),
|
|
||||||
dtype=dtype)
|
|
||||||
if recv_next:
|
if recv_next:
|
||||||
assert recv_next_shape is not None
|
assert recv_next_shape is not None
|
||||||
recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors)
|
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype,
|
||||||
tensor_recv_next = torch.empty(recv_next_chunk_shape,
|
scatter_gather_tensors)
|
||||||
requires_grad=True,
|
|
||||||
device=get_current_device(),
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
if tensor_send_prev is not None or recv_prev:
|
if object_send_prev is not None or recv_prev:
|
||||||
if prev_rank is None:
|
if prev_rank is None:
|
||||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
if tensor_send_next is not None or recv_next:
|
if object_send_next is not None or recv_next:
|
||||||
if next_rank is None:
|
if next_rank is None:
|
||||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
if tensor_send_prev is not None:
|
if object_send_prev is not None:
|
||||||
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
|
object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
|
||||||
if send_prev_split:
|
|
||||||
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
|
|
||||||
|
|
||||||
if tensor_send_next is not None:
|
if object_send_next is not None:
|
||||||
send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1]
|
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
|
||||||
if send_next_split:
|
|
||||||
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
|
|
||||||
|
|
||||||
ops = []
|
ops = []
|
||||||
if tensor_send_prev is not None:
|
if object_send_prev is not None:
|
||||||
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
|
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
|
||||||
ops.append(send_prev_op)
|
|
||||||
if tensor_recv_prev is not None:
|
if tensor_recv_prev is not None:
|
||||||
recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
|
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
|
||||||
ops.append(recv_prev_op)
|
|
||||||
if tensor_recv_next is not None:
|
if tensor_recv_next is not None:
|
||||||
recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
|
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
|
||||||
ops.append(recv_next_op)
|
|
||||||
if tensor_send_next is not None:
|
if object_send_next is not None:
|
||||||
send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
|
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
|
||||||
ops.append(send_next_op)
|
|
||||||
if len(ops) > 0:
|
if len(ops) > 0:
|
||||||
reqs = dist.batch_isend_irecv(ops)
|
reqs = dist.batch_isend_irecv(ops)
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
|
@ -131,21 +158,34 @@ def _communicate(tensor_send_next: torch.Tensor = None,
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
if recv_prev and recv_prev_split:
|
if recv_prev and recv_prev_split:
|
||||||
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
if isinstance(tensor_recv_prev, torch.Tensor):
|
||||||
|
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
||||||
|
else:
|
||||||
|
for tensor_recv, tensor_shape in zip(tensor_recv_prev, recv_prev_shape):
|
||||||
|
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
|
||||||
|
|
||||||
if recv_next and recv_next_split:
|
if recv_next and recv_next_split:
|
||||||
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
if isinstance(tensor_recv_next, torch.Tensor):
|
||||||
|
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
||||||
|
else:
|
||||||
|
for tensor_recv, tensor_shape in zip(tensor_recv_next, recv_next_shape):
|
||||||
|
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
|
||||||
|
|
||||||
return tensor_recv_prev, tensor_recv_next
|
return tensor_recv_prev, tensor_recv_next
|
||||||
|
|
||||||
|
|
||||||
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
|
def recv_forward(input_tensor_shape,
|
||||||
|
prev_rank=None,
|
||||||
|
dtype=torch.float,
|
||||||
|
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
|
||||||
prev_rank (int, optional): The rank of the source of the tensor.
|
prev_rank (int, optional): The rank of the source of the tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The input tensor.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
|
||||||
"""
|
"""
|
||||||
if gpc.is_pipeline_first_stage():
|
if gpc.is_pipeline_first_stage():
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
|
@ -158,15 +198,18 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
|
||||||
return input_tensor
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> torch.Tensor:
|
def recv_backward(output_grad_shape,
|
||||||
|
next_rank=None,
|
||||||
|
dtype=torch.float,
|
||||||
|
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
|
||||||
next_rank (int, optional): The rank of the source of the tensor.
|
next_rank (int, optional): The rank of the source of the tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The input gradient tensor.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
|
||||||
"""
|
"""
|
||||||
if gpc.is_pipeline_last_stage():
|
if gpc.is_pipeline_last_stage():
|
||||||
output_tensor_grad = None
|
output_tensor_grad = None
|
||||||
|
@ -183,22 +226,22 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) ->
|
||||||
"""Sends the input tensor to the next stage in pipeline.
|
"""Sends the input tensor to the next stage in pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_tensor (:class:`torch.Tensor`): Tensor to be sent.
|
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||||
next_rank (int, optional): The rank of the recipient of the tensor.
|
next_rank (int, optional): The rank of the recipient of the tensor.
|
||||||
"""
|
"""
|
||||||
if not gpc.is_pipeline_last_stage():
|
if not gpc.is_pipeline_last_stage():
|
||||||
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
||||||
|
|
||||||
|
|
||||||
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
|
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
|
||||||
"""Sends the gradient tensor to the previous stage in pipeline.
|
"""Sends the gradient tensor to the previous stage in pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent
|
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
|
||||||
prev_rank (int, optional): The rank of the recipient of the tensor
|
prev_rank (int, optional): The rank of the recipient of the tensor
|
||||||
"""
|
"""
|
||||||
if not gpc.is_pipeline_first_stage():
|
if not gpc.is_pipeline_first_stage():
|
||||||
_communicate(tensor_send_prev=input_tensor_grad,
|
_communicate(object_send_prev=input_tensor_grad,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
scatter_gather_tensors=scatter_gather_tensors)
|
scatter_gather_tensors=scatter_gather_tensors)
|
||||||
|
|
||||||
|
@ -208,22 +251,22 @@ def send_forward_recv_backward(output_tensor,
|
||||||
recv_next=True,
|
recv_next=True,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False) -> torch.Tensor:
|
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next stage in pipeline, while receives the gradient tensor from the
|
next stage in pipeline, while receives the gradient tensor from the
|
||||||
next stage in pipeline as the input gradient tensor of this stage.
|
next stage in pipeline as the input gradient tensor of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_tensor (:class:`torch.Tensor`): Tensor to be sent.
|
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||||
output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The input gradient tensor.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||||
"""
|
"""
|
||||||
if gpc.is_pipeline_last_stage():
|
if gpc.is_pipeline_last_stage():
|
||||||
output_tensor_grad = None
|
output_tensor_grad = None
|
||||||
else:
|
else:
|
||||||
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
|
_, output_tensor_grad = _communicate(object_send_next=output_tensor,
|
||||||
recv_next=recv_next,
|
recv_next=recv_next,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
next_rank=next_rank,
|
next_rank=next_rank,
|
||||||
|
@ -237,22 +280,22 @@ def send_backward_recv_forward(input_tensor_grad,
|
||||||
recv_prev=True,
|
recv_prev=True,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False) -> torch.Tensor:
|
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Batched communication operation. Sends the gradient tensor to the
|
"""Batched communication operation. Sends the gradient tensor to the
|
||||||
previous stage in pipeline, while receives the output tensor from the
|
previous stage in pipeline, while receives the output tensor from the
|
||||||
previous stage in pipeline as the input of this stage.
|
previous stage in pipeline as the input of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent.
|
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||||
input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The input tensor.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||||
"""
|
"""
|
||||||
if gpc.is_pipeline_first_stage():
|
if gpc.is_pipeline_first_stage():
|
||||||
input_tensor = None
|
input_tensor = None
|
||||||
else:
|
else:
|
||||||
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
|
input_tensor, _ = _communicate(object_send_prev=input_tensor_grad,
|
||||||
recv_prev=recv_prev,
|
recv_prev=recv_prev,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
|
@ -267,19 +310,19 @@ def send_forward_recv_forward(output_tensor,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False) -> torch.Tensor:
|
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Batched communication operation. Sends the input tensor to the
|
"""Batched communication operation. Sends the input tensor to the
|
||||||
next stage in pipeline, while receives the output tensor from the
|
next stage in pipeline, while receives the output tensor from the
|
||||||
previous stage in pipeline as the input of this stage.
|
previous stage in pipeline as the input of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_tensor (:class:`torch.Tensor`): Tensor to be sent.
|
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||||
input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The input tensor.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
||||||
"""
|
"""
|
||||||
input_tensor, _ = _communicate(tensor_send_next=output_tensor,
|
input_tensor, _ = _communicate(object_send_next=output_tensor,
|
||||||
recv_prev=recv_prev,
|
recv_prev=recv_prev,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
|
@ -295,19 +338,19 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||||
prev_rank=None,
|
prev_rank=None,
|
||||||
next_rank=None,
|
next_rank=None,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
scatter_gather_tensors=False) -> torch.Tensor:
|
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Batched communication operation. Sends the gradient tensor to the
|
"""Batched communication operation. Sends the gradient tensor to the
|
||||||
previous stage in pipeline, while receives the gradient tensor from the
|
previous stage in pipeline, while receives the gradient tensor from the
|
||||||
next member in pipeline as the input of this stage.
|
next member in pipeline as the input of this stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent.
|
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
||||||
output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor to be received.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`torch.Tensor`: The input gradient tensor.
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
||||||
"""
|
"""
|
||||||
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
|
_, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad,
|
||||||
recv_next=recv_next,
|
recv_next=recv_next,
|
||||||
recv_next_shape=output_grad_shape,
|
recv_next_shape=output_grad_shape,
|
||||||
prev_rank=prev_rank,
|
prev_rank=prev_rank,
|
||||||
|
@ -317,31 +360,32 @@ def send_backward_recv_backward(input_tensor_grad,
|
||||||
return output_tensor_grad
|
return output_tensor_grad
|
||||||
|
|
||||||
|
|
||||||
def send_forward_backward_recv_forward_backward(output_tensor,
|
def send_forward_backward_recv_forward_backward(
|
||||||
input_tensor_grad,
|
output_tensor,
|
||||||
input_tensor_shape,
|
input_tensor_grad,
|
||||||
output_grad_shape,
|
input_tensor_shape,
|
||||||
recv_prev=True,
|
output_grad_shape,
|
||||||
recv_next=True,
|
recv_prev=True,
|
||||||
prev_rank=None,
|
recv_next=True,
|
||||||
next_rank=None,
|
prev_rank=None,
|
||||||
dtype=torch.float,
|
next_rank=None,
|
||||||
scatter_gather_tensors=False) -> Tuple[torch.Tensor]:
|
dtype=torch.float,
|
||||||
|
scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
||||||
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
||||||
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
||||||
next stage and the input tensor from the previous stage.
|
next stage and the input tensor from the previous stage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_tensor (:class:`torch.Tensor`): Tensor sent to the next.
|
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
|
||||||
input_tensor_grad (:class:`torch.Tensor`): Tensor sent to the previous.
|
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
|
||||||
input_tensor_shape (:class:`torch.Size`): The shape of the tensor received from the previous.
|
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the previous.
|
||||||
output_grad_shape (:class:`torch.Size`): The shape of the tensor received from the next.
|
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received from the next.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor)
|
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
|
||||||
"""
|
"""
|
||||||
input_tensor, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
|
input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor,
|
||||||
tensor_send_prev=input_tensor_grad,
|
object_send_prev=input_tensor_grad,
|
||||||
recv_prev=recv_prev,
|
recv_prev=recv_prev,
|
||||||
recv_next=recv_next,
|
recv_next=recv_next,
|
||||||
recv_prev_shape=input_tensor_shape,
|
recv_prev_shape=input_tensor_shape,
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.utils import free_port, get_current_device
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
|
||||||
|
CONFIG = dict(parallel=dict(pipeline=2))
|
||||||
|
torch.manual_seed(123)
|
||||||
|
LIST_LENGTH = 3
|
||||||
|
TENSOR_SIZE = torch.Size((3, 3))
|
||||||
|
TENSOR_SIZE_LIST = [TENSOR_SIZE for i in range(LIST_LENGTH)]
|
||||||
|
data = torch.rand(3, 3)
|
||||||
|
data_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
|
||||||
|
grad = torch.rand(3, 3)
|
||||||
|
grad_list = [torch.rand(3, 3) for i in range(LIST_LENGTH)]
|
||||||
|
|
||||||
|
|
||||||
|
def check_send_recv_forward():
|
||||||
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
data_to_send = data.to(device)
|
||||||
|
data_list_to_send = []
|
||||||
|
for data_in_list in data_list:
|
||||||
|
data_list_to_send.append(data_in_list.to(device))
|
||||||
|
send_forward(data_to_send)
|
||||||
|
send_forward(data_list_to_send)
|
||||||
|
else:
|
||||||
|
device = torch.device('cuda:1')
|
||||||
|
data_recv = recv_forward(TENSOR_SIZE)
|
||||||
|
data_list_recv = recv_forward(TENSOR_SIZE_LIST)
|
||||||
|
data_to_check = data.to(device)
|
||||||
|
assert data_recv.equal(data_to_check)
|
||||||
|
for data_recv, data_send in zip(data_list_recv, data_list):
|
||||||
|
data_to_check = data_send.to(device)
|
||||||
|
assert data_recv.equal(data_to_check)
|
||||||
|
|
||||||
|
|
||||||
|
def check_send_recv_backward():
|
||||||
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
grad_recv = recv_backward(TENSOR_SIZE)
|
||||||
|
grad_list_recv = recv_backward(TENSOR_SIZE_LIST)
|
||||||
|
grad_to_check = grad.to(device)
|
||||||
|
assert grad_recv.equal(grad_to_check)
|
||||||
|
for grad_recv, grad_send in zip(grad_list_recv, grad_list):
|
||||||
|
grad_to_check = grad_send.to(device)
|
||||||
|
assert grad_recv.equal(grad_to_check)
|
||||||
|
else:
|
||||||
|
device = torch.device('cuda:1')
|
||||||
|
grad_to_send = grad.to(device)
|
||||||
|
grad_list_to_send = []
|
||||||
|
for grad_in_list in grad_list:
|
||||||
|
grad_list_to_send.append(grad_in_list.to(device))
|
||||||
|
send_backward(grad_to_send)
|
||||||
|
send_backward(grad_list_to_send)
|
||||||
|
|
||||||
|
|
||||||
|
def check_send_recv_forward_backward():
|
||||||
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
data_list_to_send = []
|
||||||
|
for data_in_list in data_list:
|
||||||
|
data_list_to_send.append(data_in_list.to(device))
|
||||||
|
grad_list_recv = send_forward_recv_backward(data_list_to_send, TENSOR_SIZE_LIST)
|
||||||
|
|
||||||
|
for grad_recv, grad_send in zip(grad_list_recv, grad_list):
|
||||||
|
grad_to_check = grad_send.to(device)
|
||||||
|
assert grad_recv.equal(grad_to_check)
|
||||||
|
else:
|
||||||
|
device = torch.device('cuda:1')
|
||||||
|
grad_list_to_send = []
|
||||||
|
for grad_in_list in grad_list:
|
||||||
|
grad_list_to_send.append(grad_in_list.to(device))
|
||||||
|
data_list_recv = send_backward_recv_forward(grad_list_to_send, TENSOR_SIZE_LIST)
|
||||||
|
for data_recv, data_send in zip(data_list_recv, data_list):
|
||||||
|
data_to_check = data_send.to(device)
|
||||||
|
assert data_recv.equal(data_to_check)
|
||||||
|
|
||||||
|
|
||||||
|
def check_layer(rank, world_size, port):
|
||||||
|
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
check_send_recv_forward()
|
||||||
|
check_send_recv_backward()
|
||||||
|
check_send_recv_forward_backward()
|
||||||
|
gpc.destroy()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_object_list_p2p():
|
||||||
|
world_size = 2
|
||||||
|
run_func = partial(check_layer, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_object_list_p2p()
|
Loading…
Reference in New Issue