[hotfix] fix norm type error in zero optimizer (#4795)

pull/4815/head
littsk 2023-09-27 10:35:24 +08:00 committed by GitHub
parent da15fdb9ca
commit 54b3ad8924
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -221,8 +221,8 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro
else:
total_norm = 0.0
for g in gradients:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item() ** 2
param_norm = g.data.double().norm(norm_type)
total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])