[zero] polish low level zero optimizer (#2275)

pull/2286/head
HELSON 2023-01-03 17:22:34 +08:00 committed by GitHub
parent ac863a01d6
commit 62c38e3330
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 24 deletions

View File

@ -68,9 +68,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
self._optimizer = optimizer
self._dtype = self._optimizer.param_groups[0]['params'][0].dtype
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
self._verbose = verbose
@ -116,7 +115,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._clip_grad_norm = clip_grad_norm
if forced_dtype:
for group in self._optimizer.param_groups:
for group in self.optim.param_groups:
group_params = group['params']
for param in group_params:
param.data = param.data.to(forced_dtype)
@ -134,7 +133,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self._optimizer.param_groups):
for group_id, param_group in enumerate(self.optim.param_groups):
group_params = param_group['params']
# add the fp16 params to fp16_param_groups for bookkeeping
@ -198,7 +197,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
self._initialize_optimizer_states()
@property
def dtype(self):
return self._dtype
@property
def loss_scale(self):
@ -227,25 +228,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
parallel_mode=self._dp_parallel_mode)
return params_per_rank
def _initialize_optimizer_states(self):
# create a dummy zero tensor which has the same shape as that of the param
# set this dummpy zero tensor as grad
for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)):
fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp32_partition_grad = torch.zeros_like(fp32_partition_param)
fp32_partition_param.grad = fp32_partition_grad
# we do not need log information for optimizer, so comment them
# update the parameter with zero gradients for initialization of optimizer states
# self._optimizer.step()
# remove the grad of the paramter to save memory
# for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items():
# fp32_flat_tensor.grad = None
def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
for param_group in self._optimizer.param_groups:
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
assert param.dtype == self._dtype, \
@ -484,7 +469,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm)
# update the parameters
self._optimizer.step()
self.optim.step()
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())