From 1f7304a8bb97007e1022441253bedf8fdafba5c6 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 1 Aug 2023 17:37:32 +0800 Subject: [PATCH] feat(utils/logger.py): support uniscale logger (#152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * style(internlm): fix lint error * feat(utils/logger.py): support uniscale logger * fix(utils/logger.py): fix import circular error * feat(train.py): support dashboard metric panel and fix ci train config * fix(ci_scripts/train/slurm_train.sh): fix ci train error * fix(ci_scripts/train/torchrun.sh): fix ci train error * fix(ci_scripts/train): restore ci update * fix(config.json): delete alert webhook * feat(train.py): optimize func init logger * feat(config.json): delete config.json --------- Co-authored-by: 黄婷 Co-authored-by: huangting.p --- ci_scripts/train/ci_7B_sft.py | 8 +-- .../core/scheduler/no_pipeline_scheduler.py | 2 +- internlm/core/scheduler/pipeline_scheduler.py | 7 ++- internlm/data/packed_dataset.py | 4 +- internlm/data/utils.py | 3 +- internlm/initialize/initialize_trainer.py | 8 +-- internlm/model/modeling_internlm.py | 2 +- internlm/utils/logger.py | 58 +++++++++++++++++++ internlm/utils/parallel.py | 13 +++++ internlm/utils/writer.py | 32 +++++----- train.py | 40 ++++++++++++- 11 files changed, 139 insertions(+), 38 deletions(-) diff --git a/ci_scripts/train/ci_7B_sft.py b/ci_scripts/train/ci_7B_sft.py index 52b6f33..bc881c0 100644 --- a/ci_scripts/train/ci_7B_sft.py +++ b/ci_scripts/train/ci_7B_sft.py @@ -11,9 +11,9 @@ VOCAB_SIZE = 103168 # fs: 'local:/mnt/nfs/XXX' # oss: 'boto3:s3://model_weights/XXX' MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" -#SAVE_CKPT_FOLDER = "local:llm_ckpts" +# SAVE_CKPT_FOLDER = "local:llm_ckpts" SAVE_CKPT_FOLDER = "local:llm_ckpts" -#LOAD_CKPT_FOLDER = "local:llm_ckpts/49" +# LOAD_CKPT_FOLDER = "local:llm_ckpts/49" ckpt = dict( # Path to save training ckpt. save_ckpt_folder=SAVE_CKPT_FOLDER, @@ -119,8 +119,8 @@ zero1 parallel: 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -pipeline parallel: pipeline parallel size, only 1 is accepted currently. -tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently. +pipeline parallel: pipeline parallel size. +tensor parallel: tensor parallel size, usually the number of GPUs per node. """ parallel = dict( zero1=8, diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index cdf3edc..65fc0af 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -62,7 +62,7 @@ class NonPipelineScheduler(BaseScheduler): data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size ) self._grad_accum_offset += self._grad_accum_batch_size - + if self.data_process_func: _data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"]) _label = self.data_process_func(_label, _data["cu_seqlens"]) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 1ce35ee..3d72ddb 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -37,7 +37,8 @@ def get_tensor_shape(): ) else: tensor_shape = ( - gpc.config.data["micro_bsz"], gpc.config.SEQ_LEN, + gpc.config.data["micro_bsz"], + gpc.config.SEQ_LEN, gpc.config.HIDDEN_SIZE, ) return tensor_shape @@ -138,7 +139,7 @@ class PipelineScheduler(BaseScheduler): micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"] ) micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"]) - + micro_batch_data.pop("cu_seqlens") micro_batch_data.pop("indexes") @@ -856,4 +857,4 @@ class InterleavedPipelineScheduler(PipelineScheduler): output, label = pack_return_tensors(return_tensors) return output, label, accum_loss else: - return None, None, accum_loss \ No newline at end of file + return None, None, accum_loss diff --git a/internlm/data/packed_dataset.py b/internlm/data/packed_dataset.py index 25a41f5..c0d689f 100644 --- a/internlm/data/packed_dataset.py +++ b/internlm/data/packed_dataset.py @@ -185,7 +185,7 @@ class PackedDataset(torch.utils.data.Dataset): out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids} return out - + def __getitem__(self, item: int) -> Dict: """Given the index, it returns a dict as { @@ -199,7 +199,7 @@ class PackedDataset(torch.utils.data.Dataset): if gpc.config.model.use_flash_attn: pos_before, token_id_before, pos_after, token_id_after = self.mapping(item) return self.build_pack(pos_before, token_id_before, pos_after, token_id_after) - + return self.build_unpack(item) diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 4d9c775..a86984a 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -18,6 +18,7 @@ def get_dataset_type_id(path): assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}" return match_idxes[0] + def unpack_data(input_ids, cu_seqlens): """ input_ids: (n, packed_length) @@ -42,4 +43,4 @@ def unpack_data(input_ids, cu_seqlens): if bsz == 1: outputs = outputs.squeeze(0) - return outputs \ No newline at end of file + return outputs diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index 801012a..0758bc2 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -85,10 +85,6 @@ def initialize_trainer( if gpc.is_using_pp(): gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num tensor_shape = get_tensor_shape() - # if gpc.config.model.use_flash_attn: - # tensor_shape = get_tensor_shape() - # else: - # tensor_shape = None use_interleaved = ( hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1 ) @@ -112,7 +108,9 @@ def initialize_trainer( scatter_gather_tensors=scatter_gather, ) else: - scheduler = NonPipelineScheduler(data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation) + scheduler = NonPipelineScheduler( + data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation + ) # initialize engine for trainer engine = Engine( diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 11340a4..8a41068 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -365,7 +365,7 @@ class PackedFlashInternLm1D(nn.Module): if isinstance(cu_seqlens, list): assert len(cu_seqlens) == 1 cu_seqlens = cu_seqlens[0].to(hidden_states.device) - + if cu_seqlens is not None: cu_seqlens = cu_seqlens.squeeze(0) hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state, diff --git a/internlm/utils/logger.py b/internlm/utils/logger.py index a4a9f03..c5906a8 100644 --- a/internlm/utils/logger.py +++ b/internlm/utils/logger.py @@ -2,6 +2,8 @@ # -*- encoding: utf-8 -*- import logging +import os + LOGGER_NAME = "internlm" LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s" @@ -11,6 +13,8 @@ LOGGER_LEVEL_HELP = ( "The logging level threshold, choices=['debug', 'info', 'warning', 'error', 'critical'], default='info'" ) +uniscale_logger = None + def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL) -> logging.Logger: """Configure the logger that is used for uniscale framework. @@ -24,6 +28,10 @@ def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL logger (logging.Logger): the created or modified logger. """ + + if uniscale_logger is not None: + return uniscale_logger + logger = logging.getLogger(logger_name) if logging_level not in LOGGER_LEVEL_CHOICES: @@ -39,3 +47,53 @@ def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL logger.addHandler(handler) return logger + + +def initialize_uniscale_logger( + job_name: str = None, + launch_time: str = None, + file_name: str = None, + name: str = LOGGER_NAME, + level: str = LOGGER_LEVEL, + file_path: str = None, + is_std: bool = True, +): + """ + Initialize uniscale logger. + + Args: + job_name (str): The name of training job, defaults to None. + launch_time (str): The launch time of training job, defaults to None. + file_name (str): The log file name, defaults to None. + name (str): The logger name, defaults to "internlm". + level (str): The log level, defaults to "info". + file_path (str): The log file path, defaults to None. + is_std (bool): Whether to output to console, defaults to True. + + Returns: + Uniscale logger instance. + """ + + try: + from uniscale_monitoring import get_logger as get_uniscale_logger + except ImportError: + print("Failed to import module uniscale_monitoring. Use default python logger.") + return None + + if not file_path: + assert ( + job_name and launch_time and file_name + ), "If file_path is None, job_name, launch_time and file_name must be setted." + log_file_name = file_name + log_folder = os.path.join(job_name, launch_time, "logs") + log_dir = os.path.join(log_folder, log_file_name) + file_path = log_dir + + logger = get_uniscale_logger(name=name, level=level, filename=file_path, is_std=is_std) + if isinstance(logger, (list, tuple)): + logger = list(logger)[0] + + global uniscale_logger + uniscale_logger = logger + + return logger diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 87ea3a6..cffcdc1 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -46,3 +46,16 @@ def sync_model_param_within_tp(model): def is_no_pp_or_last_stage(): return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) + + +def get_parallel_log_file_name(): + if gpc.is_rank_for_log(): + fn_prefix = "main_" # Indicates a rank with more output information + else: + fn_prefix = "" + + log_file_name = ( + f"{fn_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_" + f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}" + ) + return log_file_name diff --git a/internlm/utils/writer.py b/internlm/utils/writer.py index 9aaf750..311c6b3 100644 --- a/internlm/utils/writer.py +++ b/internlm/utils/writer.py @@ -8,23 +8,9 @@ from functools import partial import torch from torch.utils.tensorboard import SummaryWriter -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -def get_tb_log_file_name(): - if gpc.is_rank_for_log(): - tb_prefix = "main_" # Indicates a rank with more output information - else: - tb_prefix = "" - - tb_log_file_name = ( - f"{tb_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_" - f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}" - ) - return tb_log_file_name - - def copy_ignore_folder(source_path, target_path): os.system(f"cp -r {source_path}/* {target_path}/") @@ -40,16 +26,18 @@ def tb_save_run_info(writer, config_lines, global_step=0): def init_tb_writer( - launch_time, + job_name: str, + launch_time: str, + file_name: str, tensorboard_folder: str, resume_tb_folder: str, step_count: int, config: str, logger: logging.Logger, ): - tb_log_file_name = get_tb_log_file_name() + tb_log_file_name = file_name if not tensorboard_folder: - tb_folder = os.path.join(gpc.config.JOB_NAME, launch_time) + tb_folder = os.path.join(job_name, launch_time, "tensorboards") else: tb_folder = tensorboard_folder @@ -62,7 +50,7 @@ def init_tb_writer( tb_logdir = os.path.join(tb_folder, tb_log_file_name) writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3) - writer.add_text(tag="job_name", text_string=gpc.config.JOB_NAME, global_step=step_count) + writer.add_text(tag="job_name", text_string=job_name, global_step=step_count) writer.add_text(tag="tensorboard_folder", text_string=tb_logdir, global_step=step_count) torch.distributed.broadcast_object_list([tb_folder], src=0) @@ -95,7 +83,9 @@ class Writer: Customed writer based on tensorboard for recording training metrics. Args: + job_name (str): The name of training job, defaults to None. launch_time (str): A string representing the launch time of the training. + file_name (str): The log file name, defaults to None. tensorboard_folder (str): A string representing the folder for saving tensorboard logs. resume_tb_folder (str): A string representing the folder for resuming tensorboard logs. step_count (int): An integer representing the step count of the training. @@ -107,7 +97,9 @@ class Writer: def __init__( self, - launch_time: str, + job_name: str = None, + launch_time: str = None, + file_name: str = None, tensorboard_folder: str = None, resume_tb_folder: str = None, step_count: int = 0, @@ -117,7 +109,9 @@ class Writer: ) -> None: self.enable_tb = enable_tb self.tb_writer, self.tb_logdir = init_tb_writer( + job_name=job_name, launch_time=launch_time, + file_name=file_name, tensorboard_folder=tensorboard_folder, resume_tb_folder=resume_tb_folder, step_count=step_count, diff --git a/train.py b/train.py index bca8b54..d950c40 100644 --- a/train.py +++ b/train.py @@ -39,7 +39,7 @@ from internlm.utils.common import ( launch_time, parse_args, ) -from internlm.utils.logger import get_logger +from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.model_checkpoint import ( load_context, @@ -50,6 +50,7 @@ from internlm.utils.model_checkpoint import ( save_checkpoint, ) from internlm.utils.parallel import ( + get_parallel_log_file_name, is_no_pp_or_last_stage, sync_model_param, sync_model_param_within_tp, @@ -87,6 +88,17 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port assert launcher in ["slurm", "torch"], "launcher only support slurm or torch" +def initialize_llm_logger(start_time: str): + uniscale_logger = initialize_uniscale_logger( + job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name() + ) + if uniscale_logger is not None: + global logger + logger = uniscale_logger + + return uniscale_logger + + def initialize_model(): """ Initialize model. @@ -254,6 +266,7 @@ def record_current_batch_training_metrics( loss, grad_norm, metric, + update_panel, ): """ Print some training metrics of current batch. @@ -318,7 +331,24 @@ def record_current_batch_training_metrics( line += f"{key}={value} " writer.add_scalar(key=key, value=value, step=train_state.step_count) - logger.info(line) + if update_panel: + logger.info( + line, + extra={ + "step": batch_count, + "lr": lr, + "num_consumed_tokens": train_state.num_consumed_tokens, + "grad_norm": grad_norm, + "loss": loss.item(), + "flops": tflops, + "tgs": tk_per_gpu, + "acc": acc_perplex["acc"], + "perplexity": acc_perplex["perplexity"], + "fwd_bwd_time": fwd_bwd_time, + }, + ) + else: + logger.info(line) def main(args): @@ -359,11 +389,16 @@ def main(args): dist.broadcast_object_list(objs, src=0) current_time = objs[0] + # initialize customed llm logger + uniscale_logger = initialize_llm_logger(start_time=current_time) + # initialize customed llm writer with open(args.config, "r") as f: config_lines = f.readlines() writer = Writer( + job_name=gpc.config.JOB_NAME, launch_time=current_time, + file_name=get_parallel_log_file_name(), tensorboard_folder=gpc.config.tensorboard_folder, resume_tb_folder=gpc.config.resume_tb_folder, config=config_lines, @@ -513,6 +548,7 @@ def main(args): loss=loss, grad_norm=grad_norm, metric=metric, + update_panel=uniscale_logger is not None, ) timer("one-batch").stop()