fix bugs with merge

pull/375/head
Wenwen Qu 2023-08-17 17:07:31 +08:00
parent c76182b2d6
commit 08532dc20b
2 changed files with 14 additions and 11 deletions

1
.gitignore vendored
View File

@ -132,6 +132,7 @@ runs_bak/
LLM_ALERT LLM_ALERT
small_demo/ small_demo/
7b_llama_nopp/ 7b_llama_nopp/
test/
# Pytorch # Pytorch
*.pth *.pth

View File

@ -276,12 +276,12 @@ class HybridZeroOptimizer(BaseOptimizer):
self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)]) self.params_per_rank_id_dict.append([[] for _ in range(self._zero_world_size)])
param_list = param_group["params"] param_list = param_group["params"]
if "moe" in param_group.keys() and param_group["moe"]: if self._is_moe_group(param_group):
# just add current params to params_per_rank[_zero_local_rank] # just add current params to params_per_rank[_zero_local_rank]
params_per_rank[self._zero_local_rank] = list(param_list) params_per_rank[self._zero_local_rank] = list(param_list)
self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None) self.params_per_rank_id_dict[-1][self._zero_local_rank].append(None)
no_params_ranks = list(range(self._zero_world_size)) no_params_ranks = list(range(self._zero_world_size))
no_params_ranks.pop(self._zero_world_size) no_params_ranks.pop(self._zero_local_rank)
else: else:
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
@ -307,6 +307,9 @@ class HybridZeroOptimizer(BaseOptimizer):
return params_per_rank, set(no_params_ranks) return params_per_rank, set(no_params_ranks)
def _is_moe_group(self, param_group):
return "moe" in param_group.keys() and param_group["moe"]
def _attach_reduction_hook(self): def _attach_reduction_hook(self):
# we iterate over the fp16 params # we iterate over the fp16 params
# on each param, we register a hook to its AccumulateGrad object # on each param, we register a hook to its AccumulateGrad object
@ -568,11 +571,11 @@ class HybridZeroOptimizer(BaseOptimizer):
return self._step(closure=closure, norms=total_norms) return self._step(closure=closure, norms=total_norms)
def _get_norm_with_moe_layers(self, norm_groups): def _get_norm_with_moe_layers(self, norm):
# all_groups_norm_old = all_groups_norm # all_groups_norm_old = all_groups_norm
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
pg = gpc.get_group(ParallelMode.DATA) pg = gpc.get_group(ParallelMode.DATA)
scaled_norm = norm_groups * 1.0 / float(gpc.get_world_size(ParallelMode.DATA)) scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor( scaled_norm_tensor = torch.tensor(
scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float
) )
@ -593,7 +596,6 @@ class HybridZeroOptimizer(BaseOptimizer):
if -1 in norms: if -1 in norms:
found_inf = True found_inf = True
loss_scale = float(self.loss_scale.item()) # backup loss_scale = float(self.loss_scale.item()) # backup
if gpc.config.model.dtype is not torch.float32: if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf) self.grad_scaler.update(found_inf)
@ -638,17 +640,15 @@ class HybridZeroOptimizer(BaseOptimizer):
# get the global norm # get the global norm
global_norm_groups = [] global_norm_groups = []
if self._clip_grad_norm > 0: if self._clip_grad_norm > 0:
for norm in norms: for group_id in range(self.num_param_groups):
global_norm_groups.append(norm**0.5) if self._is_moe_group(self.optim.param_groups[group_id]):
self._get_norm_with_moe_layers(norms[group_id])
if self.has_moe: global_norm_groups.append(norms[group_id] ** 0.5)
global_norm = self._get_norm_with_moe_layers(global_norm)
# the following operations are performed only on the rank to which parameters are assigned. # the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32: if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0: if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale) self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
# update the parameters # update the parameters
timer("step").start() timer("step").start()
@ -679,6 +679,8 @@ class HybridZeroOptimizer(BaseOptimizer):
handles = [] handles = []
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
if self._is_moe_group(self.optim.param_groups[group_id]):
continue
for rank in range(self._zero_world_size): for rank in range(self._zero_world_size):
# The following operations are performed only on the rank to which parameters are assigned. # The following operations are performed only on the rank to which parameters are assigned.
if rank not in self.param_group_no_params_ranks[group_id]: if rank not in self.param_group_no_params_ranks[group_id]: