mirror of https://github.com/InternLM/InternLM
feat(metrics): make float32 logits off by default
parent
81ffb3d824
commit
5c0925cd6c
|
@ -20,13 +20,18 @@ class AccPerplex:
|
||||||
dataset_types (List[str]): Various data types that will be used in the current training process,
|
dataset_types (List[str]): Various data types that will be used in the current training process,
|
||||||
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
|
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
|
||||||
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
|
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
|
||||||
|
logits_float32: Use float32 to calculate logits.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
|
def __init__(
|
||||||
|
self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None, logits_float32: bool = False
|
||||||
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.right = torch.Tensor([0]).to(device=device)
|
self.right = torch.Tensor([0]).to(device=device)
|
||||||
self.total = torch.Tensor([0]).to(device=device)
|
self.total = torch.Tensor([0]).to(device=device)
|
||||||
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
|
self.total_log_probs = torch.Tensor([0]).to(
|
||||||
|
device=device, dtype=torch.float if logits_float32 else torch.bfloat16
|
||||||
|
)
|
||||||
self.tp_pg = tp_pg
|
self.tp_pg = tp_pg
|
||||||
self.dp_pg = dp_pg
|
self.dp_pg = dp_pg
|
||||||
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
||||||
|
@ -34,6 +39,7 @@ class AccPerplex:
|
||||||
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
|
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
|
||||||
self.batch_shift = 0
|
self.batch_shift = 0
|
||||||
self.type_ids = None
|
self.type_ids = None
|
||||||
|
self.logits_float32 = logits_float32
|
||||||
if dataset_types is not None:
|
if dataset_types is not None:
|
||||||
self.dataset_types = dataset_types
|
self.dataset_types = dataset_types
|
||||||
self.total_type_count = len(dataset_types)
|
self.total_type_count = len(dataset_types)
|
||||||
|
@ -128,8 +134,9 @@ class AccPerplex:
|
||||||
# All reduce is needed to get the chunks from other GPUs.
|
# All reduce is needed to get the chunks from other GPUs.
|
||||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
|
||||||
predicted_logits = predicted_logits.to(dtype=torch.float)
|
if self.logits_float32:
|
||||||
shift_logits = shift_logits.to(dtype=torch.float)
|
predicted_logits = predicted_logits.to(dtype=torch.float)
|
||||||
|
shift_logits = shift_logits.to(dtype=torch.float)
|
||||||
|
|
||||||
pred_exp_logits = torch.exp(predicted_logits)
|
pred_exp_logits = torch.exp(predicted_logits)
|
||||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||||
|
|
Loading…
Reference in New Issue