diff --git a/tests/test_training/7B_check_acc.py b/tests/test_training/7B_check_acc.py new file mode 100644 index 0000000..da7e4b3 --- /dev/null +++ b/tests/test_training/7B_check_acc.py @@ -0,0 +1,167 @@ +JOB_NAME = "7b_train" +DO_ALERT = False + +SEQ_LEN = 2048 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +MLP_RATIO = 8 / 3 +NUM_LAYER = 32 +VOCAB_SIZE = 103168 + +MODEL_ONLY_FOLDER = "/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt/init" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +# SAVE_CKPT_FOLDER = "local:llm_ckpts_0925_9" +# LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 100 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + auto_resume=False, + # save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), + # load_ckpt_folder="local:llm_ckpts/", + # # 'load_ckpt_info' setting guide: + # # 1. the 'path' indicate ckpt path, + # # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + # checkpoint_every=CHECKPOINT_EVERY, + # async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + # async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + # oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = "/path/to/dataset" +VALID_FOLDER = "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=2, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=2, + # defaults to the value of micro_num + valid_micro_num=2, + # defaults to 0, means disable evaluate + valid_every=5000, + pack_sample_into_one=False, + total_steps=501, + skip_batches="", + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + # train_folder=TRAIN_FOLDER, + # valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=500, + diag_outlier_ratio=1.1, + num_worker=4, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=True, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +model = dict( + checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + num_attention_heads=NUM_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. +) +""" +zero1 parallel: + 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. + 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +tensor parallel: tensor parallel size, usually the number of GPUs per node. +""" +parallel = dict( + zero1=dict(size=8, fsdp=False), + tensor=1, + pipeline=dict(size=1, interleaved_overlap=True), + sequence_parallel=False, +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + ), +) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py new file mode 100644 index 0000000..f81b4c6 --- /dev/null +++ b/tests/test_training/train_CI.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import os +import sys + +script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(script_dir, "../../")) +sys.path.append(project_root) + +# pylint: disable=C0413,W0612,W0611 +import socket +import time +import traceback +from functools import partial + +import torch +import torch.distributed as dist + +import internlm +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.scheduler import SchedulerMetricHook +from internlm.core.trainer import TrainState +from internlm.initialize import initialize_distributed_env +from internlm.model.loss import FlashGPTLMLoss +from internlm.model.metrics import AccPerplex +from internlm.monitor import initialize_monitor_manager, send_alert_message +from internlm.monitor.monitor import monitor_manager as mm +from internlm.train import ( + get_train_data_loader, + get_validation_data_loader, + initialize_llm_profile, + initialize_model, + initialize_optimizer, + load_new_batch, + record_current_batch_training_metrics, +) +from internlm.utils.common import ( + BatchSkipper, + get_megatron_flops, + launch_time, + parse_args, +) +from internlm.utils.evaluation import evaluate_on_val_dls +from internlm.utils.gputest import empty_cache_and_diag +from internlm.utils.logger import get_logger, initialize_uniscale_logger +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.model_checkpoint import CheckpointManager +from internlm.utils.parallel import get_parallel_log_file_name +from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler +from internlm.utils.writer import Writer + +# global llm logger +logger = get_logger(__file__) + + +def initialize_llm_logger(start_time: str): + """ + Initialize customed uniscale logger. + + Args: + start_time (str): The launch time of current training job. + + Returns: The instance of uniscale logger. + """ + + uniscale_logger = initialize_uniscale_logger( + job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name() + ) + if uniscale_logger is not None: + global logger + logger = uniscale_logger + + return uniscale_logger + + +def check_model_weights(model, ckpt_path): + model1_dict = torch.load(ckpt_path, map_location="cuda") + model2_dict = model.state_dict() + + for key in model1_dict.keys(): + if key in model2_dict: + tensor1 = model1_dict[key] + tensor2 = model2_dict[key] + assert torch.allclose(tensor1, tensor2, rtol=3e-2, atol=3e-2) + + +def main(args): + # init setting + skip_batches = gpc.config.data.skip_batches + total_steps = gpc.config.data.total_steps + valid_every = gpc.config.data.valid_every + label_smoothing = gpc.config.loss.label_smoothing + + get_tflops_func = partial( + get_megatron_flops, + checkpoint=gpc.config.model.checkpoint, + seq_len=gpc.config.SEQ_LEN, + hidden_size=gpc.config.model.hidden_size, + num_layers=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), + global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), + mlp_ratio=gpc.config.MLP_RATIO, + ) + + # get and broadcast current time + current_time = launch_time() + objs = [current_time] + dist.broadcast_object_list(objs, src=0) + current_time = objs[0] + + # initialize customed llm logger + uniscale_logger = initialize_llm_logger(start_time=current_time) + + # initialize model + model = initialize_model() + + with open(args.config, "r") as f: + config_lines = f.readlines() + + # initialize loss function + criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) + + # initialize the train and validation data loader + train_dl, dataset_types = get_train_data_loader(num_worker=4) + val_dls = get_validation_data_loader() + + # initialize and resume train state + train_state = TrainState(gpc.config, train_dl.batch_sampler) + + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) + + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + model_config_file="".join(config_lines), + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + + # Loading other persistent training states. + ckpt_manager.try_resume_training(train_state, current_time) + + # initialize customed llm writer + writer = Writer( + job_name=gpc.config.JOB_NAME, + launch_time=current_time, + file_name=get_parallel_log_file_name(), + tensorboard_folder=gpc.config.tensorboard_folder, + resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt. + step_count=train_state.step_count, # resume from ckpt. + config=config_lines, + logger=logger, + enable_tb=gpc.config.enable_tb, + ) + + # initialize metric for calculating accuracy and perplexity + metric = AccPerplex( + device=torch.cuda.current_device(), + tp_pg=gpc.get_group(ParallelMode.TENSOR), + dp_pg=gpc.get_group(ParallelMode.DATA), + dataset_types=dataset_types, + ) + + # initialize trainer + scheduler_hooks = [ + SchedulerMetricHook( + metric=metric, + skip=( + gpc.is_using_pp() + and hasattr(gpc.config.model, "num_chunks") + and gpc.config.model.num_chunks > 1 + and gpc.config.parallel["pipeline"].get("interleaved_overlap", False) + ), + ), + ] + + trainer, train_dl, _, _ = internlm.initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dl, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=scheduler_hooks, + ) + + # initialize simple memory profiler + if args.profiling: + memory_profiler = SimpleMemoryProfiler( + model, + optimizer.optim, + log_folder=f"memory_trace/rank{gpc.get_global_rank()}_" + + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", + ) + else: + memory_profiler = None + + # initialize the batch skipper + batch_skipper = BatchSkipper(skip_batches) + + trainer.train() + + # transfer the train data loader into train data iterator + train_iter = iter(train_dl) + + with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: + # start iterating the train data and begin training + for batch_count in range(train_state.batch_count, total_steps): + empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) + start_time = time.time() + timer("one-batch").start() + + # load batch data + # batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) + # pylint: disable=C0301 + batch_index = batch_count % 1000 + if batch_index == 0: + data_local_rank = gpc.get_local_rank(ParallelMode.DATA) + batch_step = (batch_count // 1000 + 1) * 1000 + data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt" + data_1000 = torch.load(data_path, map_location=torch.device("cpu")) + batch = data_1000[batch_index] + + # record the consumed samples in training + train_state.batch_count = batch_count + train_state.num_consumed_samples_in_epoch += len(batch[1]) + if batch_skipper(batch_count): # skip this batch + if gpc.is_rank_for_log(): + logger.info(f"Skip batch count:`{batch_count}`...") + timer("one-batch").stop() + continue + + # zero the grads of parameters + trainer.zero_grad() + # process data + if batch[0].get("type_ids", None) is not None: + metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + + # do forward and backward + timer("fwd-bwd").start() + + moe_loss = None + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + else: + _, _, loss = trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + timer("fwd-bwd").stop() + + # update parameters, and returns (success_update, grad_norm) + trainer_result = trainer.step() + assert trainer_result is not None + + success_update, grad_norm_groups = trainer_result + if success_update: # update parameters successfully + train_state.step_count += 1 + else: + train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. + if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case + logger.warning(f"Warning: skip parameter update at step {batch_count}.") + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message=f"Warning: skip parameter update at step {batch_count}.", + ) + + # calculate and record the training metrics, eg. loss, accuracy and so on. + record_current_batch_training_metrics( + get_tflops_func=get_tflops_func, + logger=logger, + writer=writer, + success_update=success_update, + batch_count=batch_count, + batch=batch, + train_state=train_state, + optimizer=optimizer, + beta2_scheduler=beta2_scheduler, + trainer=trainer, + start_time=start_time, + loss=loss, + moe_loss=moe_loss, + grad_norm=grad_norm_groups, + metric=metric, + update_panel=uniscale_logger is not None, + ) + + timer("one-batch").stop() + + # evaluate on validation data loaders + if valid_every > 0 and train_state.step_count % valid_every == 0: + evaluate_on_val_dls( + trainer=trainer, + val_dls=val_dls, + writer=writer, + logger=logger, + step_count=train_state.step_count, + update_panel=uniscale_logger is not None, + ) + + if batch_count > 0 and batch_count % 100 == 0: + ckpt_path = os.path.join( + "/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt" + ) + check_model_weights(model, ckpt_path) + + # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" + # # save batch sampler that tracks the true consumed samples + now_break = ckpt_manager.try_save_checkpoint(train_state) + if now_break: + break + + if memory_profiler is not None: + memory_profiler.step() + + if batch_count % 2 == 0: + prof.step() + + ckpt_manager.wait_async_upload_finish() + + +if __name__ == "__main__": + args = parse_args() + hostname = socket.gethostname() + + # initialize distributed environment + initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + assert hasattr(gpc, "config") and gpc.config is not None + + # initialize monitor manager context + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): + try: + main(args) + except Exception: + logger.error( + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", + ) + mm.monitor_exception( + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() + )