diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index d2ead39..cc94cdc 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -179,5 +179,5 @@ monitor = dict( model_type = "INTERNLM_MoE" # metric_dtype can be "fp32" or other string -# only when set to fp32 or unset will use fp32 to calc in metrics -metric_dtype = "fp32" +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 39a84fc..360ee92 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -179,5 +179,5 @@ monitor = dict( ) # metric_dtype can be "fp32" or other string -# only when set to fp32 or unset will use fp32 to calc in metrics -metric_dtype = "fp32" +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 2846245..704d2d6 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -26,8 +26,11 @@ class AccPerplex: self.device = device self.right = torch.Tensor([0]).to(device=device) self.total = torch.Tensor([0]).to(device=device) - self.metric_dtype = torch.float if gpc.config.get("metric_dtype", "fp32") == "fp32" else None - self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype) + self.metric_dtype = torch.float if gpc.config.get("metric_dtype", None) == "fp32" else None + if self.metric_dtype is not None: + self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype) + else: + self.total_log_probs = torch.Tensor([0]).to(device=device) self.tp_pg = tp_pg self.dp_pg = dp_pg self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)