@ -5,9 +5,12 @@ import io
import pickle
import re
from typing import Any , List , Optional , Union
from collections import namedtuple
import torch
import torch . distributed as dist
from dataclasses import dataclass
from enum import Enum
from packaging . version import Version
from torch . distributed import ProcessGroup
from torch . distributed import distributed_c10d as c10d
@ -45,6 +48,21 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
return unpickle
def check_for_nccl_backend ( group ) :
pg = group or c10d . _get_default_group ( )
# Gate PG wrapper check on Gloo availability.
if c10d . _GLOO_AVAILABLE :
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance ( pg , c10d . _ProcessGroupWrapper ) :
pg = pg . wrapped_pg
return (
c10d . is_nccl_available ( ) and
pg . name ( ) == c10d . Backend . NCCL
)
def _broadcast_object_list (
object_list : List [ Any ] , src : int , group : ProcessGroup , device : Optional [ Union [ torch . device , str , int ] ] = None
) :
@ -65,7 +83,7 @@ def _broadcast_object_list(
c10d . _warn_not_in_group ( " broadcast_object_list " )
return
is_nccl_backend = c 10d. _c heck_for_nccl_backend( group )
is_nccl_backend = c heck_for_nccl_backend( group )
current_device = None
if device is not None :
@ -113,7 +131,7 @@ def _broadcast_object_list(
if my_rank != src :
for i , obj_size in enumerate ( object_sizes_tensor ) :
obj_view = object_tensor [ offset : offset + obj_size ]
obj_view = object_tensor [ offset : offset + obj_size ]
obj_view = obj_view . type ( torch . uint8 )
if obj_view . device != torch . device ( " cpu " ) :
obj_view = obj_view . cpu ( )
@ -131,6 +149,258 @@ def _broadcast_object_list(
object_list [ i ] = unpickle_object
def check_device ( group ) :
is_nccl_backend = check_for_nccl_backend ( group )
current_device = None
current_device = torch . device ( " cpu " )
if is_nccl_backend :
current_device = torch . device ( " cuda " , torch . cuda . current_device ( ) )
return current_device , is_nccl_backend
TensorMetadata = namedtuple ( ' TensorMetadata ' , [ ' key ' , ' shape ' , ' dtype ' , ' requires_grad ' ] )
class P2PDataType ( Enum ) :
serialization = 0
tensor = 1
list = 2
dict = 3
@dataclass
class P2PMetadata :
data_type : P2PDataType
content : Union [ List [ TensorMetadata ] , TensorMetadata , Any ]
def filling_ops_queue ( obj , comm_op , comm_rank , ops_queue , group ) :
if isinstance ( obj , torch . Tensor ) :
obj = obj . contiguous ( )
op_to_add = dist . P2POp ( comm_op , obj , comm_rank , group )
ops_queue . append ( op_to_add )
else :
for tensor_to_comm in obj :
tensor_to_comm = tensor_to_comm . contiguous ( )
op_to_add = dist . P2POp ( comm_op , tensor_to_comm , comm_rank , group )
ops_queue . append ( op_to_add )
def create_recv_buffer ( p2p_metadata : P2PMetadata , current_device ) :
if p2p_metadata . data_type == P2PDataType . tensor :
metadata = p2p_metadata . content
tensor_recv = torch . empty ( metadata . shape , requires_grad = metadata . requires_grad , device = current_device , dtype = metadata . dtype )
return tensor_recv
elif p2p_metadata . data_type in ( P2PDataType . list , P2PDataType . dict ) :
buffer_recv = [ ]
for metadata in p2p_metadata . content :
tensor_recv = torch . empty ( metadata . shape , requires_grad = metadata . requires_grad , device = current_device , dtype = metadata . dtype )
buffer_recv . append ( tensor_recv )
return buffer_recv
else :
raise ValueError ( f " Unknown data_type: { p2p_metadata . data_type } " )
def _batch_send_recv_tensor ( send_tensor_list , recv_tensor_metadata , send_dst , recv_src , send_group , recv_group , current_device ) :
buffer_recv = None
if recv_tensor_metadata is not None :
buffer_recv = create_recv_buffer ( recv_tensor_metadata , current_device )
ops = [ ]
if send_dst is not None :
filling_ops_queue ( send_tensor_list , dist . isend , send_dst , ops , send_group )
if recv_src is not None :
assert buffer_recv is not None
filling_ops_queue ( buffer_recv , dist . irecv , recv_src , ops , recv_group )
if len ( ops ) > 0 :
reqs = dist . batch_isend_irecv ( ops )
for req in reqs :
req . wait ( )
torch . cuda . synchronize ( )
# Remove synchronization according to Pytorch's documentation
# However, the Megatron-LM does synchronization here
# https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
# In case there is potential error, uncomment the following `torch.cuda.synchronize()`
# torch.cuda.synchronize()
return buffer_recv
def _send_recv_serialization_object (
object : Any ,
send_dst : Optional [ int ] , recv_src : Optional [ int ] ,
send_group : Optional [ ProcessGroup ] , recv_group : Optional [ ProcessGroup ] ,
current_device ,
is_nccl_backend ) :
ops = [ ]
send_object_tensor = None
if object is not None and send_dst is not None :
if Version ( torch . __version__ ) > = Version ( " 1.13.0 " ) :
send_object_tensor , send_object_size_tensor = c10d . _object_to_tensor ( object , device = current_device )
else :
send_object_tensor , send_object_size_tensor = c10d . _object_to_tensor ( object )
if is_nccl_backend :
send_object_size_tensor = send_object_size_tensor . to ( current_device )
send_object_tensor = send_object_tensor . to ( current_device )
filling_ops_queue ( send_object_size_tensor , dist . isend , send_dst , ops , send_group )
recv_object_size_tensor = None
if recv_src is not None :
recv_object_size_tensor = torch . empty ( 1 , dtype = torch . long )
if is_nccl_backend :
recv_object_size_tensor = recv_object_size_tensor . to ( current_device )
filling_ops_queue ( recv_object_size_tensor , dist . irecv , recv_src , ops , recv_group )
if len ( ops ) > 0 :
reqs = dist . batch_isend_irecv ( ops )
for req in reqs :
req . wait ( )
torch . cuda . synchronize ( )
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
ops = [ ]
if send_dst is not None and send_object_tensor is not None :
filling_ops_queue ( send_object_tensor , dist . isend , send_dst , ops , send_group )
recv_object_tensor = None
if recv_src is not None and recv_object_size_tensor is not None :
recv_object_tensor = torch . empty ( recv_object_size_tensor . item ( ) , dtype = torch . uint8 )
if is_nccl_backend :
recv_object_tensor = recv_object_tensor . to ( current_device )
filling_ops_queue ( recv_object_tensor , dist . irecv , recv_src , ops , recv_group )
if len ( ops ) > 0 :
reqs = dist . batch_isend_irecv ( ops )
for req in reqs :
req . wait ( )
torch . cuda . synchronize ( )
# See the comment in `_batch_send_recv_tensor`
# torch.cuda.synchronize()
if recv_object_tensor is not None and recv_object_size_tensor is not None :
recv_object_tensor = recv_object_tensor . type ( torch . uint8 )
if recv_object_tensor . device != torch . device ( " cpu " ) :
recv_object_tensor = recv_object_tensor . cpu ( )
unpickle_object = _cuda_safe_tensor_to_object (
recv_object_tensor , recv_object_size_tensor . item ( ) )
if (
isinstance ( unpickle_object , torch . Tensor )
and unpickle_object . device . index != torch . cuda . current_device ( )
) :
unpickle_object = unpickle_object . cuda ( )
return unpickle_object
def _check_if_fast_send_available ( object ) :
if type ( object ) is torch . Tensor :
return True
elif type ( object ) is list :
is_list_of_tensor = all ( [ type ( v ) is torch . Tensor for v in object ] )
return is_list_of_tensor
elif type ( object ) is dict :
is_dict_of_tensor = all ( [ type ( k ) is str and type (
v ) is torch . Tensor for k , v in object . items ( ) ] )
return is_dict_of_tensor
return False
def _communicate (
object ,
send_dst : Optional [ int ] ,
recv_src : Optional [ int ] ,
send_group : Optional [ ProcessGroup ] = None ,
recv_group : Optional [ ProcessGroup ] = None ,
) - > Any :
if c10d . _rank_not_in_group ( send_group ) or c10d . _rank_not_in_group ( recv_group ) :
c10d . _warn_not_in_group ( " _communicate " )
return
current_send_device , is_send_nccl_backend = check_device ( send_group )
current_recv_device , is_recv_nccl_backend = check_device ( recv_group )
is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
assert current_send_device == current_recv_device
current_device = current_send_device
assert ( send_dst is not None ) or ( recv_src is not None )
can_fast_send = False
send_metadata = None
if send_dst is not None :
can_fast_send = _check_if_fast_send_available ( object ) and is_nccl_backend
if not can_fast_send :
send_metadata = P2PMetadata ( P2PDataType . serialization , object )
else :
if type ( object ) is torch . Tensor :
data_type = P2PDataType . tensor
content = TensorMetadata ( None , object . shape , object . dtype , object . requires_grad )
elif type ( object ) is list :
data_type = P2PDataType . list
content = [ ]
for v in object :
content . append ( TensorMetadata ( None , v . shape , v . dtype , v . requires_grad ) )
elif type ( object ) is dict :
data_type = P2PDataType . dict
content = [ ]
for k , v in object . items ( ) :
content . append ( TensorMetadata ( k , v . shape , v . dtype , v . requires_grad ) )
else :
raise ValueError ( ' Cannot send object of type {} ' . format ( type ( object ) ) )
send_metadata = P2PMetadata ( data_type , content )
recv_metadata = _send_recv_serialization_object ( send_metadata , send_dst , recv_src , send_group , recv_group , current_device , is_nccl_backend )
if recv_metadata is not None :
assert type ( recv_metadata ) is P2PMetadata
if recv_metadata . data_type == P2PDataType . serialization :
return recv_metadata . content
if not can_fast_send and send_dst is not None :
return
send_tensor_list = None
if type ( object ) is torch . Tensor :
send_tensor_list = object
elif type ( object ) is list :
send_tensor_list = object
elif type ( object ) is dict :
send_tensor_list = list ( object . values ( ) )
recv_buffer = _batch_send_recv_tensor ( send_tensor_list , recv_metadata , send_dst , recv_src , send_group , recv_group , current_device )
if recv_metadata is not None :
assert recv_buffer is not None
if recv_metadata . data_type in [ P2PDataType . tensor , P2PDataType . list ] :
return recv_buffer
elif recv_metadata . data_type == P2PDataType . dict :
return {
k : v
for k , v in zip (
[ m . key for m in recv_metadata . content ] ,
recv_buffer ,
)
}
else :
raise ValueError ( ' Unknown data type {} ' . format ( recv_metadata . data_type ) )
def _send_object ( object : Any , src : int , dst : int , group : ProcessGroup ) - > None :
""" send anything to dst rank
@ -141,8 +411,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
Returns :
None
"""
# then broadcast safely
_broadcast_object_list ( [ object ] , src , group )
_communicate ( object , send_dst = dst , recv_src = None , send_group = group )
def _recv_object ( src : int , dst : int , group : ProcessGroup ) - > Any :
@ -154,10 +423,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
Returns :
Any : Object received from src .
"""
object_list = [ None ]
_broadcast_object_list ( object_list , src , group )
return object_list [ 0 ]
return _communicate ( None , send_dst = None , recv_src = src , recv_group = group )
def _p2p_comm (
@ -302,6 +568,64 @@ class PipelineP2PCommunication:
cur_rank = self . stage_manager . get_rank ( )
_send_object ( input_object , cur_rank , prev_rank , self . stage_manager . get_p2p_process_group ( cur_rank , prev_rank ) )
def send_forward_recv_backward ( self , input_object : Any , next_rank : int = None ) - > Any :
""" Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
Args :
input_object ( Any ) : Object to be sent .
next_rank ( int , optional ) : The rank of the sender and recipient of the tensor
"""
if next_rank is None :
next_rank = self . stage_manager . get_next_rank ( )
cur_rank = self . stage_manager . get_rank ( )
group = self . stage_manager . get_p2p_process_group ( cur_rank , next_rank )
return _communicate (
input_object , next_rank , next_rank ,
send_group = group , recv_group = group ,
)
def send_backward_recv_forward ( self , input_object : Any , prev_rank : int = None ) - > Any :
""" Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
Args :
input_object ( Any ) : Object to be sent .
prev_rank ( int , optional ) : The rank of the sender and recipient of the tensor
"""
if prev_rank is None :
prev_rank = self . stage_manager . get_prev_rank ( )
cur_rank = self . stage_manager . get_rank ( )
group = self . stage_manager . get_p2p_process_group ( prev_rank , cur_rank )
return _communicate (
input_object , prev_rank , prev_rank ,
send_group = group , recv_group = group ,
)
def send_forward_recv_forward ( self , input_object : Any , prev_rank : int = None , next_rank : int = None ) - > Any :
""" Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
Args :
input_object ( Any ) : Object to be sent .
prev_rank ( int , optional ) : The rank of the sender of the tensor
next_rank ( int , optional ) : The rank of the recipient of the tensor
"""
if prev_rank is None :
prev_rank = self . stage_manager . get_prev_rank ( )
if next_rank is None :
next_rank = self . stage_manager . get_next_rank ( )
cur_rank = self . stage_manager . get_rank ( )
recv_group = self . stage_manager . get_p2p_process_group ( prev_rank , cur_rank )
send_group = self . stage_manager . get_p2p_process_group ( cur_rank , next_rank )
return _communicate (
input_object ,
send_dst = next_rank ,
recv_src = prev_rank ,
send_group = send_group ,
recv_group = recv_group ,
)
def p2p_communicate (
self , output_object : Any , recv_pre : bool , peer : int = None , comm_dtype : torch . dtype = torch . float16
) - > None :