[checkpoint] sharded optim save/load grad scaler (#1350)

pull/1353/head
ver217 2022-07-21 15:21:21 +08:00 committed by GitHub
parent 05fae1fd56
commit ce470ba37e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 0 deletions

View File

@ -363,7 +363,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.master_params[p].trans_state(TensorState.HOLD)
def state_dict(self):
optim_state_dict = super().state_dict()
scaler_state_dict = self.grad_scaler.state_dict()
optim_state_dict['scaler'] = scaler_state_dict
return optim_state_dict
def load_state_dict(self, *args, **kwargs):
if 'scaler' not in args[0]:
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
else:
scaler_state_dict = args[0].pop('scaler')
self.grad_scaler.load_state_dict(scaler_state_dict)
super().load_state_dict(*args, **kwargs)
for group in self.optim.param_groups:
for p in group['params']: