feat(model/metrics.py): support calculating accuracy and perplexity m… (#91)

* feat(model/metrics.py): support calculating accuracy and perplexity metrics

* fix(model/metrics.py): fix import error

* feat(train.py): minor update

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: huangting.p <huangting@sensetime.com>
pull/139/head
huangting4201 2023-07-26 16:22:10 +08:00 committed by GitHub
parent fd398fae1a
commit 754c5aa69a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 295 additions and 7 deletions

View File

@ -118,8 +118,8 @@ zero1 parallel:
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel: pipeline parallel size, only 1 is accepted currently. pipeline parallel: pipeline parallel size.
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently. tensor parallel: tensor parallel size, usually the number of GPUs per node.
""" """
parallel = dict( parallel = dict(
zero1=8, zero1=8,

View File

@ -83,6 +83,7 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False, forward_only: bool = False,
return_loss: bool = True, return_loss: bool = True,
scale_loss: int = 1, scale_loss: int = 1,
post_fn: Callable = None,
): ):
"""Trains one batch of data. """Trains one batch of data.
@ -94,12 +95,16 @@ class NonPipelineScheduler(BaseScheduler):
be executed. be executed.
return_loss (bool, optional): Loss will be returned if True. return_loss (bool, optional): Loss will be returned if True.
scale_loss (int, optional): The scale factor for the loss. scale_loss (int, optional): The scale factor for the loss.
post_fn (Callable, optional): Call back function after executing data forward output.
""" """
# forward # forward
with conditional_context(torch.no_grad(), enable=forward_only): with conditional_context(torch.no_grad(), enable=forward_only):
output = self._call_engine(engine, data) output = self._call_engine(engine, data)
if post_fn is not None:
post_fn(output, label)
if return_loss: if return_loss:
loss = self._call_engine_criterion(engine, output, label) loss = self._call_engine_criterion(engine, output, label)
loss /= scale_loss loss /= scale_loss
@ -120,6 +125,7 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False, forward_only: bool = False,
return_loss: bool = True, return_loss: bool = True,
return_output_label: bool = True, return_output_label: bool = True,
post_fn: Callable = None,
): ):
"""The process function that loads a batch of dataset and feeds it to the model. """The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False. The returned labels and loss will None if :attr:`return_loss` is False.
@ -131,6 +137,7 @@ class NonPipelineScheduler(BaseScheduler):
If True, the model is run for the forward pass, else back propagation will be executed. If True, the model is run for the forward pass, else back propagation will be executed.
return_loss (bool, optional): Loss will be returned if True. return_loss (bool, optional): Loss will be returned if True.
return_output_label (bool, optional): Output and label will be returned if True. return_output_label (bool, optional): Output and label will be returned if True.
post_fn (Callable, optional): Call back function after executing data forward output.
Returns: Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
@ -168,7 +175,7 @@ class NonPipelineScheduler(BaseScheduler):
_data, _label = self._load_accum_batch(data, label) _data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch( _output, _loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size _data, _label, engine, forward_only, return_loss, self._grad_accum_size, post_fn
) )
if return_loss: if return_loss:

View File

