@ -10,6 +10,18 @@ from torch.distributed import ReduceOp
SUPPORT_TORCH_COMPILE = Version ( torch . __version__ ) > = Version ( " 2.3.0 " )
class Handle :
def __init__ ( self , handles = [ ] , remain_ops = None ) - > None :
self . handles = handles
self . remain_ops = remain_ops
def wait ( self ) :
for handle in self . handles :
handle . wait ( )
if self . remain_ops :
self . remain_ops ( )
def cast_to_fp8 ( inp : torch . Tensor , fp8_format = " e4m3 " , per_channel_scale = False ) - > Tuple [ torch . Tensor , torch . Tensor ] :
r """
casting torch Tensor into specified fp8 tensor with per - channel scaling or per - tensor scaling .
@ -68,7 +80,9 @@ def cast_from_fp8(
return ret . to ( ret_type )
def all_reduce_fp8 ( tensor : torch . Tensor , fp8_format = " e4m3 " , op = ReduceOp . SUM , group = None ) - > None :
def all_reduce_fp8 (
tensor : torch . Tensor , fp8_format = " e4m3 " , op = ReduceOp . SUM , group = None , async_op : bool = False
) - > Optional [ Handle ] :
r """
This is an in - place operation for compressed all_reduce using fp8 .
It works like dist . all_reduce but during communication the data is cast to fp8 format .
@ -105,6 +119,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
scale_list = [ torch . ones ( 1 , dtype = scale . dtype , device = input_device ) for _ in range ( world_size ) ]
dist . all_gather ( scale_list , scale , group = group )
summed_out = torch . zeros_like ( output_chunks [ 0 ] ) . to ( input_type )
for scale , out in zip ( scale_list , output_chunks ) :
out = out . view ( fp8_type )
summed_out + = cast_from_fp8 ( out , scale , input_type )
@ -113,19 +128,28 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
summed_out . div_ ( world_size )
summed_out_fp8 , scale = cast_to_fp8 ( summed_out , fp8_format = fp8_format )
dist. all_gather ( scale_list , scale , group = grou p)
gather_scale_handle = dist. all_gather ( scale_list , scale , group = grou p, async_op = async_o p)
tensor_list = [ torch . empty_like ( summed_out_fp8 . view ( torch . uint8 ) ) for _ in range ( world_size ) ]
dist . all_gather ( tensor_list , summed_out_fp8 . view ( torch . uint8 ) , group = group )
gather_tensor_handle = dist . all_gather (
tensor_list , summed_out_fp8 . view ( torch . uint8 ) , group = group , async_op = async_op
)
def cat_op ( ) :
for i in range ( world_size ) :
tensor_list [ i ] = tensor_list [ i ] . view ( fp8_type ) . to ( input_type ) * scale_list [ i ] . to ( input_device )
tensor_list [ i ] = tensor_list [ i ] . view ( fp8_type ) . to ( input_type ) * scale_list [ i ]
out = torch . cat ( tensor_list , dim = 0 )
tensor . copy_ ( out [ : input_size ] . view ( input_shape ) . to ( input_type ) )
if async_op :
return Handle ( [ gather_scale_handle , gather_tensor_handle ] , cat_op )
else :
cat_op ( )
def all_to_all_single_fp8 (
output , input , output_split_sizes = None , input_split_sizes = None , fp8_format = " e5m2 " , group = None , async_op = False
) - > None :
) - > Optional [ Handle ] :
r """
This is an in - place operation for compressed all_reduce using fp8 .
It works like dist . all_to_all_single but during communication the data is cast to fp8 format .
@ -163,9 +187,11 @@ def all_to_all_single_fp8(
else :
output_chunks = [ torch . empty_like ( input_chunks [ 0 ] ) for _ in range ( world_size ) ]
dist. all_to_all ( output_chunks , input_chunks , group = grou p)
chunk_handle = dist. all_to_all ( output_chunks , input_chunks , group = grou p, async_op = async_o p)
scale_list = [ torch . ones ( 1 , dtype = scale . dtype , device = input_device ) for _ in range ( world_size ) ]
dist . all_gather ( scale_list , scale , group = group )
scale_hanle = dist . all_gather ( scale_list , scale , group = group , async_op = async_op )
def cast_op ( ) :
cast_output_chunk = [
cast_from_fp8 ( out . view ( fp8_type ) , scale , input_type ) for scale , out in zip ( scale_list , output_chunks )
]
@ -178,6 +204,11 @@ def all_to_all_single_fp8(
outputs_shape = input_shape
output . data = tensor_out . view ( outputs_shape ) . to ( input_type )
if async_op :
return Handle ( [ chunk_handle , scale_hanle ] , cast_op )
else :
cast_op ( )
def cast_to_fp8_pipeline ( inp : Any ) - > None :
"""
@ -250,7 +281,9 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
del inp [ " dtype " ]
def reduce_scatter_fp8 ( output : torch . Tensor , input_list , group , fp8_format = " e5m2 " ) - > None :
def reduce_scatter_fp8 (
output : torch . Tensor , input_list , group , fp8_format = " e5m2 " , async_op : bool = False
) - > Optional [ Handle ] :
r """
This is an in - place operation for compressed reduce_scatter using fp8 .
It works like dist . reduce_scatter but during communication the data is cast to fp8 format .
@ -277,15 +310,21 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
cast_input_list . append ( ret )
output_chunks . append ( torch . empty_like ( ret ) )
output_scale_list . append ( torch . empty_like ( scale ) )
dist. all_to_all ( output_chunks , cast_input_list , group = grou p)
dist. all_to_all ( output_scale_list , scale_list , group = grou p)
chunk_handle = dist. all_to_all ( output_chunks , cast_input_list , group = grou p, async_op = async_o p)
scale_handle = dist. all_to_all ( output_scale_list , scale_list , group = grou p, async_op = async_o p)
def cast_op ( ) :
summed_out = torch . zeros_like ( output_chunks [ 0 ] ) . to ( input_type )
for scale , out in zip ( output_scale_list , output_chunks ) :
out = out . view ( fp8_type )
summed_out + = cast_from_fp8 ( out , scale , input_type )
output . data = summed_out
if async_op :
return Handle ( [ chunk_handle , scale_handle ] , cast_op )
else :
cast_op ( )
def fp8_compress_ddp_grad_comm_hook_async (
process_group : dist . ProcessGroup ,
@ -500,7 +539,8 @@ def all_gather_into_tensor_flat_fp8(
output_shape : torch . Size ,
group : dist . ProcessGroup ,
fp8_format : str = " e4m3 " ,
) :
async_op : bool = False ,
) - > Optional [ Handle ] :
""" all gather into tensor in fp8 format
Args :
@ -547,15 +587,25 @@ def all_gather_into_tensor_flat_fp8(
scale = fp8_max / per_tensor_max
fp8_input = ( scale * input_tensor . float ( ) ) . to ( fp8_type )
scale_inv = 1.0 / scale
buffer = torch . empty_like ( output_tensor , dtype = fp8_type )
dist . all_gather_into_tensor ( buffer . view ( torch . uint8 ) , fp8_input . view ( torch . uint8 ) , group = group )
tensor_handle = dist . all_gather_into_tensor (
buffer . view ( torch . uint8 ) , fp8_input . view ( torch . uint8 ) , group = group , async_op = async_op
)
def cast_op ( ) :
numel = output_shape . numel ( )
valid_buffer = buffer [ : numel ] . reshape ( output_shape )
valid_buffer = cast_from_fp8 ( valid_buffer , scale_inv , input_type , per_channel_scale = ( len ( output_shape ) == 2 ) )
output_tensor [ : numel ] . copy_ ( valid_buffer . view ( - 1 ) )
if async_op :
return Handle ( [ tensor_handle ] , cast_op )
else :
cast_op ( )
def all_to_all_fp8 ( output_list , input_list , group = None , fp8_format = " e5m2 " ) :
def all_to_all_fp8 ( output_list , input_list , group = None , fp8_format = " e5m2 " , async_op = False ):
world_size = dist . get_world_size ( group )
@ -573,17 +623,23 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
output_scale_list = [ torch . empty_like ( x ) for x in scale_list ]
output_tensor_list = [ torch . empty_like ( x ) for x in tensor_list ]
dist. all_to_all ( output_tensor_list , tensor_list , group = grou p)
dist. all_to_all ( output_scale_list , scale_list , group = grou p)
tensor_hanle = dist. all_to_all ( output_tensor_list , tensor_list , group = grou p, async_op = async_o p)
scale_handle = dist. all_to_all ( output_scale_list , scale_list , group = grou p, async_op = async_o p)
def cast_op ( ) :
for i in range ( world_size ) :
scale = output_scale_list [ i ]
tensor = output_tensor_list [ i ]
tensor = tensor . view ( fp8_type )
output_list [ i ] . copy_ ( cast_from_fp8 ( tensor , scale , input_type ) )
if async_op :
return Handle ( [ tensor_hanle , scale_handle ] , cast_op )
else :
cast_op ( )
def gather_fp8 ( output_list , input_ , group = None , fp8_format = " e5m2 " ) :
def gather_fp8 ( output_list , input_ , group = None , fp8_format = " e5m2 " , async_op : bool = False ) - > Optional [ Handle ] :
world_size = dist . get_world_size ( group )
@ -593,14 +649,20 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
input_ = ret . view ( torch . uint8 )
tensor_list = [ torch . empty_like ( input_ ) for _ in range ( world_size ) ]
scale_list = [ torch . ones ( 1 , dtype = scale . dtype , device = input_ . device ) for _ in range ( world_size ) ]
dist. all_gather ( tensor_list , input_ , group = grou p)
dist. all_gather ( scale_list , scale , group = grou p)
chunk_handle = dist. all_gather ( tensor_list , input_ , group = grou p, async_op = async_o p)
scale_hanle = dist. all_gather ( scale_list , scale , group = grou p, async_op = async_o p)
def cast_op ( ) :
for i in range ( world_size ) :
output = tensor_list [ i ] . view ( fp8_type )
scale = scale_list [ i ]
output_list [ i ] . copy_ ( cast_from_fp8 ( output , scale , input_type ) )
if async_op :
return Handle ( [ chunk_handle , scale_hanle ] , cast_op )
else :
cast_op ( )
class _LinearFp8 ( torch . autograd . Function ) :
@staticmethod