fix(metric): fix metric behavior when ppl exceeds float16

pull/481/head
Pryest 2023-11-09 16:01:27 +08:00
parent 79e84fade3
commit 745fb33ca5
1 changed files with 4 additions and 1 deletions

View File

@ -26,7 +26,7 @@ class AccPerplex:
self.device = device
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
@ -128,6 +128,9 @@ class AccPerplex:
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
predicted_logits = predicted_logits.to(dtype=torch.float)
shift_logits = shift_logits.to(dtype=torch.float)
pred_exp_logits = torch.exp(predicted_logits)
# Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)