diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 7abca14..6894945 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -115,7 +115,6 @@ class HybridZeroOptimizer(BaseOptimizer): super().__init__(optim=optimizer) - self._dtype = self.optim.param_groups[0]["params"][0].dtype self._cpu_offload = cpu_offload self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1) @@ -157,8 +156,8 @@ class HybridZeroOptimizer(BaseOptimizer): # need to record the rank in which parameter groups are not assigned parameters. self.param_group_has_params = [] self.param_group_no_params_ranks = [] - self.padding_grad = torch.zeros([32], dtype=self._dtype, device=get_current_device()) - self.padding_tensor = torch.zeros([32], dtype=self._dtype, device=get_current_device()) + self.padding_grad = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device()) + self.padding_tensor = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device()) self.rank_unique_id = ( f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_" @@ -177,6 +176,9 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id, param_group in enumerate(self.optim.param_groups): group_params = param_group["params"] + # set the dtype for each param group + param_group["dtype"] = group_params[0].dtype if len(group_params) != 0 else None + # add the fp16 params to fp16_param_groups for bookkeeping self._fp16_param_groups[group_id] = group_params @@ -253,10 +255,6 @@ class HybridZeroOptimizer(BaseOptimizer): def zero_world_size(self): return self._zero_world_size - @property - def dtype(self): - return self._dtype - @property def loss_scale(self): return self.grad_scaler.scale @@ -528,8 +526,9 @@ class HybridZeroOptimizer(BaseOptimizer): # compute norm for gradients that have been reduced params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket) if len(params) == 0: - grads = [self.padding_grad] - params = [self.padding_tensor] + dtype = self.param_groups[group_id]["dtype"] + grads = [self.padding_grad.to(dtype)] + params = [self.padding_tensor.to(dtype)] norm = 0 if self._clip_grad_norm > 0: