fix(metric): use float32 to compute ppl (#481)

pull/493/head
Pryest 2023-11-09 20:26:46 +08:00 committed by GitHub
parent a435980e0c
commit 5b67db33d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 1 deletions

View File

@ -26,7 +26,7 @@ class AccPerplex:
self.device = device
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
@ -128,6 +128,9 @@ class AccPerplex:
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
predicted_logits = predicted_logits.to(dtype=torch.float)
shift_logits = shift_logits.to(dtype=torch.float)
pred_exp_logits = torch.exp(predicted_logits)
# Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)