fix param_metrics is not a tensor

pull/449/head
JiaoPL 2023-10-26 18:41:13 +08:00
parent 83cb7036a7
commit 325b549707
1 changed files with 5 additions and 4 deletions

View File

@ -444,13 +444,14 @@ def compute_param_metric(
# scale norm
if metric_type == "norm":
for param_name, param_metric in total_metrics.items():
metric_value = param_metric.item()
if metric_value in (inf, -inf):
if torch.is_tensor(param_metric):
param_metric = param_metric.item()
if param_metric in (inf, -inf):
total_metrics[param_name] = -1
elif math.isnan(metric_value):
elif math.isnan(param_metric):
total_metrics[param_name] = -2
else:
total_metrics[param_name] = metric_value
total_metrics[param_name] = param_metric
return total_metrics