Merge branch 'feature_add_moe' of github.com:blankde/InternLM into feature_add_moe_pp_zl

pull/182/head
zhanglei 2023-08-22 18:56:29 +08:00
commit 12b739e83b
1 changed files with 20 additions and 5 deletions

View File

@ -522,6 +522,19 @@ class HybridZeroOptimizer(BaseOptimizer):
return norm
def _compute_norm_with_moe_group(self, group_id):
parameters = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
# wo do not get the average grad for moe parameters, so we have to constuct
# the gradients list hear. Maybe this can be optimized.
gradients = [p.grad for p in parameters]
norm = compute_norm(
gradients=gradients,
parameters=parameters,
last_stage=True,
)
return norm
def step(self, closure=None):
"""Performs a single optimization step.
@ -559,12 +572,14 @@ class HybridZeroOptimizer(BaseOptimizer):
# compute norm for gradients in the last bucket
total_norms = []
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
if self._is_moe_group(self.optim.param_groups[group_id]):
total_norms.append(self._compute_norm_with_moe_group(group_id=group_id))
else:
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
)
)
)
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()