@ -80,9 +80,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
tp_process_group : Optional [ ProcessGroup ] = None , # if using tp
forced_dtype : Optional [ torch . dtype ] = None ) :
# TODO:
# 1. state_dict for checkpoint IO
super ( LowLevelZeroOptimizer , self ) . __init__ ( optim = optimizer )
self . _dtype = self . optim . param_groups [ 0 ] [ ' params ' ] [ 0 ] . dtype
self . _logger = get_dist_logger ( )
@ -528,9 +525,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k , v in state . items ( ) :
if isinstance ( v , torch . Tensor ) and k != ' step ' :
working_param = self . _param_store . master_to_working_param [ id ( param ) ]
gather_tensor = [ torch . zeros_like ( v ) for _ in range ( self . _world_size ) ]
dist . all_gather ( gather_tensor , v , group = self . dp_pg )
param_state = torch . stack ( gather_tensor ) . view ( - 1 ) [ : working_param . numel ( ) ] . reshape_as ( working_param )
gather_tensor = [
torch . zeros ( v . shape , device = ' cuda ' , dtype = v . dtype ) for _ in range ( self . _world_size )
]
dist . all_gather ( gather_tensor , v . cuda ( ) , group = self . dp_pg )
param_state = torch . stack ( gather_tensor ) . view ( - 1 ) [ : working_param . numel ( ) ] . reshape_as (
working_param ) . cpu ( )
zero_state [ param ] [ k ] = param_state
states_dict = self . _pack_state ( zero_state )
@ -553,7 +553,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if padding_size > 0 :
v = torch . nn . functional . pad ( v , [ 0 , padding_size ] )
v_list = v . split ( v . numel ( ) / / self . _world_size )
zero_state_dict [ ' state ' ] [ param_idx ] [ k ] = v_list [ self . _local_rank ] . detach ( )
device = ' cpu ' if self . _cpu_offload else ' cuda '
zero_state_dict [ ' state ' ] [ param_idx ] [ k ] = v_list [ self . _local_rank ] . to ( device ) . detach ( )
self . optim . load_state_dict ( zero_state_dict )
zero_state_dict = dict ( )
@ -585,9 +586,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for k , v in states . items ( ) :
if isinstance ( v , torch . Tensor ) and k != ' step ' :
state_tensor = [ torch . zeros_like ( v ) for _ in range ( self . _world_size ) ]
dist . all_gather ( state_tensor , v , group = self . dp_pg )
state_tensor = torch . stack ( state_tensor ) . view ( - 1 ) [ : working_param . numel ( ) ] . reshape_as ( working_param )
state_tensor = [ torch . zeros ( v . shape , device = ' cuda ' , dtype = v . dtype ) for _ in range ( self . _world_size ) ]
dist . all_gather ( state_tensor , v . cuda ( ) , group = self . dp_pg )
state_tensor = torch . stack ( state_tensor ) . view ( - 1 ) [ : working_param . numel ( ) ] . reshape_as (
working_param ) . cpu ( )
current_block_size + = state_tensor . numel ( )
current_block [ k ] = state_tensor