@ -96,6 +96,7 @@ class GeminiDDP(ModelWrapper):
master_weights : bool = True ,
extra_dp_group : Optional [ ProcessGroup ] = None ,
verbose : bool = False ,
enable_async_reduce : bool = True ,
) - > None :
assert mixed_precision in ( torch . float16 , torch . bfloat16 )
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@ -178,6 +179,7 @@ class GeminiDDP(ModelWrapper):
if is_ddp_ignored ( p ) :
continue
if p . requires_grad :
assert not hasattr ( p , " _grad_handle " )
p . _grad_handle = p . register_hook (
partial (
GeminiDDP . grad_handle ,
@ -187,6 +189,7 @@ class GeminiDDP(ModelWrapper):
master_weights = self . master_weights ,
enable_gradient_accumulation = self . enable_gradient_accumulation ,
p = p ,
async_reduce = enable_async_reduce ,
)
)
@ -334,6 +337,11 @@ class GeminiDDP(ModelWrapper):
setattr ( param , " _gemini_reduced " , False )
def _post_backward ( self ) :
for param in self . param2name :
if hasattr ( param , " _release_grad_chunk_cb " ) :
param . _release_grad_chunk_cb ( )
delattr ( param , " _release_grad_chunk_cb " )
if self . chunk_manager . accessed_mem != 0 :
error_params = [ " Reduction failed at followed parameters: " ]
for param in self . param2name :
@ -371,6 +379,7 @@ class GeminiDDP(ModelWrapper):
master_weights : bool ,
enable_gradient_accumulation : bool ,
p : nn . Parameter ,
async_reduce : bool ,
) :
setattr ( p , " _gemini_reduced " , True )
empty_grad = torch . empty_like ( grad )
@ -406,8 +415,35 @@ class GeminiDDP(ModelWrapper):
grad_chunk . copy_tensor_to_chunk_slice ( p , grad , update_ptr = chunk_manager . reuse_fp16_chunk )
else :
grad_chunk . add_tensor_to_chunk_slice ( p , grad )
reduced = chunk_manager . reduce_chunk ( grad_chunk )
if reduced :
reduced = chunk_manager . reduce_chunk ( grad_chunk , async_op = async_reduce )
if reduced : # if not async, can release immediately, else release in when work finished
if async_reduce :
# dirty fix by installing callback
assert not hasattr ( p , " _release_grad_chunk_cb " )
def _release_grad_chunk_cb ( ) :
grad_chunk . wait_async_reduce ( )
GeminiDDP . release_grad_chunk_handle (
chunk_manager ,
grads_device ,
master_weights ,
enable_gradient_accumulation ,
p ,
chunk ,
grad_chunk ,
)
p . _release_grad_chunk_cb = _release_grad_chunk_cb
else :
GeminiDDP . release_grad_chunk_handle (
chunk_manager , grads_device , master_weights , enable_gradient_accumulation , p , chunk , grad_chunk
)
return empty_grad
@staticmethod
def release_grad_chunk_handle (
chunk_manager , grads_device , master_weights , enable_gradient_accumulation , p , chunk , grad_chunk
) :
if not chunk_manager . reuse_fp16_chunk :
if chunk . keep_gathered :
chunk_manager . fake_release_chunk ( chunk )
@ -429,7 +465,6 @@ class GeminiDDP(ModelWrapper):
chunk_manager . move_chunk ( grad_chunk , grads_device [ p ] , force_copy = True )
if not ( master_weights ) or ( enable_gradient_accumulation ) :
chunk_manager . move_chunk ( chunk , grads_device [ p ] , force_copy = True )
return empty_grad
def zero_grad ( self , set_to_none : bool = False ) - > None :
self . module . zero_grad ( set_to_none = True )