diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 1f54d06..1f2f1c7 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -20,13 +20,18 @@ class AccPerplex: dataset_types (List[str]): Various data types that will be used in the current training process, such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids(). + logits_float32: Use float32 to calculate logits. """ - def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None): + def __init__( + self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None, logits_float32: bool = False + ): self.device = device self.right = torch.Tensor([0]).to(device=device) self.total = torch.Tensor([0]).to(device=device) - self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float) + self.total_log_probs = torch.Tensor([0]).to( + device=device, dtype=torch.float if logits_float32 else torch.bfloat16 + ) self.tp_pg = tp_pg self.dp_pg = dp_pg self.tp_local_rank = torch.distributed.get_rank(self.tp_pg) @@ -34,6 +39,7 @@ class AccPerplex: self.total_bytes = torch.Tensor([0]).to(device=device).view(1) self.batch_shift = 0 self.type_ids = None + self.logits_float32 = logits_float32 if dataset_types is not None: self.dataset_types = dataset_types self.total_type_count = len(dataset_types) @@ -128,8 +134,9 @@ class AccPerplex: # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) - predicted_logits = predicted_logits.to(dtype=torch.float) - shift_logits = shift_logits.to(dtype=torch.float) + if self.logits_float32: + predicted_logits = predicted_logits.to(dtype=torch.float) + shift_logits = shift_logits.to(dtype=torch.float) pred_exp_logits = torch.exp(predicted_logits) # Sum of exponential of logits along vocab dimension across all GPUs.