@ -23,7 +23,7 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
chunk_tensor ( bool , optional ) : whether to chunk tensor , defaults to False
Returns :
Tuple [ Union [ torch . Size , List [ int ] , Tuple [ int ] ] , bool ] : exact tensor shape , whether to chunk tensor
Tuple [ Union [ : class : ` torch . Size ` , List [ int ] , Tuple [ int ] ] , bool ] : exact tensor shape , whether to chunk tensor
"""
if chunk_tensor :
tensor_chunk_shape = reduce ( operator . mul , tensor_shape , 1 )
@ -38,31 +38,38 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
return tensor_chunk_shape , chunk_tensor
def _communicate ( tensor_send_next = None ,
tensor_send_prev = None ,
recv_prev = False ,
recv_next = False ,
recv_prev_shape = None ,
recv_next_shape = None ,
prev_rank = None ,
next_rank = None ,
dtype = None ,
scatter_gather_tensors = False ) :
def _communicate ( tensor_send_next : torch . Tensor = None ,
tensor_send_prev : torch . Tensor = None ,
recv_prev : bool = False ,
recv_next : bool = False ,
recv_prev_shape : TensorShape = None ,
recv_next_shape : TensorShape = None ,
prev_rank : int = None ,
next_rank : int = None ,
dtype : torch . dtype = None ,
scatter_gather_tensors : bool = False ) - > Tuple [ torch . Tensor ] :
"""
Adapted from megatron . p2p_communication .
Communicate tensors between stages . Used as helper method in other
communication methods that are used in pipeline schedule .
Takes the following arguments :
tensor_send_next : tensor to send to next rank ( no tensor sent if
tensor_send_next ( : class : ` torch . Tensor ` ) : tensor to send to next rank ( no tensor sent if
set to None ) .
tensor_send_prev : tensor to send to prev rank ( no tensor sent if
tensor_send_prev ( : class : ` torch . Tensor ` ) : tensor to send to prev rank ( no tensor sent if
set to None ) .
recv_prev : boolean for whether tensor should be received from
recv_prev ( bool ) : boolean for whether tensor should be received from
previous rank .
recv_next : boolean for whether tensor should be received from
recv_next ( bool ) : boolean for whether tensor should be received from
next rank .
recv_prev_shape ( TensorShape ) : shape of the tensor to be received from the previous stage , defualts to None .
recv_next_shape ( TensorShape ) : shape of the tensor to be received from the next stage , defualts to None .
prev_rank ( int ) : the rank of the previous pipeline stage , defualts to None ,
next_rank ( int ) : the rank of the next pipeline stage , defualts to None ,
dtype ( torch . dtype ) : data type of intermediate buffers , defaults to None
scatter_gather_tensors ( bool ) : whether to scatter and gather tensor between pipeline stages , defaults to False
Returns :
( tensor_recv_prev , tensor_recv_next )
Tuple [ torch . Tensor ] : returns tensor_recv_prev , tensor_recv_next
"""
# Create placeholder tensors for receive in forward and backward directions
@ -130,7 +137,7 @@ def _communicate(tensor_send_next=None,
return tensor_recv_prev , tensor_recv_next
def recv_forward ( input_tensor_shape , prev_rank = None , dtype = torch . float , scatter_gather_tensors = False ) :
def recv_forward ( input_tensor_shape , prev_rank = None , dtype = torch . float , scatter_gather_tensors = False ) - > torch . Tensor :
""" Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args :
@ -151,7 +158,7 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
return input_tensor
def recv_backward ( output_grad_shape , next_rank = None , dtype = torch . float , scatter_gather_tensors = False ) :
def recv_backward ( output_grad_shape , next_rank = None , dtype = torch . float , scatter_gather_tensors = False ) - > torch . Tensor :
""" Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args :
@ -172,7 +179,7 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_
return output_tensor_grad
def send_forward ( output_tensor , next_rank = None , scatter_gather_tensors = False ) :
def send_forward ( output_tensor , next_rank = None , scatter_gather_tensors = False ) - > None :
""" Sends the input tensor to the next stage in pipeline.
Args :
@ -183,7 +190,7 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
_communicate ( tensor_send_next = output_tensor , next_rank = next_rank , scatter_gather_tensors = scatter_gather_tensors )
def send_backward ( input_tensor_grad , prev_rank = None , scatter_gather_tensors = False ) :
def send_backward ( input_tensor_grad , prev_rank = None , scatter_gather_tensors = False ) - > None :
""" Sends the gradient tensor to the previous stage in pipeline.
Args :
@ -201,7 +208,7 @@ def send_forward_recv_backward(output_tensor,
recv_next = True ,
next_rank = None ,
dtype = torch . float ,
scatter_gather_tensors = False ) :
scatter_gather_tensors = False ) - > torch . Tensor :
""" Batched communication operation. Sends the input tensor to the
next stage in pipeline , while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage .
@ -230,7 +237,7 @@ def send_backward_recv_forward(input_tensor_grad,
recv_prev = True ,
prev_rank = None ,
dtype = torch . float ,
scatter_gather_tensors = False ) :
scatter_gather_tensors = False ) - > torch . Tensor :
""" Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline , while receives the output tensor from the
previous stage in pipeline as the input of this stage .
@ -260,7 +267,7 @@ def send_forward_recv_forward(output_tensor,
prev_rank = None ,
next_rank = None ,
dtype = torch . float ,
scatter_gather_tensors = False ) :
scatter_gather_tensors = False ) - > torch . Tensor :
""" Batched communication operation. Sends the input tensor to the
next stage in pipeline , while receives the output tensor from the
previous stage in pipeline as the input of this stage .
@ -288,7 +295,7 @@ def send_backward_recv_backward(input_tensor_grad,
prev_rank = None ,
next_rank = None ,
dtype = torch . float ,
scatter_gather_tensors = False ) :
scatter_gather_tensors = False ) - > torch . Tensor :
""" Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline , while receives the gradient tensor from the
next member in pipeline as the input of this stage .
@ -319,7 +326,7 @@ def send_forward_backward_recv_forward_backward(output_tensor,
prev_rank = None ,
next_rank = None ,
dtype = torch . float ,
scatter_gather_tensors = False ) :
scatter_gather_tensors = False ) - > Tuple [ torch . Tensor ] :
""" Batched communication operation. Sends the input tensor to the next stage in pipeline and
the gradient tensor to the previous stage , while receives the input gradient tensor from the
next stage and the input tensor from the previous stage .