feat(metrics): make float32 logits off by default

pull/531/head
877825076@qq.com 2023-12-08 00:46:53 +08:00
parent 81ffb3d824
commit 5c0925cd6c
1 changed files with 11 additions and 4 deletions

View File

@ -20,13 +20,18 @@ class AccPerplex:
dataset_types (List[str]): Various data types that will be used in the current training process,
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
logits_float32: Use float32 to calculate logits.
"""
def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
def __init__(
self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None, logits_float32: bool = False
):
self.device = device
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
self.total_log_probs = torch.Tensor([0]).to(
device=device, dtype=torch.float if logits_float32 else torch.bfloat16
)
self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
@ -34,6 +39,7 @@ class AccPerplex:
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
self.batch_shift = 0
self.type_ids = None
self.logits_float32 = logits_float32
if dataset_types is not None:
self.dataset_types = dataset_types
self.total_type_count = len(dataset_types)
@ -128,6 +134,7 @@ class AccPerplex:
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
if self.logits_float32:
predicted_logits = predicted_logits.to(dtype=torch.float)
shift_logits = shift_logits.to(dtype=torch.float)