refactor code

pull/448/head
Qu Wenwen 2023-10-27 15:34:48 +08:00
parent 739a308c82
commit 4df2c47472
1 changed files with 11 additions and 12 deletions

View File

@ -252,16 +252,6 @@ def reduce_grads(gradients, parameters, fine_grained=False):
return parallel_grads
def reduce_moe_norm(total_norm):
pg = gpc.get_group(ParallelMode.EXPERT)
scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
total_norm = scaled_norm_tensor.item()
return total_norm
def compute_norm(
gradients,
parameters,
@ -344,7 +334,11 @@ def compute_norm(
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
# model and zero have been reduced!!!
if zero_mode == ParallelMode.EXPERT_DATA:
total_norm = reduce_moe_norm(total_norm)
pg = gpc.get_group(ParallelMode.EXPERT)
scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
total_norm = scaled_norm_tensor.item()
# Scale.
if total_norm == float("inf") or total_norm == -float("inf"):
@ -433,7 +427,12 @@ def compute_param_norm(
# moe
if zero_mode == ParallelMode.EXPERT_DATA:
total_param_norms = reduce_moe_norm(total_param_norms)
pg = gpc.get_group(ParallelMode.EXPERT)
scaled_param_norm = torch.cuda.FloatTensor(list(total_param_norms.values()), device=get_current_device())
scaled_param_norm = scaled_param_norm / float(gpc.get_world_size(ParallelMode.EXPERT))
dist.all_reduce(scaled_param_norm, group=pg)
for i, param_name in enumerate(total_param_norms.keys()):
total_param_norms[param_name] = scaled_param_norm[i].item()
# scale
for param_name, param_norm in total_param_norms.items():