mirror of https://github.com/InternLM/InternLM
fix(metrics): remove redundant cuda memory in metric calculations (#557)
parent
c39d758a8a
commit
220953d7e5
|
@ -67,8 +67,8 @@ class AccPerplex:
|
|||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
|
||||
logits = logits.detach().clone()
|
||||
labels = labels.detach().clone()
|
||||
# logits = logits.detach().clone()
|
||||
# labels = labels.detach().clone()
|
||||
|
||||
if self.tokenizer: # need to calculate bits per bytes
|
||||
sequences = self.tokenizer.decode_ids(labels.tolist())
|
||||
|
@ -136,9 +136,9 @@ class AccPerplex:
|
|||
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
|
||||
shift_logits = shift_logits.to(dtype=self.metric_dtype)
|
||||
|
||||
pred_exp_logits = torch.exp(predicted_logits)
|
||||
pred_exp_logits = torch.exp_(predicted_logits)
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
|
||||
sum_exp_logits = torch.exp_(shift_logits).sum(dim=-1)
|
||||
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
||||
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
|
||||
|
|
Loading…
Reference in New Issue