From 220953d7e5031da44fbe94e0d0b2530503897189 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Fri, 29 Dec 2023 20:21:24 +0800 Subject: [PATCH] fix(metrics): remove redundant cuda memory in metric calculations (#557) --- internlm/model/metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 704d2d6..55e0219 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -67,8 +67,8 @@ class AccPerplex: if isinstance(logits, (list, tuple)): logits = logits[0] - logits = logits.detach().clone() - labels = labels.detach().clone() + # logits = logits.detach().clone() + # labels = labels.detach().clone() if self.tokenizer: # need to calculate bits per bytes sequences = self.tokenizer.decode_ids(labels.tolist()) @@ -136,9 +136,9 @@ class AccPerplex: predicted_logits = predicted_logits.to(dtype=self.metric_dtype) shift_logits = shift_logits.to(dtype=self.metric_dtype) - pred_exp_logits = torch.exp(predicted_logits) + pred_exp_logits = torch.exp_(predicted_logits) # Sum of exponential of logits along vocab dimension across all GPUs. - sum_exp_logits = torch.exp(shift_logits).sum(dim=-1) + sum_exp_logits = torch.exp_(shift_logits).sum(dim=-1) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()