fix(metrics): remove redundant cuda memory in metric calculations

pull/557/head
877825076@qq.com 2023-12-25 22:28:10 +08:00
parent de53b17506
commit 48e25fd849
1 changed files with 4 additions and 4 deletions

View File

@ -67,8 +67,8 @@ class AccPerplex:
if isinstance(logits, (list, tuple)):
logits = logits[0]
logits = logits.detach().clone()
labels = labels.detach().clone()
# logits = logits.detach().clone()
# labels = labels.detach().clone()
if self.tokenizer: # need to calculate bits per bytes
sequences = self.tokenizer.decode_ids(labels.tolist())
@ -136,9 +136,9 @@ class AccPerplex:
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
shift_logits = shift_logits.to(dtype=self.metric_dtype)
pred_exp_logits = torch.exp(predicted_logits)
pred_exp_logits = torch.exp_(predicted_logits)
# Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
sum_exp_logits = torch.exp_(shift_logits).sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()