Merge branch 'feature_add_moe' into feature_add_moe_data

fix bugs with compute moe norm
pull/375/head
Wenwen Qu 2023-08-22 17:31:23 +08:00
commit 428b5d2f33
1 changed files with 27 additions and 20 deletions

View File

@ -525,6 +525,26 @@ class HybridZeroOptimizer(BaseOptimizer):
return norm
def _compute_norm_with_moe_group(self, group_id):
parameters = self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._zero_local_rank)
# wo do not get the average grad for moe parameters, so we have to constuct
# the gradients list hear. Maybe this can be optimized.
gradients = [p.grad for p in parameters]
norm = compute_norm(
gradients=gradients,
parameters=parameters,
last_stage=True,
)
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
# model and zero have been reduced!!!
pg = gpc.get_group(ParallelMode.DATA)
scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item()
return all_groups_norm
def step(self, closure=None):
"""Performs a single optimization step.
@ -563,31 +583,20 @@ class HybridZeroOptimizer(BaseOptimizer):
# compute norm for gradients in the last bucket
total_norms = []
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
if self._is_moe_group(self.optim.param_groups[group_id]):
total_norms.append(self._compute_norm_with_moe_group(group_id=group_id))
else:
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
)
)
)
timer("sync_grad").start()
self._sync_grad()
timer("sync_grad").stop()
return self._step(closure=closure, norms=total_norms)
def _get_norm_with_moe_layers(self, norm):
# all_groups_norm_old = all_groups_norm
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
pg = gpc.get_group(ParallelMode.DATA)
scaled_norm = norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(
scaled_norm, device=self._fp32_flat_param_groups_of_current_rank[0].device, dtype=torch.float
)
dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item()
# print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
return all_groups_norm
def _step(self, closure=None, norms=None):
assert closure is None, "closure is not supported by step()"
@ -645,8 +654,6 @@ class HybridZeroOptimizer(BaseOptimizer):
global_norm_groups = []
if self._clip_grad_norm > 0:
for group_id in range(self.num_param_groups):
if self._is_moe_group(self.optim.param_groups[group_id]):
self._get_norm_with_moe_layers(norms[group_id])
global_norm_groups.append(norms[group_id] ** 0.5)
# the following operations are performed only on the rank to which parameters are assigned.