diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 1f54d06..2846245 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -26,7 +26,8 @@ class AccPerplex: self.device = device self.right = torch.Tensor([0]).to(device=device) self.total = torch.Tensor([0]).to(device=device) - self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float) + self.metric_dtype = torch.float if gpc.config.get("metric_dtype", "fp32") == "fp32" else None + self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype) self.tp_pg = tp_pg self.dp_pg = dp_pg self.tp_local_rank = torch.distributed.get_rank(self.tp_pg) @@ -128,8 +129,9 @@ class AccPerplex: # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) - predicted_logits = predicted_logits.to(dtype=torch.float) - shift_logits = shift_logits.to(dtype=torch.float) + if self.metric_dtype is not None: + predicted_logits = predicted_logits.to(dtype=self.metric_dtype) + shift_logits = shift_logits.to(dtype=self.metric_dtype) pred_exp_logits = torch.exp(predicted_logits) # Sum of exponential of logits along vocab dimension across all GPUs.