diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index 8a4f05677..c437ac549 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -68,9 +68,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # 2. contiguous gradients # 3. cpu offload # 4. support when some parameters requires_grad = False - - self._optimizer = optimizer - self._dtype = self._optimizer.param_groups[0]['params'][0].dtype + super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) + self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -116,7 +115,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._clip_grad_norm = clip_grad_norm if forced_dtype: - for group in self._optimizer.param_groups: + for group in self.optim.param_groups: group_params = group['params'] for param in group_params: param.data = param.data.to(forced_dtype) @@ -134,7 +133,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access - for group_id, param_group in enumerate(self._optimizer.param_groups): + for group_id, param_group in enumerate(self.optim.param_groups): group_params = param_group['params'] # add the fp16 params to fp16_param_groups for bookkeeping @@ -198,7 +197,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): if self._overlap_communication or self._partition_grads: self._attach_reduction_hook() - self._initialize_optimizer_states() + @property + def dtype(self): + return self._dtype @property def loss_scale(self): @@ -227,25 +228,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): parallel_mode=self._dp_parallel_mode) return params_per_rank - def _initialize_optimizer_states(self): - # create a dummy zero tensor which has the same shape as that of the param - # set this dummpy zero tensor as grad - for group_id in range(len(self._fp32_flat_param_groups_of_current_rank)): - fp32_partition_param = self._fp32_flat_param_groups_of_current_rank[group_id] - fp32_partition_grad = torch.zeros_like(fp32_partition_param) - fp32_partition_param.grad = fp32_partition_grad - - # we do not need log information for optimizer, so comment them - # update the parameter with zero gradients for initialization of optimizer states - # self._optimizer.step() - - # remove the grad of the paramter to save memory - # for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items(): - # fp32_flat_tensor.grad = None - def _sanity_checks(self): assert torch.cuda.is_available(), 'CUDA is required' - for param_group in self._optimizer.param_groups: + for param_group in self.optim.param_groups: group_params = param_group['params'] for param in group_params: assert param.dtype == self._dtype, \ @@ -484,7 +469,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) # update the parameters - self._optimizer.step() + self.optim.step() # release the fp32 grad release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())