fix(metrics): remove redundant cuda memory in metric calculations (#557)

pull/570/head
Guoteng 2023-12-29 20:21:24 +08:00 committed by GitHub
parent c39d758a8a
commit 220953d7e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -67,8 +67,8 @@ class AccPerplex:
if isinstance(logits, (list, tuple)):
logits = logits[0]
logits = logits.detach().clone()
labels = labels.detach().clone()
# logits = logits.detach().clone()
# labels = labels.detach().clone()
if self.tokenizer: # need to calculate bits per bytes
sequences = self.tokenizer.decode_ids(labels.tolist())
@ -136,9 +136,9 @@ class AccPerplex:
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
shift_logits = shift_logits.to(dtype=self.metric_dtype)
pred_exp_logits = torch.exp(predicted_logits)
pred_exp_logits = torch.exp_(predicted_logits)
# Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
sum_exp_logits = torch.exp_(shift_logits).sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()