mirror of https://github.com/InternLM/InternLM
feat(model/metrics.py): support calculating accuracy and perplexity m… (#91)
* feat(model/metrics.py): support calculating accuracy and perplexity metrics * fix(model/metrics.py): fix import error * feat(train.py): minor update --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: huangting.p <huangting@sensetime.com>pull/139/head
parent
fd398fae1a
commit
754c5aa69a
|
@ -118,8 +118,8 @@ zero1 parallel:
|
||||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||||
pipeline parallel: pipeline parallel size, only 1 is accepted currently.
|
pipeline parallel: pipeline parallel size.
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
|
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=8,
|
||||||
|
|
|
@ -83,6 +83,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
forward_only: bool = False,
|
forward_only: bool = False,
|
||||||
return_loss: bool = True,
|
return_loss: bool = True,
|
||||||
scale_loss: int = 1,
|
scale_loss: int = 1,
|
||||||
|
post_fn: Callable = None,
|
||||||
):
|
):
|
||||||
"""Trains one batch of data.
|
"""Trains one batch of data.
|
||||||
|
|
||||||
|
@ -94,12 +95,16 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
be executed.
|
be executed.
|
||||||
return_loss (bool, optional): Loss will be returned if True.
|
return_loss (bool, optional): Loss will be returned if True.
|
||||||
scale_loss (int, optional): The scale factor for the loss.
|
scale_loss (int, optional): The scale factor for the loss.
|
||||||
|
post_fn (Callable, optional): Call back function after executing data forward output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||||
output = self._call_engine(engine, data)
|
output = self._call_engine(engine, data)
|
||||||
|
|
||||||
|
if post_fn is not None:
|
||||||
|
post_fn(output, label)
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
loss = self._call_engine_criterion(engine, output, label)
|
loss = self._call_engine_criterion(engine, output, label)
|
||||||
loss /= scale_loss
|
loss /= scale_loss
|
||||||
|
@ -120,6 +125,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
forward_only: bool = False,
|
forward_only: bool = False,
|
||||||
return_loss: bool = True,
|
return_loss: bool = True,
|
||||||
return_output_label: bool = True,
|
return_output_label: bool = True,
|
||||||
|
post_fn: Callable = None,
|
||||||
):
|
):
|
||||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||||
|
@ -131,6 +137,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
If True, the model is run for the forward pass, else back propagation will be executed.
|
If True, the model is run for the forward pass, else back propagation will be executed.
|
||||||
return_loss (bool, optional): Loss will be returned if True.
|
return_loss (bool, optional): Loss will be returned if True.
|
||||||
return_output_label (bool, optional): Output and label will be returned if True.
|
return_output_label (bool, optional): Output and label will be returned if True.
|
||||||
|
post_fn (Callable, optional): Call back function after executing data forward output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||||
|
@ -168,7 +175,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
_data, _label = self._load_accum_batch(data, label)
|
_data, _label = self._load_accum_batch(data, label)
|
||||||
|
|
||||||
_output, _loss = self._train_one_batch(
|
_output, _loss = self._train_one_batch(
|
||||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, post_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
from .embedding import Embedding1D, RotaryEmbedding
|
from .embedding import Embedding1D, RotaryEmbedding
|
||||||
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
|
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
|
||||||
|
from .metrics import AccPerplex
|
||||||
from .modeling_internlm import build_model_with_cfg
|
from .modeling_internlm import build_model_with_cfg
|
||||||
from .multi_head_attention import MHA
|
from .multi_head_attention import MHA
|
||||||
from .utils import gather_forward_split_backward
|
from .utils import gather_forward_split_backward
|
||||||
|
@ -13,6 +14,7 @@ __all__ = [
|
||||||
"RotaryEmbedding",
|
"RotaryEmbedding",
|
||||||
"RewardModelLinear",
|
"RewardModelLinear",
|
||||||
"ScaleColumnParallelLinear",
|
"ScaleColumnParallelLinear",
|
||||||
|
"AccPerplex",
|
||||||
"MHA",
|
"MHA",
|
||||||
"gather_forward_split_backward",
|
"gather_forward_split_backward",
|
||||||
"build_model_with_cfg",
|
"build_model_with_cfg",
|
||||||
|
|
|
@ -0,0 +1,260 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
||||||
|
from torch_scatter import scatter
|
||||||
|
|
||||||
|
from internlm.core.context import ParallelMode
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.utils.parallel import is_no_pp_or_last_stage
|
||||||
|
|
||||||
|
|
||||||
|
class AccPerplex:
|
||||||
|
"""
|
||||||
|
AccPerplex module for calculating model's accuracy and perplexity metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: The GPU device.
|
||||||
|
tp_pg: The tensor parallel process group.
|
||||||
|
dp_pg: The data parallel process group.
|
||||||
|
tokenizer: For calculating BPB.
|
||||||
|
dataset_types (List[str]): Various data types that will be used in the current training process,
|
||||||
|
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
|
||||||
|
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
|
||||||
|
self.device = device
|
||||||
|
self.right = torch.Tensor([0]).to(device=device)
|
||||||
|
self.total = torch.Tensor([0]).to(device=device)
|
||||||
|
self.total_log_probs = torch.Tensor([0]).to(device=device)
|
||||||
|
self.tp_pg = tp_pg
|
||||||
|
self.dp_pg = dp_pg
|
||||||
|
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
|
||||||
|
self.batch_shift = 0
|
||||||
|
self.type_ids = None
|
||||||
|
if dataset_types is not None:
|
||||||
|
self.dataset_types = dataset_types
|
||||||
|
self.total_type_count = len(dataset_types)
|
||||||
|
self.ds_right = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
|
||||||
|
self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types)
|
||||||
|
|
||||||
|
def set_current_type_ids(self, type_ids: torch.Tensor):
|
||||||
|
self.batch_shift = 0
|
||||||
|
self.type_ids = type_ids.cuda()
|
||||||
|
|
||||||
|
def __call__(self, logits, labels):
|
||||||
|
return self.update(logits, labels, type_ids=self.type_ids)
|
||||||
|
|
||||||
|
def update(self, logits, labels, type_ids=None):
|
||||||
|
micro_bsz = labels.size(0)
|
||||||
|
if type_ids is not None:
|
||||||
|
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
|
||||||
|
self.batch_shift += 1
|
||||||
|
self.loss_with_type_id.update(logits, labels, type_ids)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if isinstance(logits, (list, tuple)):
|
||||||
|
logits = logits[0]
|
||||||
|
|
||||||
|
logits = logits.detach().clone()
|
||||||
|
labels = labels.detach().clone()
|
||||||
|
|
||||||
|
if self.tokenizer: # need to calculate bits per bytes
|
||||||
|
sequences = self.tokenizer.decode_ids(labels.tolist())
|
||||||
|
self.total_bytes += sum(map(lambda x: len(x.encode("utf-8")), sequences))
|
||||||
|
|
||||||
|
shift_logits = logits.view(-1, logits.size(-1))
|
||||||
|
shift_labels = labels.view(-1)
|
||||||
|
# There is a shift according to the current rank, because the logits are split
|
||||||
|
pred_shift = self.tp_local_rank * logits.shape[-1]
|
||||||
|
|
||||||
|
logits_max = torch.max(shift_logits, dim=-1)[0]
|
||||||
|
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=self.tp_pg)
|
||||||
|
# Determine whether the maximum value of the current local tensor is the global maximum value
|
||||||
|
logits_global = logits_max == torch.max(shift_logits, dim=-1)[0]
|
||||||
|
|
||||||
|
corrects = torch.logical_and(
|
||||||
|
(shift_labels == (shift_logits.argmax(dim=-1) + pred_shift)), logits_global
|
||||||
|
).long()
|
||||||
|
mask = shift_labels.ne(-100).long()
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
ds_acc = scatter(corrects, type_ids, dim=0, reduce="sum")
|
||||||
|
token_num_type = scatter(mask, type_ids, dim=0, reduce="sum")
|
||||||
|
if len(ds_acc) < self.total_type_count:
|
||||||
|
ds_acc = torch.cat([ds_acc, ds_acc.new_zeros(self.total_type_count - len(ds_acc))])
|
||||||
|
token_num_type = torch.cat(
|
||||||
|
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
|
||||||
|
)
|
||||||
|
self.ds_tokens += token_num_type
|
||||||
|
sync_tensor = ds_acc
|
||||||
|
torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
self.ds_right += sync_tensor.view(-1)
|
||||||
|
|
||||||
|
acc = corrects.sum()
|
||||||
|
torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
self.right += acc # Masked_fill is not needed here because -100 is not available anyway
|
||||||
|
self.total += mask.sum()
|
||||||
|
|
||||||
|
# Subtract the maximum value.
|
||||||
|
shift_logits = shift_logits.sub(logits_max.unsqueeze(dim=-1))
|
||||||
|
|
||||||
|
# Get the partition's vocab indecies
|
||||||
|
partition_vocab_size = shift_logits.size()[-1]
|
||||||
|
vocab_start_index = partition_vocab_size * self.tp_local_rank
|
||||||
|
vocab_end_index = vocab_start_index + partition_vocab_size
|
||||||
|
|
||||||
|
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
||||||
|
target_mask = (shift_labels < vocab_start_index) | (shift_labels >= vocab_end_index)
|
||||||
|
masked_target = shift_labels - vocab_start_index
|
||||||
|
masked_target[target_mask] = 0
|
||||||
|
|
||||||
|
# Get predicted-logits = logits[target].
|
||||||
|
# For Simplicity, we convert logits to a 2-D tensor with size
|
||||||
|
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
||||||
|
logits_2d = shift_logits.view(-1, partition_vocab_size)
|
||||||
|
masked_target_1d = masked_target.view(-1)
|
||||||
|
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
|
||||||
|
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
||||||
|
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
||||||
|
predicted_logits = predicted_logits_1d.view_as(shift_labels) # bsz x max_len
|
||||||
|
predicted_logits[target_mask] = 0.0
|
||||||
|
# All reduce is needed to get the chunks from other GPUs.
|
||||||
|
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
|
||||||
|
pred_exp_logits = torch.exp(predicted_logits)
|
||||||
|
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||||
|
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
|
||||||
|
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
|
||||||
|
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
|
||||||
|
self.total_log_probs += total_log_probs
|
||||||
|
|
||||||
|
def get_metric(self, reset=True):
|
||||||
|
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||||
|
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
torch.distributed.all_reduce(self.ds_right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
torch.distributed.all_reduce(self.ds_tokens, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
if self.tokenizer:
|
||||||
|
torch.distributed.all_reduce(self.total_bytes, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
|
||||||
|
acc = round((self.right / self.total).item(), 4)
|
||||||
|
perplexity = round(torch.exp(self.total_log_probs / self.total).item(), 4)
|
||||||
|
bits_per_bytes = round((self.total_log_probs / self.total_bytes).item(), 4) if self.tokenizer else 0
|
||||||
|
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
ds_acc = {}
|
||||||
|
ds_tokens = {}
|
||||||
|
for i in range(self.total_type_count):
|
||||||
|
ds_acc[f"acc/{self.dataset_types[i]}"] = round(
|
||||||
|
(self.ds_right[i].float() / (self.ds_tokens[i].float() + 1e-5)).item(), 4
|
||||||
|
)
|
||||||
|
ds_tokens[f"tokens/{self.dataset_types[i]}"] = self.ds_tokens[i].item()
|
||||||
|
if reset:
|
||||||
|
self.right.fill_(0)
|
||||||
|
self.total.fill_(0)
|
||||||
|
self.total_log_probs.fill_(0)
|
||||||
|
self.total_bytes.fill_(0)
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
self.ds_right.fill_(0)
|
||||||
|
self.ds_tokens.fill_(0)
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
res = {"acc": acc, "perplexity": perplexity, "BPB": bits_per_bytes}
|
||||||
|
else:
|
||||||
|
res = {"acc": acc, "perplexity": perplexity}
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
res.update(ds_acc)
|
||||||
|
res.update(ds_tokens)
|
||||||
|
|
||||||
|
loss_res = self.loss_with_type_id.get_metric()
|
||||||
|
res.update(loss_res)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class LossWithTypeId:
|
||||||
|
"""
|
||||||
|
Notice the loss value computed here may be not the same with the main info loss,
|
||||||
|
cause loss here is the reduced result of the data parallel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
|
||||||
|
self.device = device
|
||||||
|
self.dp_pg = dp_pg
|
||||||
|
|
||||||
|
self.loss = torch.Tensor([0.0]).to(device=device)
|
||||||
|
self.token_num = torch.Tensor([0.0]).to(device=device)
|
||||||
|
|
||||||
|
if dataset_types is not None:
|
||||||
|
self.dataset_types = dataset_types
|
||||||
|
self.total_type_count = len(dataset_types)
|
||||||
|
self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
|
||||||
|
self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
|
||||||
|
|
||||||
|
self.loss_fn = FlashCrossEntropyLoss(
|
||||||
|
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, logits, labels, type_ids=None):
|
||||||
|
with torch.no_grad():
|
||||||
|
if isinstance(logits, (list, tuple)):
|
||||||
|
logits = logits[0]
|
||||||
|
logits = logits.contiguous().view(-1, logits.size(-1))
|
||||||
|
labels = labels.contiguous().view(-1)
|
||||||
|
loss_list = self.loss_fn(logits, labels)
|
||||||
|
|
||||||
|
cond = labels != -100
|
||||||
|
real_loss_list = loss_list[cond]
|
||||||
|
self.loss += real_loss_list.sum()
|
||||||
|
self.token_num += real_loss_list.numel()
|
||||||
|
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
type_ids = type_ids.contiguous().view(-1).to(self.device)
|
||||||
|
real_type_ids = type_ids[cond]
|
||||||
|
|
||||||
|
loss_list_type = scatter(real_loss_list, real_type_ids, dim=0, reduce="sum")
|
||||||
|
token_num_type = scatter(torch.ones_like(real_loss_list), real_type_ids, dim=0, reduce="sum")
|
||||||
|
|
||||||
|
if len(loss_list_type) < self.total_type_count:
|
||||||
|
loss_list_type = torch.cat(
|
||||||
|
[loss_list_type, loss_list_type.new_zeros(self.total_type_count - len(loss_list_type))]
|
||||||
|
)
|
||||||
|
token_num_type = torch.cat(
|
||||||
|
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
|
||||||
|
)
|
||||||
|
self.ds_loss += loss_list_type
|
||||||
|
self.ds_token_num += token_num_type
|
||||||
|
|
||||||
|
def get_metric(self, reset=True):
|
||||||
|
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
||||||
|
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
torch.distributed.all_reduce(self.ds_loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
torch.distributed.all_reduce(self.ds_token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
||||||
|
|
||||||
|
loss = round((self.loss / self.token_num).item(), 4)
|
||||||
|
res = {
|
||||||
|
"loss_from_metric": loss,
|
||||||
|
}
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
ds_loss = {}
|
||||||
|
for i in range(self.total_type_count):
|
||||||
|
ds_loss[f"loss/{self.dataset_types[i]}"] = round((self.ds_loss[i] / self.ds_token_num[i]).item(), 4)
|
||||||
|
res.update(ds_loss)
|
||||||
|
|
||||||
|
if reset:
|
||||||
|
self.loss.fill_(0.0)
|
||||||
|
self.token_num.fill_(0.0)
|
||||||
|
if hasattr(self, "total_type_count"):
|
||||||
|
self.ds_loss.fill_(0.0)
|
||||||
|
self.ds_token_num.fill_(0.0)
|
||||||
|
|
||||||
|
return res
|
27
train.py
27
train.py
|
@ -27,6 +27,7 @@ from internlm.data.packed_dataset import (
|
||||||
)
|
)
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP
|
||||||
from internlm.model.loss import FlashGPTLMLoss
|
from internlm.model.loss import FlashGPTLMLoss
|
||||||
|
from internlm.model.metrics import AccPerplex
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||||
|
@ -207,8 +208,6 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
|
||||||
train_state.num_consumed_samples_in_epoch = 0
|
train_state.num_consumed_samples_in_epoch = 0
|
||||||
timer("batch-gen").stop()
|
timer("batch-gen").stop()
|
||||||
|
|
||||||
batch[0].pop("type_ids", None)
|
|
||||||
|
|
||||||
return batch, train_iter
|
return batch, train_iter
|
||||||
|
|
||||||
|
|
||||||
|
@ -254,6 +253,7 @@ def record_current_batch_training_metrics(
|
||||||
start_time,
|
start_time,
|
||||||
loss,
|
loss,
|
||||||
grad_norm,
|
grad_norm,
|
||||||
|
metric,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Print some training metrics of current batch.
|
Print some training metrics of current batch.
|
||||||
|
@ -261,6 +261,8 @@ def record_current_batch_training_metrics(
|
||||||
|
|
||||||
if success_update in (0, True):
|
if success_update in (0, True):
|
||||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
if is_no_pp_or_last_stage():
|
||||||
|
acc_perplex = metric.get_metric()
|
||||||
|
|
||||||
if success_update and gpc.is_rank_for_log():
|
if success_update and gpc.is_rank_for_log():
|
||||||
lr = optimizer.param_groups[0]["lr"]
|
lr = optimizer.param_groups[0]["lr"]
|
||||||
|
@ -308,6 +310,9 @@ def record_current_batch_training_metrics(
|
||||||
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
|
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
|
||||||
infos["fwd_bwd_time"] = fwd_bwd_time
|
infos["fwd_bwd_time"] = fwd_bwd_time
|
||||||
|
|
||||||
|
for key, value in acc_perplex.items():
|
||||||
|
infos[key] = value
|
||||||
|
|
||||||
line = ""
|
line = ""
|
||||||
for key, value in infos.items():
|
for key, value in infos.items():
|
||||||
line += f"{key}={value} "
|
line += f"{key}={value} "
|
||||||
|
@ -396,7 +401,7 @@ def main(args):
|
||||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||||
|
|
||||||
# initialize the train data loader
|
# initialize the train data loader
|
||||||
train_dl, _ = get_train_data_loader(num_worker=4)
|
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
||||||
train_state.init_batch_sampler(train_dl)
|
train_state.init_batch_sampler(train_dl)
|
||||||
|
|
||||||
# Loading model weights must be done before zero is initialized.
|
# Loading model weights must be done before zero is initialized.
|
||||||
|
@ -430,6 +435,14 @@ def main(args):
|
||||||
# initialize the batch skipper
|
# initialize the batch skipper
|
||||||
batch_skipper = BatchSkipper(skip_batches)
|
batch_skipper = BatchSkipper(skip_batches)
|
||||||
|
|
||||||
|
# initialize metric for calculating accuracy and perplexity
|
||||||
|
metric = AccPerplex(
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||||
|
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||||
|
dataset_types=dataset_types,
|
||||||
|
)
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# transfer the train data loader into train data iterator
|
# transfer the train data loader into train data iterator
|
||||||
|
@ -457,10 +470,15 @@ def main(args):
|
||||||
|
|
||||||
# zero the grads of parameters
|
# zero the grads of parameters
|
||||||
trainer.zero_grad()
|
trainer.zero_grad()
|
||||||
|
type_ids = batch[0].pop("type_ids", None)
|
||||||
|
if type_ids is not None:
|
||||||
|
metric.set_current_type_ids(type_ids=type_ids)
|
||||||
|
|
||||||
# do forward and backward
|
# do forward and backward
|
||||||
timer("fwd-bwd").start()
|
timer("fwd-bwd").start()
|
||||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
_, _, loss = trainer.execute_schedule(
|
||||||
|
batch, forward_only=False, return_loss=True, return_output_label=False, post_fn=metric
|
||||||
|
)
|
||||||
timer("fwd-bwd").stop()
|
timer("fwd-bwd").stop()
|
||||||
|
|
||||||
# update parameters, and returns (success_update, grad_norm)
|
# update parameters, and returns (success_update, grad_norm)
|
||||||
|
@ -490,6 +508,7 @@ def main(args):
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
grad_norm=grad_norm,
|
grad_norm=grad_norm,
|
||||||
|
metric=metric,
|
||||||
)
|
)
|
||||||
|
|
||||||
timer("one-batch").stop()
|
timer("one-batch").stop()
|
||||||
|
|
Loading…
Reference in New Issue