move optim.dtype to each param group

pull/382/head
Qu Wenwen 2023-10-09 12:39:03 +08:00
parent 11a4a8bb44
commit 856f88e97b
1 changed files with 8 additions and 9 deletions

View File

@ -115,7 +115,6 @@ class HybridZeroOptimizer(BaseOptimizer):
super().__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._cpu_offload = cpu_offload
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
self._zero_world_size = gpc.get_world_size(ParallelMode.ZERO1)
@ -157,8 +156,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# need to record the rank in which parameter groups are not assigned parameters.
self.param_group_has_params = []
self.param_group_no_params_ranks = []
self.padding_grad = torch.zeros([32], dtype=self._dtype, device=get_current_device())
self.padding_tensor = torch.zeros([32], dtype=self._dtype, device=get_current_device())
self.padding_grad = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device())
self.padding_tensor = torch.zeros([32], dtype=gpc.config.model.dtype, device=get_current_device())
self.rank_unique_id = (
f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_"
@ -177,6 +176,9 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id, param_group in enumerate(self.optim.param_groups):
group_params = param_group["params"]
# set the dtype for each param group
param_group["dtype"] = group_params[0].dtype if len(group_params) != 0 else None
# add the fp16 params to fp16_param_groups for bookkeeping
self._fp16_param_groups[group_id] = group_params
@ -253,10 +255,6 @@ class HybridZeroOptimizer(BaseOptimizer):
def zero_world_size(self):
return self._zero_world_size
@property
def dtype(self):
return self._dtype
@property
def loss_scale(self):
return self.grad_scaler.scale
@ -528,8 +526,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# compute norm for gradients that have been reduced
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
if len(params) == 0:
grads = [self.padding_grad]
params = [self.padding_tensor]
dtype = self.param_groups[group_id]["dtype"]
grads = [self.padding_grad.to(dtype)]
params = [self.padding_tensor.to(dtype)]
norm = 0
if self._clip_grad_norm > 0: