fix(metric): add metric dtype control

pull/533/head
Pryest 2023-12-11 16:16:21 +08:00
parent 81ffb3d824
commit 649af64c59
1 changed files with 5 additions and 3 deletions

View File

@ -26,7 +26,8 @@ class AccPerplex:
self.device = device
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
self.metric_dtype = torch.float if gpc.config.get("metric_dtype", "fp32") == "fp32" else None
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
@ -128,8 +129,9 @@ class AccPerplex:
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
predicted_logits = predicted_logits.to(dtype=torch.float)
shift_logits = shift_logits.to(dtype=torch.float)
if self.metric_dtype is not None:
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
shift_logits = shift_logits.to(dtype=self.metric_dtype)
pred_exp_logits = torch.exp(predicted_logits)
# Sum of exponential of logits along vocab dimension across all GPUs.