diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 2f00e1d..2fb8f57 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -444,13 +444,14 @@ def compute_param_metric( # scale norm if metric_type == "norm": for param_name, param_metric in total_metrics.items(): - metric_value = param_metric.item() - if metric_value in (inf, -inf): + if torch.is_tensor(param_metric): + param_metric = param_metric.item() + if param_metric in (inf, -inf): total_metrics[param_name] = -1 - elif math.isnan(metric_value): + elif math.isnan(param_metric): total_metrics[param_name] = -2 else: - total_metrics[param_name] = metric_value + total_metrics[param_name] = param_metric return total_metrics