@ -18,6 +18,7 @@ from colossalai.utils import get_current_device
from colossalai . zero . utils . gemini_hook import GeminiZeROHook
from . reducer import Reducer
from . utils import get_static_torch_model
try :
from torch . nn . modules . module import _EXTRA_STATE_KEY_SUFFIX , _IncompatibleKeys
@ -251,6 +252,7 @@ class ZeroDDP(ColoDDP):
pin_memory = pin_memory )
self . fp32_params . append ( fp32_p )
self . grads_device [ p ] = self . gemini_manager . default_device
self . chunk_manager . close_all_groups ( )
self . _cast_buffers ( )
@ -331,12 +333,11 @@ class ZeroDDP(ColoDDP):
for tensor in chunk . get_tensors ( ) :
self . grads_device [ tensor ] = device
def state_dict ( self , destination = None , prefix = ' ' , keep_vars = False , only_rank_0 : bool = True ) :
r """ Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers ( e . g . running averages ) are
included . Keys are corresponding parameter and buffer names .
Parameters and buffers set to ` ` None ` ` are not included .
def state_dict ( self , destination = None , prefix = ' ' , keep_vars = False , only_rank_0 : bool = True , strict : bool = True ) :
r """
Args :
strict ( bool ) : whether to reture the whole model state
as the original pytorch state_dict ( )
Returns :
dict :
@ -346,7 +347,31 @@ class ZeroDDP(ColoDDP):
>> > module . state_dict ( ) . keys ( )
[ ' bias ' , ' weight ' ]
"""
if strict :
return get_static_torch_model ( zero_ddp_model = self , device = get_current_device ( ) ,
only_rank_0 = only_rank_0 ) . state_dict ( destination = destination ,
prefix = prefix ,
keep_vars = keep_vars )
return self . _non_strict_state_dict ( destination = destination ,
prefix = prefix ,
keep_vars = keep_vars ,
only_rank_0 = only_rank_0 )
def _non_strict_state_dict ( self , destination = None , prefix = ' ' , keep_vars = False , only_rank_0 : bool = True ) :
r """ Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers ( e . g . running averages ) are
included . Keys are corresponding parameter and buffer names .
Parameters and buffers set to ` ` None ` ` are not included .
Warning : The non strict state dict would ignore the parameters if the
tensors of the parameters are shared with other parameters which
have been included in the dictionary .
Returns :
dict :
a dictionary containing a whole state of the module
"""
if destination is None :
destination = OrderedDict ( )