@ -3,6 +3,7 @@
from .embedding import Embedding1D, RotaryEmbedding from .embedding import Embedding1D, RotaryEmbedding
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
from .metrics import AccPerplex
from .modeling_internlm import build_model_with_cfg from .modeling_internlm import build_model_with_cfg
from .multi_head_attention import MHA from .multi_head_attention import MHA
from .utils import gather_forward_split_backward from .utils import gather_forward_split_backward
@ -13,6 +14,7 @@ __all__ = [
"RotaryEmbedding", "RotaryEmbedding",
"RewardModelLinear", "RewardModelLinear",
"ScaleColumnParallelLinear", "ScaleColumnParallelLinear",
"AccPerplex",
"MHA", "MHA",
"gather_forward_split_backward", "gather_forward_split_backward",
"build_model_with_cfg", "build_model_with_cfg",

260
internlm/model/metrics.py Normal file
View File

@ -0,0 +1,260 @@
from typing import List
import torch
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
from torch_scatter import scatter
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.parallel import is_no_pp_or_last_stage
class AccPerplex:
"""
AccPerplex module for calculating model's accuracy and perplexity metrics.
Args:
device: The GPU device.
tp_pg: The tensor parallel process group.
dp_pg: The data parallel process group.
tokenizer: For calculating BPB.
dataset_types (List[str]): Various data types that will be used in the current training process,
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
"""
def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
self.device = device
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device)
self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
self.tokenizer = tokenizer
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
self.batch_shift = 0
self.type_ids = None
if dataset_types is not None:
self.dataset_types = dataset_types
self.total_type_count = len(dataset_types)
self.ds_right = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types)
def set_current_type_ids(self, type_ids: torch.Tensor):
self.batch_shift = 0
self.type_ids = type_ids.cuda()
def __call__(self, logits, labels):
return self.update(logits, labels, type_ids=self.type_ids)
def update(self, logits, labels, type_ids=None):
micro_bsz = labels.size(0)
if type_ids is not None:
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
self.batch_shift += 1
self.loss_with_type_id.update(logits, labels, type_ids)
with torch.no_grad():
if isinstance(logits, (list, tuple)):
logits = logits[0]
logits = logits.detach().clone()
labels = labels.detach().clone()
if self.tokenizer: # need to calculate bits per bytes
sequences = self.tokenizer.decode_ids(labels.tolist())
self.total_bytes += sum(map(lambda x: len(x.encode("utf-8")), sequences))
shift_logits = logits.view(-1, logits.size(-1))
shift_labels = labels.view(-1)
# There is a shift according to the current rank, because the logits are split
pred_shift = self.tp_local_rank * logits.shape[-1]
logits_max = torch.max(shift_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=self.tp_pg)
# Determine whether the maximum value of the current local tensor is the global maximum value
logits_global = logits_max == torch.max(shift_logits, dim=-1)[0]
corrects = torch.logical_and(
(shift_labels == (shift_logits.argmax(dim=-1) + pred_shift)), logits_global
).long()
mask = shift_labels.ne(-100).long()
if hasattr(self, "total_type_count"):
ds_acc = scatter(corrects, type_ids, dim=0, reduce="sum")
token_num_type = scatter(mask, type_ids, dim=0, reduce="sum")
if len(ds_acc) < self.total_type_count:
ds_acc = torch.cat([ds_acc, ds_acc.new_zeros(self.total_type_count - len(ds_acc))])
token_num_type = torch.cat(
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
)
self.ds_tokens += token_num_type
sync_tensor = ds_acc
torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
self.ds_right += sync_tensor.view(-1)
acc = corrects.sum()
torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
self.right += acc # Masked_fill is not needed here because -100 is not available anyway
self.total += mask.sum()
# Subtract the maximum value.
shift_logits = shift_logits.sub(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = shift_logits.size()[-1]
vocab_start_index = partition_vocab_size * self.tp_local_rank
vocab_end_index = vocab_start_index + partition_vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (shift_labels < vocab_start_index) | (shift_labels >= vocab_end_index)
masked_target = shift_labels - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = shift_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(shift_labels) # bsz x max_len
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
pred_exp_logits = torch.exp(predicted_logits)
# Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
self.total_log_probs += total_log_probs
def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if hasattr(self, "total_type_count"):
torch.distributed.all_reduce(self.ds_right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.ds_tokens, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if self.tokenizer:
torch.distributed.all_reduce(self.total_bytes, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
acc = round((self.right / self.total).item(), 4)
perplexity = round(torch.exp(self.total_log_probs / self.total).item(), 4)
bits_per_bytes = round((self.total_log_probs / self.total_bytes).item(), 4) if self.tokenizer else 0
if hasattr(self, "total_type_count"):
ds_acc = {}
ds_tokens = {}
for i in range(self.total_type_count):
ds_acc[f"acc/{self.dataset_types[i]}"] = round(
(self.ds_right[i].float() / (self.ds_tokens[i].float() + 1e-5)).item(), 4
)
ds_tokens[f"tokens/{self.dataset_types[i]}"] = self.ds_tokens[i].item()
if reset:
self.right.fill_(0)
self.total.fill_(0)
self.total_log_probs.fill_(0)
self.total_bytes.fill_(0)
if hasattr(self, "total_type_count"):
self.ds_right.fill_(0)
self.ds_tokens.fill_(0)
if self.tokenizer is not None:
res = {"acc": acc, "perplexity": perplexity, "BPB": bits_per_bytes}
else:
res = {"acc": acc, "perplexity": perplexity}
if hasattr(self, "total_type_count"):
res.update(ds_acc)
res.update(ds_tokens)
loss_res = self.loss_with_type_id.get_metric()
res.update(loss_res)
return res
class LossWithTypeId:
"""
Notice the loss value computed here may be not the same with the main info loss,
cause loss here is the reduced result of the data parallel.
"""
def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
self.device = device
self.dp_pg = dp_pg
self.loss = torch.Tensor([0.0]).to(device=device)
self.token_num = torch.Tensor([0.0]).to(device=device)
if dataset_types is not None:
self.dataset_types = dataset_types
self.total_type_count = len(dataset_types)
self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
self.loss_fn = FlashCrossEntropyLoss(
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
)
def update(self, logits, labels, type_ids=None):
with torch.no_grad():
if isinstance(logits, (list, tuple)):
logits = logits[0]
logits = logits.contiguous().view(-1, logits.size(-1))
labels = labels.contiguous().view(-1)
loss_list = self.loss_fn(logits, labels)
cond = labels != -100
real_loss_list = loss_list[cond]
self.loss += real_loss_list.sum()
self.token_num += real_loss_list.numel()
if hasattr(self, "total_type_count"):
type_ids = type_ids.contiguous().view(-1).to(self.device)
real_type_ids = type_ids[cond]
loss_list_type = scatter(real_loss_list, real_type_ids, dim=0, reduce="sum")
token_num_type = scatter(torch.ones_like(real_loss_list), real_type_ids, dim=0, reduce="sum")
if len(loss_list_type) < self.total_type_count:
loss_list_type = torch.cat(
[loss_list_type, loss_list_type.new_zeros(self.total_type_count - len(loss_list_type))]
)
token_num_type = torch.cat(
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
)
self.ds_loss += loss_list_type
self.ds_token_num += token_num_type
def get_metric(self, reset=True):
if is_no_pp_or_last_stage() and self.dp_pg is not None:
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
if hasattr(self, "total_type_count"):
torch.distributed.all_reduce(self.ds_loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
torch.distributed.all_reduce(self.ds_token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
loss = round((self.loss / self.token_num).item(), 4)
res = {
"loss_from_metric": loss,
}
if hasattr(self, "total_type_count"):
ds_loss = {}
for i in range(self.total_type_count):
ds_loss[f"loss/{self.dataset_types[i]}"] = round((self.ds_loss[i] / self.ds_token_num[i]).item(), 4)
res.update(ds_loss)
if reset:
self.loss.fill_(0.0)
self.token_num.fill_(0.0)
if hasattr(self, "total_type_count"):
self.ds_loss.fill_(0.0)
self.ds_token_num.fill_(0.0)
return res

View File

@ -27,6 +27,7 @@ from internlm.data.packed_dataset import (
) )
from internlm.data.utils import DATASET_TYPE_IDS_MAP from internlm.data.utils import DATASET_TYPE_IDS_MAP
from internlm.model.loss import FlashGPTLMLoss from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
@ -207,8 +208,6 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
train_state.num_consumed_samples_in_epoch = 0 train_state.num_consumed_samples_in_epoch = 0
timer("batch-gen").stop() timer("batch-gen").stop()
batch[0].pop("type_ids", None)
return batch, train_iter return batch, train_iter
@ -254,6 +253,7 @@ def record_current_batch_training_metrics(
start_time, start_time,
loss, loss,
grad_norm, grad_norm,
metric,
): ):
""" """
Print some training metrics of current batch. Print some training metrics of current batch.
@ -261,6 +261,8 @@ def record_current_batch_training_metrics(
if success_update in (0, True): if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
acc_perplex = metric.get_metric()
if success_update and gpc.is_rank_for_log(): if success_update and gpc.is_rank_for_log():
lr = optimizer.param_groups[0]["lr"] lr = optimizer.param_groups[0]["lr"]
@ -308,6 +310,9 @@ def record_current_batch_training_metrics(
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2) fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
infos["fwd_bwd_time"] = fwd_bwd_time infos["fwd_bwd_time"] = fwd_bwd_time
for key, value in acc_perplex.items():
infos[key] = value
line = "" line = ""
for key, value in infos.items(): for key, value in infos.items():
line += f"{key}={value} " line += f"{key}={value} "
@ -396,7 +401,7 @@ def main(args):
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
# initialize the train data loader # initialize the train data loader
train_dl, _ = get_train_data_loader(num_worker=4) train_dl, dataset_types = get_train_data_loader(num_worker=4)
train_state.init_batch_sampler(train_dl) train_state.init_batch_sampler(train_dl)
# Loading model weights must be done before zero is initialized. # Loading model weights must be done before zero is initialized.
@ -430,6 +435,14 @@ def main(args):
# initialize the batch skipper # initialize the batch skipper
batch_skipper = BatchSkipper(skip_batches) batch_skipper = BatchSkipper(skip_batches)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
dataset_types=dataset_types,
)
trainer.train() trainer.train()
# transfer the train data loader into train data iterator # transfer the train data loader into train data iterator
@ -457,10 +470,15 @@ def main(args):
# zero the grads of parameters # zero the grads of parameters
trainer.zero_grad() trainer.zero_grad()
type_ids = batch[0].pop("type_ids", None)
if type_ids is not None:
metric.set_current_type_ids(type_ids=type_ids)
# do forward and backward # do forward and backward
timer("fwd-bwd").start() timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False) _, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False, post_fn=metric
)
timer("fwd-bwd").stop() timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm) # update parameters, and returns (success_update, grad_norm)
@ -490,6 +508,7 @@ def main(args):
start_time=start_time, start_time=start_time,
loss=loss, loss=loss,
grad_norm=grad_norm, grad_norm=grad_norm,
metric=metric,
) )
timer("one-batch").stop() timer("one-batch").stop()