mirror of https://github.com/InternLM/InternLM
Merge branch 'feature_add_moe' of github.com:blankde/InternLM into feature_add_moe_pp_zl
commit
12b739e83b
|
@ -522,6 +522,19 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
|
def _compute_norm_with_moe_group(self, group_id):
|
||||||
|
parameters = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
|
||||||
|
# wo do not get the average grad for moe parameters, so we have to constuct
|
||||||
|
# the gradients list hear. Maybe this can be optimized.
|
||||||
|
gradients = [p.grad for p in parameters]
|
||||||
|
norm = compute_norm(
|
||||||
|
gradients=gradients,
|
||||||
|
parameters=parameters,
|
||||||
|
last_stage=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return norm
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
@ -559,12 +572,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# compute norm for gradients in the last bucket
|
# compute norm for gradients in the last bucket
|
||||||
total_norms = []
|
total_norms = []
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
|
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||||
|
total_norms.append(self._compute_norm_with_moe_group(group_id=group_id))
|
||||||
|
else:
|
||||||
total_norms.append(
|
total_norms.append(
|
||||||
self._compute_norm_with_stage(
|
self._compute_norm_with_stage(
|
||||||
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
|
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
timer("sync_grad").start()
|
timer("sync_grad").start()
|
||||||
self._sync_grad()
|
self._sync_grad()
|
||||||
timer("sync_grad").stop()
|
timer("sync_grad").stop()
|
||||||
|
|
Loading…
Reference in New Issue