mirror of https://github.com/InternLM/InternLM
fix default behavior
parent
347370a58a
commit
fdce50a000
|
@ -179,5 +179,5 @@ monitor = dict(
|
|||
model_type = "INTERNLM_MoE"
|
||||
|
||||
# metric_dtype can be "fp32" or other string
|
||||
# only when set to fp32 or unset will use fp32 to calc in metrics
|
||||
metric_dtype = "fp32"
|
||||
# only when set to "fp32" will use fp32 to calc in metrics
|
||||
# metric_dtype = "fp32"
|
||||
|
|
|
@ -179,5 +179,5 @@ monitor = dict(
|
|||
)
|
||||
|
||||
# metric_dtype can be "fp32" or other string
|
||||
# only when set to fp32 or unset will use fp32 to calc in metrics
|
||||
metric_dtype = "fp32"
|
||||
# only when set to "fp32" will use fp32 to calc in metrics
|
||||
# metric_dtype = "fp32"
|
||||
|
|
|
@ -26,8 +26,11 @@ class AccPerplex:
|
|||
self.device = device
|
||||
self.right = torch.Tensor([0]).to(device=device)
|
||||
self.total = torch.Tensor([0]).to(device=device)
|
||||
self.metric_dtype = torch.float if gpc.config.get("metric_dtype", "fp32") == "fp32" else None
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
|
||||
self.metric_dtype = torch.float if gpc.config.get("metric_dtype", None) == "fp32" else None
|
||||
if self.metric_dtype is not None:
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
|
||||
else:
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device)
|
||||
self.tp_pg = tp_pg
|
||||
self.dp_pg = dp_pg
|
||||
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
||||
|
|
Loading…
Reference in New Issue