fix default behavior

pull/533/head
Pryest 2023-12-11 17:27:09 +08:00
parent 347370a58a
commit fdce50a000
3 changed files with 9 additions and 6 deletions

View File

@ -179,5 +179,5 @@ monitor = dict(
model_type = "INTERNLM_MoE"
# metric_dtype can be "fp32" or other string
# only when set to fp32 or unset will use fp32 to calc in metrics
metric_dtype = "fp32"
# only when set to "fp32" will use fp32 to calc in metrics
# metric_dtype = "fp32"

View File

@ -179,5 +179,5 @@ monitor = dict(
)
# metric_dtype can be "fp32" or other string
# only when set to fp32 or unset will use fp32 to calc in metrics
metric_dtype = "fp32"
# only when set to "fp32" will use fp32 to calc in metrics
# metric_dtype = "fp32"

View File

@ -26,8 +26,11 @@ class AccPerplex:
self.device = device
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.metric_dtype = torch.float if gpc.config.get("metric_dtype", "fp32") == "fp32" else None
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
self.metric_dtype = torch.float if gpc.config.get("metric_dtype", None) == "fp32" else None
if self.metric_dtype is not None:
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
else:
self.total_log_probs = torch.Tensor([0]).to(device=device)
self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)