2024-09-10 09:30:53 +00:00
import os
import traceback
from contextlib import contextmanager
from time import sleep
from typing import Callable , List , Optional
2023-11-02 02:21:24 +00:00
import torch
2024-09-10 09:30:53 +00:00
import torch . distributed as dist
from torch . utils . _pytree import tree_map
2023-12-14 09:52:05 +00:00
2024-07-22 03:40:34 +00:00
def assert_loose_close ( a , b , dtype : torch . dtype = torch . float32 , name = " " ) :
assert loose_close ( a , b , dtype ) , f " { name } not close { a . mean ( ) } { b . mean ( ) } "
def loose_close ( a , b , dtype : torch . dtype = torch . float32 ) :
2023-12-14 09:52:05 +00:00
rtol = None
atol = None
if dtype is torch . float16 :
rtol = 5e-2
atol = 5e-4
elif dtype is torch . bfloat16 :
rtol = 4e-3
atol = 4e-3
2024-07-15 06:43:27 +00:00
else :
assert dtype is torch . float32
2024-07-22 03:40:34 +00:00
rtol = 1e-05
atol = 1e-08
2023-12-14 09:52:05 +00:00
a = a . detach ( ) . to ( dtype )
b = b . detach ( ) . to ( dtype ) . to ( a . device )
2024-07-22 03:40:34 +00:00
return torch . allclose ( a , b , rtol = rtol , atol = atol )
2024-09-10 09:30:53 +00:00
def check_model_equal ( model1 , model2 , dtype ) :
2024-07-22 03:40:34 +00:00
assert set ( model1 . state_dict ( ) . keys ( ) ) == set ( model2 . state_dict ( ) . keys ( ) )
for i , ( ( name , p1 ) , p2 ) in enumerate ( zip ( model1 . named_parameters ( ) , model2 . parameters ( ) ) ) :
2024-09-10 09:30:53 +00:00
assert_loose_close ( p1 , p2 , dtype , name = name )
@contextmanager
def distributed_debug_mode ( num_stacks : int = 1 , funcs_to_patch : Optional [ List [ Callable ] ] = None , enable = True ) :
if enable :
assert (
os . environ . get ( " CUDA_LAUNCH_BLOCKING " , " 0 " ) == " 1 "
) , f " Expect CUDA_LAUNCH_BLOCKING=1, got { os . environ . get ( ' CUDA_LAUNCH_BLOCKING ' , ' 0 ' ) } "
if funcs_to_patch is None :
funcs_to_patch = [
dist . all_reduce ,
dist . all_reduce_coalesced ,
dist . all_gather ,
dist . all_gather_coalesced ,
dist . all_gather_into_tensor ,
dist . all_to_all ,
dist . all_to_all_single ,
dist . reduce_scatter ,
]
original_funcs = { }
patched_funcs = { }
def make_patched ( func ) :
def patched_func ( * args , * * kwargs ) :
stack = traceback . format_stack ( )
def format_node ( node ) :
if isinstance ( node , torch . Tensor ) :
return f " { node . shape } "
elif isinstance ( node , list ) :
return f " [ { ' , ' . join ( [ format_node ( n ) for n in node ] ) } ] "
return str ( node )
args_str , kwargs_str = tree_map ( format_node , ( args , kwargs ) )
en = len ( stack ) - 1
st = max ( 0 , en - num_stacks )
dist . barrier ( )
sleep ( 0.001 * dist . get_rank ( ) )
print (
f " [Rank { dist . get_rank ( ) } - { func . __name__ } - { dist . get_process_group_ranks ( kwargs . get ( ' group ' , dist . group . WORLD ) ) } ]: Called from { ' ' . join ( stack [ st : en ] ) } args= { args_str } kwargs= { kwargs_str } \n "
)
dist . barrier ( )
return func ( * args , * * kwargs )
return patched_func
if enable :
for func in funcs_to_patch :
original_funcs [ func . __name__ ] = getattr ( dist , func . __name__ )
patched_funcs [ func . __name__ ] = make_patched ( func )
setattr ( dist , func . __name__ , patched_funcs [ func . __name__ ] )
try :
yield
finally :
for func_name , original_func in original_funcs . items ( ) :
setattr ( dist , func_name , original_func )