diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 3a77f8b..1f54d06 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -26,7 +26,7 @@ class AccPerplex: self.device = device self.right = torch.Tensor([0]).to(device=device) self.total = torch.Tensor([0]).to(device=device) - self.total_log_probs = torch.Tensor([0]).to(device=device) + self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float) self.tp_pg = tp_pg self.dp_pg = dp_pg self.tp_local_rank = torch.distributed.get_rank(self.tp_pg) @@ -128,6 +128,9 @@ class AccPerplex: # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) + predicted_logits = predicted_logits.to(dtype=torch.float) + shift_logits = shift_logits.to(dtype=torch.float) + pred_exp_logits = torch.exp(predicted_logits) # Sum of exponential of logits along vocab dimension across all GPUs. sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)