fix param_metrics is not a tensor

pull/449/head
JiaoPL 2023-10-26 18:41:13 +08:00
parent 83cb7036a7
commit 325b549707
1 changed files with 5 additions and 4 deletions

View File

@ -444,13 +444,14 @@ def compute_param_metric(
# scale norm # scale norm
if metric_type == "norm": if metric_type == "norm":
for param_name, param_metric in total_metrics.items(): for param_name, param_metric in total_metrics.items():
metric_value = param_metric.item() if torch.is_tensor(param_metric):
if metric_value in (inf, -inf): param_metric = param_metric.item()
if param_metric in (inf, -inf):
total_metrics[param_name] = -1 total_metrics[param_name] = -1
elif math.isnan(metric_value): elif math.isnan(param_metric):
total_metrics[param_name] = -2 total_metrics[param_name] = -2
else: else:
total_metrics[param_name] = metric_value total_metrics[param_name] = param_metric
return total_metrics return total_metrics