Browse Source

[hotfix] fix norm type error in zero optimizer (#4795)

pull/4815/head
littsk 1 year ago committed by GitHub
parent
commit
54b3ad8924
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      colossalai/zero/low_level/_utils.py

4
colossalai/zero/low_level/_utils.py

@ -221,8 +221,8 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
else:
total_norm = 0.0
for g in gradients:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item() ** 2
param_norm = g.data.double().norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])

Loading…
Cancel
Save