|
|
|
@ -12,7 +12,6 @@ from functools import reduce
|
|
|
|
|
import operator |
|
|
|
|
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TensorShape = Union[torch.Size, List[int], Tuple[int]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -88,13 +87,11 @@ def _communicate(tensor_send_next=None,
|
|
|
|
|
|
|
|
|
|
if tensor_send_prev is not None or recv_prev: |
|
|
|
|
if prev_rank is None: |
|
|
|
|
prev_rank = gpc.get_prev_global_rank( |
|
|
|
|
ParallelMode.PIPELINE) |
|
|
|
|
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) |
|
|
|
|
|
|
|
|
|
if tensor_send_next is not None or recv_next: |
|
|
|
|
if next_rank is None: |
|
|
|
|
next_rank = gpc.get_next_global_rank( |
|
|
|
|
ParallelMode.PIPELINE) |
|
|
|
|
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) |
|
|
|
|
|
|
|
|
|
if tensor_send_prev is not None: |
|
|
|
|
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1] |
|
|
|
@ -183,9 +180,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
|
|
|
|
|
next_rank (int, optional): The rank of the recipient of the tensor. |
|
|
|
|
""" |
|
|
|
|
if not gpc.is_pipeline_last_stage(): |
|
|
|
|
_communicate(tensor_send_next=output_tensor, |
|
|
|
|
next_rank=next_rank, |
|
|
|
|
scatter_gather_tensors=scatter_gather_tensors) |
|
|
|
|
_communicate(tensor_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False): |
|
|
|
@ -338,15 +333,14 @@ def send_forward_backward_recv_forward_backward(output_tensor,
|
|
|
|
|
Returns: |
|
|
|
|
Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor) |
|
|
|
|
""" |
|
|
|
|
input_tensor, output_tensor_grad = _communicate( |
|
|
|
|
tensor_send_next=output_tensor, |
|
|
|
|
tensor_send_prev=input_tensor_grad, |
|
|
|
|
recv_prev=recv_prev, |
|
|
|
|
recv_next=recv_next, |
|
|
|
|
recv_prev_shape=input_tensor_shape, |
|
|
|
|
recv_next_shape=output_grad_shape, |
|
|
|
|
prev_rank=prev_rank, |
|
|
|
|
next_rank=next_rank, |
|
|
|
|
dtype=dtype, |
|
|
|
|
scatter_gather_tensors=scatter_gather_tensors) |
|
|
|
|
input_tensor, output_tensor_grad = _communicate(tensor_send_next=output_tensor, |
|
|
|
|
tensor_send_prev=input_tensor_grad, |
|
|
|
|
recv_prev=recv_prev, |
|
|
|
|
recv_next=recv_next, |
|
|
|
|
recv_prev_shape=input_tensor_shape, |
|
|
|
|
recv_next_shape=output_grad_shape, |
|
|
|
|
prev_rank=prev_rank, |
|
|
|
|
next_rank=next_rank, |
|
|
|
|
dtype=dtype, |
|
|
|
|
scatter_gather_tensors=scatter_gather_tensors) |
|
|
|
|
return input_tensor, output_tensor_grad |
|
|
|
|