mirror of https://github.com/InternLM/InternLM
fix(metric): use float32 to compute ppl (#481)
parent
a435980e0c
commit
5b67db33d0
|
@ -26,7 +26,7 @@ class AccPerplex:
|
|||
self.device = device
|
||||
self.right = torch.Tensor([0]).to(device=device)
|
||||
self.total = torch.Tensor([0]).to(device=device)
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device)
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
|
||||
self.tp_pg = tp_pg
|
||||
self.dp_pg = dp_pg
|
||||
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
||||
|
@ -128,6 +128,9 @@ class AccPerplex:
|
|||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
||||
predicted_logits = predicted_logits.to(dtype=torch.float)
|
||||
shift_logits = shift_logits.to(dtype=torch.float)
|
||||
|
||||
pred_exp_logits = torch.exp(predicted_logits)
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
|
||||
|
|
Loading…
Reference in New Issue