mirror of https://github.com/InternLM/InternLM
optimize code with moe norm computing
parent
0ab3de8994
commit
94b8b18a49
|
@ -533,7 +533,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
last_stage=True,
|
||||
)
|
||||
|
||||
return norm
|
||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||
# model and zero have been reduced!!!
|
||||
pg = gpc.get_group(ParallelMode.DATA)
|
||||
scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
all_groups_norm = scaled_norm_tensor.item()
|
||||
return all_groups_norm
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
@ -586,19 +593,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
return self._step(closure=closure, norms=total_norms)
|
||||
|
||||
def _get_norm_with_moe_layers(self, norm):
|
||||
# all_groups_norm_old = all_groups_norm
|
||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||
pg = gpc.get_group(ParallelMode.DATA)
|
||||
scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||
scaled_norm_tensor = torch.tensor(
|
||||
scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float
|
||||
)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
all_groups_norm = scaled_norm_tensor.item()
|
||||
# print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
|
||||
return all_groups_norm
|
||||
|
||||
def _step(self, closure=None, norms=None):
|
||||
assert closure is None, "closure is not supported by step()"
|
||||
|
||||
|
@ -656,8 +650,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
global_norm_groups = []
|
||||
if self._clip_grad_norm > 0:
|
||||
for group_id in range(self.num_param_groups):
|
||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||
self._get_norm_with_moe_layers(norms[group_id])
|
||||
global_norm_groups.append(norms[group_id] ** 0.5)
|
||||
|
||||
# the following operations are performed only on the rank to which parameters are assigned.
|
||||
|
|
Loading…
Reference in New Issue