feat(utils/evaluation.py): support evaluate (#154)

* style(internlm): fix lint error

* feat(utils/logger.py): support uniscale logger

* fix(utils/logger.py): fix import circular error

* feat(train.py): support dashboard metric panel and fix ci train config

* fix(ci_scripts/train/slurm_train.sh): fix ci train error

* fix(ci_scripts/train/torchrun.sh): fix ci train error

* feat(utils/evaluation.py): support evaluate on validation dataset

* fix(utils/evaluation.py): fix demo error

* fix(ci_scripts/train/ci_7B_sft.py): fix ci train error

* feat(initialize/launch.py): set default value for valid_bsz and valid_every

* fix(ci_scripts/train): restore ci update

* docs(configs/7B_sft.py): update comment for config

* fix(config.json): delete config.json

* fix evaluation bug in scheduler when use_flash_attn=False

* feat(scheduler/no_pipeline_scheduler.py): support micro_bsz>1 in no pp

* modify the jugement in pp and no-pp scheduler

* modify the data_process_func in evaluation

* fix bugs when use_flash_attn=False

* rename symbol

* feat(configs/7B_sft.py): change para valid_bsz to valid_micro_num

* feat(scheduler/no_pipeline_scheduler.py): update para set _grad_accum_batch_size

---------

Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>
Co-authored-by: huangting.p <huangting@sensetime.com>
Co-authored-by: yingtongxiong <974106207@qq.com>
pull/166/head^2
huangting4201 2023-08-02 19:03:59 +08:00 committed by GitHub
parent 1f7304a8bb
commit 66a23e326a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 275 additions and 11 deletions

View File

@ -26,12 +26,17 @@ ckpt = dict(
)
TRAIN_FOLDER = "/path/to/dataset"
VALID_FOLDER = "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
@ -39,6 +44,7 @@ data = dict(
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
)
grad_scaler = dict(

View File

@ -35,9 +35,7 @@ class NonPipelineScheduler(BaseScheduler):
"""
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
self._grad_accum_size = gradient_accumulation_size
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
self._grad_accum_offset = 0
super().__init__(data_process_func)
@ -145,8 +143,9 @@ class NonPipelineScheduler(BaseScheduler):
batch_data, batch_size = engine.load_batch(data_iter)
assert (
batch_size == self._grad_accum_size
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
batch_size % self._grad_accum_size == 0
), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}"
self._grad_accum_batch_size = batch_size // self._grad_accum_size
data, label = batch_data

View File

@ -133,7 +133,6 @@ class PipelineScheduler(BaseScheduler):
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
)
self.microbatch_offset += self.microbatch_size
if self.data_process_func:
micro_batch_data["input_ids"] = self.data_process_func(
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
@ -310,7 +309,7 @@ class PipelineScheduler(BaseScheduler):
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
self.load_batch(engine, data_iter)
num_warmup_microbatches = (
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1

56
internlm/data/dataset.py Normal file
View File

@ -0,0 +1,56 @@
import os
from typing import Dict
from torch.utils.data import ConcatDataset
from internlm.data.single_dataset import JsonlDataset
def get_dataset_dict(folder, split="valid") -> Dict:
"""
Return a dictionary of Datasets from a folder containing data files for validation.
Args:
folder (str): The path to the folder containing data files.
split (str): The split of the data files to be used, default is "valid".
Returns:
A dictionary containing Datasets for each folder in the given path
that contains data files with the specified split.
Raises:
AssertionError: If the given folder does not exist.
Example:
If the given folder is as follows,
- data
- zhihu
- xxx.bin
- valid.bin
- baike
- xxx.bin
- valid.bin
The returned dictionary will be,
{
'zhihu': Dataset,
'baike': Dataset
}
"""
assert os.path.exists(folder), f"folder `{folder}` not exists"
data_dict = {}
for root, dirs, files in os.walk(folder, followlinks=True):
dirs.sort() # The order is guaranteed, and the newly added data starting with z needs to be ranked behind
datasets = []
for fn in sorted(files): # Need sorted to ensure that the order is consistent
if fn.endswith(".bin") and split in fn:
fp = os.path.join(root, fn)
ds = JsonlDataset(fp)
datasets.append(ds)
if datasets:
ds = ConcatDataset(datasets=datasets)
data_dict[os.path.basename(root)] = ds
return data_dict

View File

@ -88,6 +88,12 @@ def args_sanity_check():
if "valid_folder" not in data:
data._add_item("valid_folder", None)
if "valid_micro_num" not in data:
data._add_item("valid_micro_num", data.micro_num)
if "valid_every" not in data:
data._add_item("valid_every", 0)
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"seq_len: {data.seq_len}")
@ -96,6 +102,8 @@ def args_sanity_check():
logger.info(f"packed_length: {data.packed_length}")
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
logger.info(f"min_length: {data.min_length}")
logger.info(f"valid_micro_num: {data.valid_micro_num}")
logger.info(f"valid_every: {data.valid_every}")
# processing the checkpoint config
if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0:

View File

@ -0,0 +1,143 @@
from contextlib import contextmanager
import torch
import torch.distributed as dist
from tqdm import tqdm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.metrics import AccPerplex
@contextmanager
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size):
if not gpc.is_using_pp():
prev_data_process_func = trainer.schedule.data_process_func
prev_grad_accum_size = trainer.schedule._grad_accum_size
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
try:
trainer.schedule.data_process_func = None
trainer.schedule._grad_accum_size = grad_accum_size
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
yield
finally:
trainer.schedule.data_process_func = prev_data_process_func
trainer.schedule._grad_accum_size = prev_grad_accum_size
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
@contextmanager
def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape):
if gpc.is_using_pp():
pre_data_process_func = trainer.schedule.data_process_func
prev_num_microbatches = trainer.schedule.num_microbatches
prev_tensor_shape = trainer.schedule.tensor_shape
try:
trainer.schedule.data_process_func = None
trainer.schedule.num_microbatches = num_microbatches
trainer.schedule.tensor_shape = tensor_shape
yield
finally:
trainer.schedule.data_process_func = pre_data_process_func
trainer.schedule.num_microbatches = prev_num_microbatches
trainer.schedule.tensor_shape = prev_tensor_shape
def evaluate_on_val_dls(
trainer,
val_dls,
writer,
logger,
step_count,
tokenizer=None,
update_panel: bool = False,
):
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
data_cfg = gpc.config.data
for val_name, val_dl in val_dls.items():
if len(val_dl) == 0 and verbose:
logger.info(f"Validation dataset: {val_name} is empty")
continue
val_metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
tokenizer=tokenizer,
)
val_loss = 0
val_idx = -1
for val_idx, batch in tqdm(
enumerate(val_dl),
desc="Val.",
total=len(val_dl),
position=1,
disable=not verbose,
leave=False,
):
with torch.inference_mode():
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer, num_microbatches=num_microbatches, tensor_shape=tensor_shape
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
)
else:
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer, grad_accum_size=grad_accum_size, grad_accum_batch_size=grad_accum_batch_size
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False, post_fn=val_metric
)
if verbose:
val_loss += loss.item()
assert val_idx != -1
dist.barrier()
val_res = val_metric.get_metric()
if verbose and len(val_dl) != 0:
val_loss = val_loss / (val_idx + 1 + 1e-6)
infos = {
f"val/{val_name}_loss": val_loss,
f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"],
}
val_metric = {
"step": step_count,
"val_loss": val_loss,
"val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"],
}
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
infos["step"] = step_count
if update_panel:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
extra=val_metric,
)
else:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
)
trainer.train()
torch.cuda.empty_cache()
dist.barrier()

View File

@ -4,7 +4,6 @@
import logging
import os
LOGGER_NAME = "internlm"
LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s"
LOGGER_LEVEL = "info"

View File

@ -17,8 +17,9 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler
from internlm.data.collaters import packed_collate_fn
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
from internlm.data.dataset import get_dataset_dict
from internlm.data.dummy_dataset import RandomDataset
from internlm.data.packed_dataset import (
PackedDataset,
@ -39,6 +40,7 @@ from internlm.utils.common import (
launch_time,
parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import (
@ -196,6 +198,45 @@ def get_train_data_loader(num_worker: int = 0):
return train_dl, dataset_types
def get_validation_data_loader(num_worker: int = 0):
data_cfg = gpc.config.data
if not data_cfg.valid_folder:
val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len)
else:
val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="")
if not isinstance(val_ds, dict):
val_ds = {"val": val_ds}
val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len)
val_dls = {}
for val_name, ds in val_ds.items():
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
# otherwise too much data may be dropped
batch_size = min(
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
)
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
if batch_size == 0 and gpc.is_rank_for_log():
logger.info(f"skip validate {val_name}.")
continue
val_dls[val_name] = get_dpsampler_dataloader(
ds, shuffle=False, num_workers=num_worker, batch_size=batch_size, collate_fn=val_collate_fn, drop_last=True
) # drop_last=True, otherwise it may cause problems in the last batch
if gpc.is_rank_for_log():
logger.info(
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
f"samples {str(len(val_dls[val_name]))}."
)
return val_dls
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
"""
Load and return the new batch data based on training data loader.
@ -359,6 +400,7 @@ def main(args):
# init setting
skip_batches = gpc.config.data.skip_batches
total_steps = gpc.config.data.total_steps
valid_every = gpc.config.data.valid_every
load_optimizer = gpc.config.ckpt.load_optimizer
label_smoothing = gpc.config.loss.label_smoothing
lr = gpc.config.adam.lr
@ -435,8 +477,9 @@ def main(args):
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
# initialize the train data loader
# initialize the train and validation data loader
train_dl, dataset_types = get_train_data_loader(num_worker=4)
val_dls = get_validation_data_loader()
train_state.init_batch_sampler(train_dl)
# Loading model weights must be done before zero is initialized.
@ -553,8 +596,19 @@ def main(args):
timer("one-batch").stop()
# evaluate on validation data loaders
if valid_every > 0 and train_state.step_count % valid_every == 0:
evaluate_on_val_dls(
trainer=trainer,
val_dls=val_dls,
writer=writer,
logger=logger,
step_count=train_state.step_count,
update_panel=uniscale_logger is not None,
)
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# # save batch sampler that tracks the true consumed samples
# save batch sampler that tracks the true consumed samples
if enable_save_ckpt and train_state.step_count % checkpoint_every == 0:
save_checkpoint(
folder=save_ckpt_folder,