mirror of https://github.com/InternLM/InternLM
refactor code
parent
739a308c82
commit
4df2c47472
|
@ -252,16 +252,6 @@ def reduce_grads(gradients, parameters, fine_grained=False):
|
||||||
return parallel_grads
|
return parallel_grads
|
||||||
|
|
||||||
|
|
||||||
def reduce_moe_norm(total_norm):
|
|
||||||
pg = gpc.get_group(ParallelMode.EXPERT)
|
|
||||||
scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
|
||||||
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
|
||||||
dist.all_reduce(scaled_norm_tensor, group=pg)
|
|
||||||
total_norm = scaled_norm_tensor.item()
|
|
||||||
|
|
||||||
return total_norm
|
|
||||||
|
|
||||||
|
|
||||||
def compute_norm(
|
def compute_norm(
|
||||||
gradients,
|
gradients,
|
||||||
parameters,
|
parameters,
|
||||||
|
@ -344,7 +334,11 @@ def compute_norm(
|
||||||
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
# Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
|
||||||
# model and zero have been reduced!!!
|
# model and zero have been reduced!!!
|
||||||
if zero_mode == ParallelMode.EXPERT_DATA:
|
if zero_mode == ParallelMode.EXPERT_DATA:
|
||||||
total_norm = reduce_moe_norm(total_norm)
|
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||||
|
scaled_norm = total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.DATA))
|
||||||
|
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_current_device(), dtype=torch.float)
|
||||||
|
dist.all_reduce(scaled_norm_tensor, group=pg)
|
||||||
|
total_norm = scaled_norm_tensor.item()
|
||||||
|
|
||||||
# Scale.
|
# Scale.
|
||||||
if total_norm == float("inf") or total_norm == -float("inf"):
|
if total_norm == float("inf") or total_norm == -float("inf"):
|
||||||
|
@ -433,7 +427,12 @@ def compute_param_norm(
|
||||||
|
|
||||||
# moe
|
# moe
|
||||||
if zero_mode == ParallelMode.EXPERT_DATA:
|
if zero_mode == ParallelMode.EXPERT_DATA:
|
||||||
total_param_norms = reduce_moe_norm(total_param_norms)
|
pg = gpc.get_group(ParallelMode.EXPERT)
|
||||||
|
scaled_param_norm = torch.cuda.FloatTensor(list(total_param_norms.values()), device=get_current_device())
|
||||||
|
scaled_param_norm = scaled_param_norm / float(gpc.get_world_size(ParallelMode.EXPERT))
|
||||||
|
dist.all_reduce(scaled_param_norm, group=pg)
|
||||||
|
for i, param_name in enumerate(total_param_norms.keys()):
|
||||||
|
total_param_norms[param_name] = scaled_param_norm[i].item()
|
||||||
|
|
||||||
# scale
|
# scale
|
||||||
for param_name, param_norm in total_param_norms.items():
|
for param_name, param_norm in total_param_norms.items():
|
||||||
|
|
Loading…
Reference in New Issue