@ -96,6 +96,7 @@ class GeminiDDP(ModelWrapper):
master_weights : bool = True ,
extra_dp_group : Optional [ ProcessGroup ] = None ,
verbose : bool = False ,
enable_async_reduce : bool = True ,
) - > None :
assert mixed_precision in ( torch . float16 , torch . bfloat16 )
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@ -178,6 +179,7 @@ class GeminiDDP(ModelWrapper):
if is_ddp_ignored ( p ) :
continue
if p . requires_grad :
assert not hasattr ( p , " _grad_handle " )
p . _grad_handle = p . register_hook (
partial (
GeminiDDP . grad_handle ,
@ -187,6 +189,7 @@ class GeminiDDP(ModelWrapper):
master_weights = self . master_weights ,
enable_gradient_accumulation = self . enable_gradient_accumulation ,
p = p ,
async_reduce = enable_async_reduce ,
)
)
@ -334,6 +337,11 @@ class GeminiDDP(ModelWrapper):
setattr ( param , " _gemini_reduced " , False )
def _post_backward ( self ) :
for param in self . param2name :
if hasattr ( param , " _release_grad_chunk_cb " ) :
param . _release_grad_chunk_cb ( )
delattr ( param , " _release_grad_chunk_cb " )
if self . chunk_manager . accessed_mem != 0 :
error_params = [ " Reduction failed at followed parameters: " ]
for param in self . param2name :
@ -371,6 +379,7 @@ class GeminiDDP(ModelWrapper):
master_weights : bool ,
enable_gradient_accumulation : bool ,
p : nn . Parameter ,
async_reduce : bool ,
) :
setattr ( p , " _gemini_reduced " , True )
empty_grad = torch . empty_like ( grad )
@ -406,31 +415,57 @@ class GeminiDDP(ModelWrapper):
grad_chunk . copy_tensor_to_chunk_slice ( p , grad , update_ptr = chunk_manager . reuse_fp16_chunk )
else :
grad_chunk . add_tensor_to_chunk_slice ( p , grad )
reduced = chunk_manager . reduce_chunk ( grad_chunk )
if reduced :
if not chunk_manager . reuse_fp16_chunk :
if chunk . keep_gathered :
chunk_manager . fake_release_chunk ( chunk )
else :
chunk_manager . release_chunk ( chunk )
if grad_chunk . is_gathered :
grad_chunk . cuda_global_chunk . div_ ( chunk . pg_size )
if chunk . extra_dp_group is not None :
grad_chunk . cuda_global_chunk . div_ ( chunk . extra_dp_size )
reduced = chunk_manager . reduce_chunk ( grad_chunk , async_op = async_reduce )
if reduced : # if not async, can release immediately, else release in when work finished
if async_reduce :
# dirty fix by installing callback
assert not hasattr ( p , " _release_grad_chunk_cb " )
def _release_grad_chunk_cb ( ) :
grad_chunk . wait_async_reduce ( )
GeminiDDP . release_grad_chunk_handle (
chunk_manager ,
grads_device ,
master_weights ,
enable_gradient_accumulation ,
p ,
chunk ,
grad_chunk ,
)
p . _release_grad_chunk_cb = _release_grad_chunk_cb
else :
grad_chunk . cuda_shard . div_ ( chunk . pg_size )
if chunk . extra_dp_group is not None :
grad_chunk . cuda_shard . div_ ( chunk . extra_dp_size )
# check overflow elements
chunk_manager . overflow_counter + = grad_chunk . has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk . l2_norm_flag :
grad_chunk . set_l2_norm ( )
chunk_manager . move_chunk ( grad_chunk , grads_device [ p ] , force_copy = True )
if not ( master_weights ) or ( enable_gradient_accumulation ) :
chunk_manager . move_chunk ( chunk , grads_device [ p ] , force_copy = True )
GeminiDDP . release_grad_chunk_handle (
chunk_manager , grads_device , master_weights , enable_gradient_accumulation , p , chunk , grad_chunk
)
return empty_grad
@staticmethod
def release_grad_chunk_handle (
chunk_manager , grads_device , master_weights , enable_gradient_accumulation , p , chunk , grad_chunk
) :
if not chunk_manager . reuse_fp16_chunk :
if chunk . keep_gathered :
chunk_manager . fake_release_chunk ( chunk )
else :
chunk_manager . release_chunk ( chunk )
if grad_chunk . is_gathered :
grad_chunk . cuda_global_chunk . div_ ( chunk . pg_size )
if chunk . extra_dp_group is not None :
grad_chunk . cuda_global_chunk . div_ ( chunk . extra_dp_size )
else :
grad_chunk . cuda_shard . div_ ( chunk . pg_size )
if chunk . extra_dp_group is not None :
grad_chunk . cuda_shard . div_ ( chunk . extra_dp_size )
# check overflow elements
chunk_manager . overflow_counter + = grad_chunk . has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk . l2_norm_flag :
grad_chunk . set_l2_norm ( )
chunk_manager . move_chunk ( grad_chunk , grads_device [ p ] , force_copy = True )
if not ( master_weights ) or ( enable_gradient_accumulation ) :
chunk_manager . move_chunk ( chunk , grads_device [ p ] , force_copy = True )
def zero_grad ( self , set_to_none : bool = False ) - > None :
self . module . zero_grad ( set_to_none = True )