diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index ba1135940..0a15f8ddd 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -221,8 +221,8 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro else: total_norm = 0.0 for g in gradients: - param_norm = g.data.double().norm(2) - total_norm += param_norm.item() ** 2 + param_norm = g.data.double().norm(norm_type) + total_norm += param_norm.item() ** norm_type # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])