mirror of https://github.com/InternLM/InternLM
fix(metrics): remove redundant cuda memory in metric calculations (#557)
parent
c39d758a8a
commit
220953d7e5
|
@ -67,8 +67,8 @@ class AccPerplex:
|
||||||
if isinstance(logits, (list, tuple)):
|
if isinstance(logits, (list, tuple)):
|
||||||
logits = logits[0]
|
logits = logits[0]
|
||||||
|
|
||||||
logits = logits.detach().clone()
|
# logits = logits.detach().clone()
|
||||||
labels = labels.detach().clone()
|
# labels = labels.detach().clone()
|
||||||
|
|
||||||
if self.tokenizer: # need to calculate bits per bytes
|
if self.tokenizer: # need to calculate bits per bytes
|
||||||
sequences = self.tokenizer.decode_ids(labels.tolist())
|
sequences = self.tokenizer.decode_ids(labels.tolist())
|
||||||
|
@ -136,9 +136,9 @@ class AccPerplex:
|
||||||
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
|
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
|
||||||
shift_logits = shift_logits.to(dtype=self.metric_dtype)
|
shift_logits = shift_logits.to(dtype=self.metric_dtype)
|
||||||
|
|
||||||
pred_exp_logits = torch.exp(predicted_logits)
|
pred_exp_logits = torch.exp_(predicted_logits)
|
||||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||||
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
|
sum_exp_logits = torch.exp_(shift_logits).sum(dim=-1)
|
||||||
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
|
||||||
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
|
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
|
||||||
|
|
Loading…
Reference in New Issue