mirror of https://github.com/InternLM/InternLM
feat(utils/evaluation.py): support evaluate (#154)
* style(internlm): fix lint error * feat(utils/logger.py): support uniscale logger * fix(utils/logger.py): fix import circular error * feat(train.py): support dashboard metric panel and fix ci train config * fix(ci_scripts/train/slurm_train.sh): fix ci train error * fix(ci_scripts/train/torchrun.sh): fix ci train error * feat(utils/evaluation.py): support evaluate on validation dataset * fix(utils/evaluation.py): fix demo error * fix(ci_scripts/train/ci_7B_sft.py): fix ci train error * feat(initialize/launch.py): set default value for valid_bsz and valid_every * fix(ci_scripts/train): restore ci update * docs(configs/7B_sft.py): update comment for config * fix(config.json): delete config.json * fix evaluation bug in scheduler when use_flash_attn=False * feat(scheduler/no_pipeline_scheduler.py): support micro_bsz>1 in no pp * modify the jugement in pp and no-pp scheduler * modify the data_process_func in evaluation * fix bugs when use_flash_attn=False * rename symbol * feat(configs/7B_sft.py): change para valid_bsz to valid_micro_num * feat(scheduler/no_pipeline_scheduler.py): update para set _grad_accum_batch_size --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: huangting.p <huangting@sensetime.com> Co-authored-by: yingtongxiong <974106207@qq.com>pull/166/head^2
parent
1f7304a8bb
commit
66a23e326a
|
@ -26,12 +26,17 @@ ckpt = dict(
|
|||
)
|
||||
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
VALID_FOLDER = "/path/to/dataset"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
valid_every=50,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=50000,
|
||||
skip_batches="",
|
||||
|
@ -39,6 +44,7 @@ data = dict(
|
|||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
# valid_folder=VALID_FOLDER,
|
||||
)
|
||||
|
||||
grad_scaler = dict(
|
||||
|
|
|
@ -35,9 +35,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
|
||||
|
||||
self._grad_accum_size = gradient_accumulation_size
|
||||
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
|
||||
self._grad_accum_offset = 0
|
||||
|
||||
super().__init__(data_process_func)
|
||||
|
@ -145,8 +143,9 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
batch_data, batch_size = engine.load_batch(data_iter)
|
||||
|
||||
assert (
|
||||
batch_size == self._grad_accum_size
|
||||
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
|
||||
batch_size % self._grad_accum_size == 0
|
||||
), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}"
|
||||
self._grad_accum_batch_size = batch_size // self._grad_accum_size
|
||||
|
||||
data, label = batch_data
|
||||
|
||||
|
|
|
@ -133,7 +133,6 @@ class PipelineScheduler(BaseScheduler):
|
|||
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
|
||||
)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
|
||||
if self.data_process_func:
|
||||
micro_batch_data["input_ids"] = self.data_process_func(
|
||||
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
|
||||
|
@ -310,7 +309,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
|
||||
|
||||
self.load_batch(engine, data_iter)
|
||||
num_warmup_microbatches = (
|
||||
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from internlm.data.single_dataset import JsonlDataset
|
||||
|
||||
|
||||
def get_dataset_dict(folder, split="valid") -> Dict:
|
||||
"""
|
||||
Return a dictionary of Datasets from a folder containing data files for validation.
|
||||
|
||||
Args:
|
||||
folder (str): The path to the folder containing data files.
|
||||
split (str): The split of the data files to be used, default is "valid".
|
||||
|
||||
Returns:
|
||||
A dictionary containing Datasets for each folder in the given path
|
||||
that contains data files with the specified split.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the given folder does not exist.
|
||||
|
||||
Example:
|
||||
If the given folder is as follows,
|
||||
- data
|
||||
- zhihu
|
||||
- xxx.bin
|
||||
- valid.bin
|
||||
- baike
|
||||
- xxx.bin
|
||||
- valid.bin
|
||||
|
||||
The returned dictionary will be,
|
||||
{
|
||||
'zhihu': Dataset,
|
||||
'baike': Dataset
|
||||
}
|
||||
"""
|
||||
|
||||
assert os.path.exists(folder), f"folder `{folder}` not exists"
|
||||
data_dict = {}
|
||||
|
||||
for root, dirs, files in os.walk(folder, followlinks=True):
|
||||
dirs.sort() # The order is guaranteed, and the newly added data starting with z needs to be ranked behind
|
||||
datasets = []
|
||||
for fn in sorted(files): # Need sorted to ensure that the order is consistent
|
||||
if fn.endswith(".bin") and split in fn:
|
||||
fp = os.path.join(root, fn)
|
||||
ds = JsonlDataset(fp)
|
||||
datasets.append(ds)
|
||||
if datasets:
|
||||
ds = ConcatDataset(datasets=datasets)
|
||||
data_dict[os.path.basename(root)] = ds
|
||||
|
||||
return data_dict
|
|
@ -88,6 +88,12 @@ def args_sanity_check():
|
|||
if "valid_folder" not in data:
|
||||
data._add_item("valid_folder", None)
|
||||
|
||||
if "valid_micro_num" not in data:
|
||||
data._add_item("valid_micro_num", data.micro_num)
|
||||
|
||||
if "valid_every" not in data:
|
||||
data._add_item("valid_every", 0)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
|
||||
logger.info(f"seq_len: {data.seq_len}")
|
||||
|
@ -96,6 +102,8 @@ def args_sanity_check():
|
|||
logger.info(f"packed_length: {data.packed_length}")
|
||||
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
|
||||
logger.info(f"min_length: {data.min_length}")
|
||||
logger.info(f"valid_micro_num: {data.valid_micro_num}")
|
||||
logger.info(f"valid_every: {data.valid_every}")
|
||||
|
||||
# processing the checkpoint config
|
||||
if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0:
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.metrics import AccPerplex
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size):
|
||||
if not gpc.is_using_pp():
|
||||
prev_data_process_func = trainer.schedule.data_process_func
|
||||
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
||||
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
|
||||
try:
|
||||
trainer.schedule.data_process_func = None
|
||||
trainer.schedule._grad_accum_size = grad_accum_size
|
||||
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
|
||||
yield
|
||||
finally:
|
||||
trainer.schedule.data_process_func = prev_data_process_func
|
||||
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
||||
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape):
|
||||
if gpc.is_using_pp():
|
||||
pre_data_process_func = trainer.schedule.data_process_func
|
||||
prev_num_microbatches = trainer.schedule.num_microbatches
|
||||
prev_tensor_shape = trainer.schedule.tensor_shape
|
||||
try:
|
||||
trainer.schedule.data_process_func = None
|
||||
trainer.schedule.num_microbatches = num_microbatches
|
||||
trainer.schedule.tensor_shape = tensor_shape
|
||||
yield
|
||||
finally:
|
||||
trainer.schedule.data_process_func = pre_data_process_func
|
||||
trainer.schedule.num_microbatches = prev_num_microbatches
|
||||
trainer.schedule.tensor_shape = prev_tensor_shape
|
||||
|
||||
|
||||
def evaluate_on_val_dls(
|
||||
trainer,
|
||||
val_dls,
|
||||
writer,
|
||||
logger,
|
||||
step_count,
|
||||
tokenizer=None,
|
||||
update_panel: bool = False,
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
trainer.eval()
|
||||
verbose = gpc.is_rank_for_log()
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
for val_name, val_dl in val_dls.items():
|
||||
if len(val_dl) == 0 and verbose:
|
||||
logger.info(f"Validation dataset: {val_name} is empty")
|
||||
continue
|
||||
|
||||
val_metric = AccPerplex(
|
||||
device=torch.cuda.current_device(),
|
||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
val_loss = 0
|
||||
val_idx = -1
|
||||
for val_idx, batch in tqdm(
|
||||
enumerate(val_dl),
|
||||
desc="Val.",
|
||||
total=len(val_dl),
|
||||
position=1,
|
||||
disable=not verbose,
|
||||
leave=False,
|
||||
):
|
||||
with torch.inference_mode():
|
||||
if gpc.is_using_pp():
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer, num_microbatches=num_microbatches, tensor_shape=tensor_shape
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
|
||||
)
|
||||
else:
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||
grad_accum_batch_size = data_cfg.micro_bsz
|
||||
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer, grad_accum_size=grad_accum_size, grad_accum_batch_size=grad_accum_batch_size
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
|
||||
)
|
||||
if verbose:
|
||||
val_loss += loss.item()
|
||||
|
||||
assert val_idx != -1
|
||||
dist.barrier()
|
||||
val_res = val_metric.get_metric()
|
||||
|
||||
if verbose and len(val_dl) != 0:
|
||||
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
||||
infos = {
|
||||
f"val/{val_name}_loss": val_loss,
|
||||
f"val/{val_name}_acc": val_res["acc"],
|
||||
f"val/{val_name}_plex": val_res["perplexity"],
|
||||
}
|
||||
val_metric = {
|
||||
"step": step_count,
|
||||
"val_loss": val_loss,
|
||||
"val_acc": val_res["acc"],
|
||||
"val_perplexity": val_res["perplexity"],
|
||||
}
|
||||
for key, value in infos.items():
|
||||
writer.add_scalar(key=key, value=value, step=step_count)
|
||||
infos["step"] = step_count
|
||||
if update_panel:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
||||
extra=val_metric,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
|
@ -4,7 +4,6 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
|
||||
LOGGER_NAME = "internlm"
|
||||
LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s"
|
||||
LOGGER_LEVEL = "info"
|
||||
|
|
62
train.py
62
train.py
|
@ -17,8 +17,9 @@ from internlm.core.context import ParallelMode
|
|||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.data.batch_sampler import StaticBatchSampler
|
||||
from internlm.data.collaters import packed_collate_fn
|
||||
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
|
||||
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
|
||||
from internlm.data.dataset import get_dataset_dict
|
||||
from internlm.data.dummy_dataset import RandomDataset
|
||||
from internlm.data.packed_dataset import (
|
||||
PackedDataset,
|
||||
|
@ -39,6 +40,7 @@ from internlm.utils.common import (
|
|||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import (
|
||||
|
@ -196,6 +198,45 @@ def get_train_data_loader(num_worker: int = 0):
|
|||
return train_dl, dataset_types
|
||||
|
||||
|
||||
def get_validation_data_loader(num_worker: int = 0):
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
if not data_cfg.valid_folder:
|
||||
val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len)
|
||||
else:
|
||||
val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="")
|
||||
|
||||
if not isinstance(val_ds, dict):
|
||||
val_ds = {"val": val_ds}
|
||||
|
||||
val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len)
|
||||
|
||||
val_dls = {}
|
||||
for val_name, ds in val_ds.items():
|
||||
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
|
||||
# otherwise too much data may be dropped
|
||||
batch_size = min(
|
||||
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
|
||||
)
|
||||
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
|
||||
|
||||
if batch_size == 0 and gpc.is_rank_for_log():
|
||||
logger.info(f"skip validate {val_name}.")
|
||||
continue
|
||||
|
||||
val_dls[val_name] = get_dpsampler_dataloader(
|
||||
ds, shuffle=False, num_workers=num_worker, batch_size=batch_size, collate_fn=val_collate_fn, drop_last=True
|
||||
) # drop_last=True, otherwise it may cause problems in the last batch
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
|
||||
f"samples {str(len(val_dls[val_name]))}."
|
||||
)
|
||||
|
||||
return val_dls
|
||||
|
||||
|
||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
|
||||
"""
|
||||
Load and return the new batch data based on training data loader.
|
||||
|
@ -359,6 +400,7 @@ def main(args):
|
|||
# init setting
|
||||
skip_batches = gpc.config.data.skip_batches
|
||||
total_steps = gpc.config.data.total_steps
|
||||
valid_every = gpc.config.data.valid_every
|
||||
load_optimizer = gpc.config.ckpt.load_optimizer
|
||||
label_smoothing = gpc.config.loss.label_smoothing
|
||||
lr = gpc.config.adam.lr
|
||||
|
@ -435,8 +477,9 @@ def main(args):
|
|||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||
|
||||
# initialize the train data loader
|
||||
# initialize the train and validation data loader
|
||||
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
||||
val_dls = get_validation_data_loader()
|
||||
train_state.init_batch_sampler(train_dl)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
|
@ -553,8 +596,19 @@ def main(args):
|
|||
|
||||
timer("one-batch").stop()
|
||||
|
||||
# evaluate on validation data loaders
|
||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||
evaluate_on_val_dls(
|
||||
trainer=trainer,
|
||||
val_dls=val_dls,
|
||||
writer=writer,
|
||||
logger=logger,
|
||||
step_count=train_state.step_count,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
# save batch sampler that tracks the true consumed samples
|
||||
if enable_save_ckpt and train_state.step_count % checkpoint_every == 0:
|
||||
save_checkpoint(
|
||||
folder=save_ckpt_folder,
|
||||
|
|
Loading…
Reference in New Issue