mirror of https://github.com/InternLM/InternLM
fix(metric): add metric dtype control
parent
81ffb3d824
commit
649af64c59
|
@ -26,7 +26,8 @@ class AccPerplex:
|
|||
self.device = device
|
||||
self.right = torch.Tensor([0]).to(device=device)
|
||||
self.total = torch.Tensor([0]).to(device=device)
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
|
||||
self.metric_dtype = torch.float if gpc.config.get("metric_dtype", "fp32") == "fp32" else None
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
|
||||
self.tp_pg = tp_pg
|
||||
self.dp_pg = dp_pg
|
||||
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
||||
|
@ -128,8 +129,9 @@ class AccPerplex:
|
|||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
||||
predicted_logits = predicted_logits.to(dtype=torch.float)
|
||||
shift_logits = shift_logits.to(dtype=torch.float)
|
||||
if self.metric_dtype is not None:
|
||||
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
|
||||
shift_logits = shift_logits.to(dtype=self.metric_dtype)
|
||||
|
||||
pred_exp_logits = torch.exp(predicted_logits)
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
|
|
Loading…
Reference in New Issue