mirror of https://github.com/InternLM/InternLM
refactor code
parent
739a308c82
commit
4df2c47472
|
@ -252,16 +252,6 @@ def reduce_grads(gradients, parameters, fine_grained=False):
|
|||
return parallel_grads
|
||||
|
||||
|
||||
def reduce_moe_norm(total_norm):
|
||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||
scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
total_norm = scaled_norm_tensor.item()
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def compute_norm(
|
||||
gradients,
|
||||
parameters,
|
||||
|
@ -344,7 +334,11 @@ def compute_norm(
|
|||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||
# model and zero have been reduced!!!
|
||||
if zero_mode == ParallelMode.EXPERT_DATA:
|
||||
total_norm = reduce_moe_norm(total_norm)
|
||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||
scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||
total_norm = scaled_norm_tensor.item()
|
||||
|
||||
# Scale.
|
||||
if total_norm == float("inf") or total_norm == -float("inf"):
|
||||
|
@ -433,7 +427,12 @@ def compute_param_norm(
|
|||
|
||||
# moe
|
||||
if zero_mode == ParallelMode.EXPERT_DATA:
|
||||
total_param_norms = reduce_moe_norm(total_param_norms)
|
||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||
scaled_param_norm = torch.cuda.FloatTensor(list(total_param_norms.values()), device=get_current_device())
|
||||
scaled_param_norm = scaled_param_norm / float(gpc.get_world_size(ParallelMode.EXPERT))
|
||||
dist.all_reduce(scaled_param_norm, group=pg)
|
||||
for i, param_name in enumerate(total_param_norms.keys()):
|
||||
total_param_norms[param_name] = scaled_param_norm[i].item()
|
||||
|
||||
# scale
|
||||
for param_name, param_norm in total_param_norms.items():
|
||||
|
|
Loading…
Reference in New Issue