mirror of https://github.com/InternLM/InternLM
feat(metrics): make float32 logits off by default
parent
81ffb3d824
commit
5c0925cd6c
|
@ -20,13 +20,18 @@ class AccPerplex:
|
|||
dataset_types (List[str]): Various data types that will be used in the current training process,
|
||||
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
|
||||
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
|
||||
logits_float32: Use float32 to calculate logits.
|
||||
"""
|
||||
|
||||
def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
|
||||
def __init__(
|
||||
self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None, logits_float32: bool = False
|
||||
):
|
||||
self.device = device
|
||||
self.right = torch.Tensor([0]).to(device=device)
|
||||
self.total = torch.Tensor([0]).to(device=device)
|
||||
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
|
||||
self.total_log_probs = torch.Tensor([0]).to(
|
||||
device=device, dtype=torch.float if logits_float32 else torch.bfloat16
|
||||
)
|
||||
self.tp_pg = tp_pg
|
||||
self.dp_pg = dp_pg
|
||||
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
||||
|
@ -34,6 +39,7 @@ class AccPerplex:
|
|||
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
|
||||
self.batch_shift = 0
|
||||
self.type_ids = None
|
||||
self.logits_float32 = logits_float32
|
||||
if dataset_types is not None:
|
||||
self.dataset_types = dataset_types
|
||||
self.total_type_count = len(dataset_types)
|
||||
|
@ -128,6 +134,7 @@ class AccPerplex:
|
|||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
||||
if self.logits_float32:
|
||||
predicted_logits = predicted_logits.to(dtype=torch.float)
|
||||
shift_logits = shift_logits.to(dtype=torch.float)
|
||||
|
||||
|
|
Loading…
Reference in New Issue