mirror of https://github.com/InternLM/InternLM
feat(utils/logger.py): support uniscale logger (#152)
* style(internlm): fix lint error * feat(utils/logger.py): support uniscale logger * fix(utils/logger.py): fix import circular error * feat(train.py): support dashboard metric panel and fix ci train config * fix(ci_scripts/train/slurm_train.sh): fix ci train error * fix(ci_scripts/train/torchrun.sh): fix ci train error * fix(ci_scripts/train): restore ci update * fix(config.json): delete alert webhook * feat(train.py): optimize func init logger * feat(config.json): delete config.json --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local> Co-authored-by: huangting.p <huangting@sensetime.com>pull/166/head
parent
307c4741d1
commit
1f7304a8bb
|
@ -11,9 +11,9 @@ VOCAB_SIZE = 103168
|
||||||
# fs: 'local:/mnt/nfs/XXX'
|
# fs: 'local:/mnt/nfs/XXX'
|
||||||
# oss: 'boto3:s3://model_weights/XXX'
|
# oss: 'boto3:s3://model_weights/XXX'
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||||
#SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
# SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||||
#LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
# LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||||
ckpt = dict(
|
ckpt = dict(
|
||||||
# Path to save training ckpt.
|
# Path to save training ckpt.
|
||||||
save_ckpt_folder=SAVE_CKPT_FOLDER,
|
save_ckpt_folder=SAVE_CKPT_FOLDER,
|
||||||
|
@ -119,8 +119,8 @@ zero1 parallel:
|
||||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||||
pipeline parallel: pipeline parallel size, only 1 is accepted currently.
|
pipeline parallel: pipeline parallel size.
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node, only 1 is accepted currently.
|
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=8,
|
zero1=8,
|
||||||
|
|
|
@ -37,7 +37,8 @@ def get_tensor_shape():
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tensor_shape = (
|
tensor_shape = (
|
||||||
gpc.config.data["micro_bsz"], gpc.config.SEQ_LEN,
|
gpc.config.data["micro_bsz"],
|
||||||
|
gpc.config.SEQ_LEN,
|
||||||
gpc.config.HIDDEN_SIZE,
|
gpc.config.HIDDEN_SIZE,
|
||||||
)
|
)
|
||||||
return tensor_shape
|
return tensor_shape
|
||||||
|
|
|
@ -18,6 +18,7 @@ def get_dataset_type_id(path):
|
||||||
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
||||||
return match_idxes[0]
|
return match_idxes[0]
|
||||||
|
|
||||||
|
|
||||||
def unpack_data(input_ids, cu_seqlens):
|
def unpack_data(input_ids, cu_seqlens):
|
||||||
"""
|
"""
|
||||||
input_ids: (n, packed_length)
|
input_ids: (n, packed_length)
|
||||||
|
|
|
@ -85,10 +85,6 @@ def initialize_trainer(
|
||||||
if gpc.is_using_pp():
|
if gpc.is_using_pp():
|
||||||
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
||||||
tensor_shape = get_tensor_shape()
|
tensor_shape = get_tensor_shape()
|
||||||
# if gpc.config.model.use_flash_attn:
|
|
||||||
# tensor_shape = get_tensor_shape()
|
|
||||||
# else:
|
|
||||||
# tensor_shape = None
|
|
||||||
use_interleaved = (
|
use_interleaved = (
|
||||||
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
|
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
|
||||||
)
|
)
|
||||||
|
@ -112,7 +108,9 @@ def initialize_trainer(
|
||||||
scatter_gather_tensors=scatter_gather,
|
scatter_gather_tensors=scatter_gather,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scheduler = NonPipelineScheduler(data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
scheduler = NonPipelineScheduler(
|
||||||
|
data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation
|
||||||
|
)
|
||||||
|
|
||||||
# initialize engine for trainer
|
# initialize engine for trainer
|
||||||
engine = Engine(
|
engine = Engine(
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
LOGGER_NAME = "internlm"
|
LOGGER_NAME = "internlm"
|
||||||
LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s"
|
LOGGER_FORMAT = "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s in %(funcName)s -- %(message)s"
|
||||||
|
@ -11,6 +13,8 @@ LOGGER_LEVEL_HELP = (
|
||||||
"The logging level threshold, choices=['debug', 'info', 'warning', 'error', 'critical'], default='info'"
|
"The logging level threshold, choices=['debug', 'info', 'warning', 'error', 'critical'], default='info'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
uniscale_logger = None
|
||||||
|
|
||||||
|
|
||||||
def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL) -> logging.Logger:
|
def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL) -> logging.Logger:
|
||||||
"""Configure the logger that is used for uniscale framework.
|
"""Configure the logger that is used for uniscale framework.
|
||||||
|
@ -24,6 +28,10 @@ def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL
|
||||||
logger (logging.Logger): the created or modified logger.
|
logger (logging.Logger): the created or modified logger.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if uniscale_logger is not None:
|
||||||
|
return uniscale_logger
|
||||||
|
|
||||||
logger = logging.getLogger(logger_name)
|
logger = logging.getLogger(logger_name)
|
||||||
|
|
||||||
if logging_level not in LOGGER_LEVEL_CHOICES:
|
if logging_level not in LOGGER_LEVEL_CHOICES:
|
||||||
|
@ -39,3 +47,53 @@ def get_logger(logger_name: str = LOGGER_NAME, logging_level: str = LOGGER_LEVEL
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_uniscale_logger(
|
||||||
|
job_name: str = None,
|
||||||
|
launch_time: str = None,
|
||||||
|
file_name: str = None,
|
||||||
|
name: str = LOGGER_NAME,
|
||||||
|
level: str = LOGGER_LEVEL,
|
||||||
|
file_path: str = None,
|
||||||
|
is_std: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize uniscale logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_name (str): The name of training job, defaults to None.
|
||||||
|
launch_time (str): The launch time of training job, defaults to None.
|
||||||
|
file_name (str): The log file name, defaults to None.
|
||||||
|
name (str): The logger name, defaults to "internlm".
|
||||||
|
level (str): The log level, defaults to "info".
|
||||||
|
file_path (str): The log file path, defaults to None.
|
||||||
|
is_std (bool): Whether to output to console, defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Uniscale logger instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from uniscale_monitoring import get_logger as get_uniscale_logger
|
||||||
|
except ImportError:
|
||||||
|
print("Failed to import module uniscale_monitoring. Use default python logger.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
assert (
|
||||||
|
job_name and launch_time and file_name
|
||||||
|
), "If file_path is None, job_name, launch_time and file_name must be setted."
|
||||||
|
log_file_name = file_name
|
||||||
|
log_folder = os.path.join(job_name, launch_time, "logs")
|
||||||
|
log_dir = os.path.join(log_folder, log_file_name)
|
||||||
|
file_path = log_dir
|
||||||
|
|
||||||
|
logger = get_uniscale_logger(name=name, level=level, filename=file_path, is_std=is_std)
|
||||||
|
if isinstance(logger, (list, tuple)):
|
||||||
|
logger = list(logger)[0]
|
||||||
|
|
||||||
|
global uniscale_logger
|
||||||
|
uniscale_logger = logger
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
|
@ -46,3 +46,16 @@ def sync_model_param_within_tp(model):
|
||||||
|
|
||||||
def is_no_pp_or_last_stage():
|
def is_no_pp_or_last_stage():
|
||||||
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parallel_log_file_name():
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
fn_prefix = "main_" # Indicates a rank with more output information
|
||||||
|
else:
|
||||||
|
fn_prefix = ""
|
||||||
|
|
||||||
|
log_file_name = (
|
||||||
|
f"{fn_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||||
|
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}"
|
||||||
|
)
|
||||||
|
return log_file_name
|
||||||
|
|
|
@ -8,23 +8,9 @@ from functools import partial
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
|
|
||||||
|
|
||||||
def get_tb_log_file_name():
|
|
||||||
if gpc.is_rank_for_log():
|
|
||||||
tb_prefix = "main_" # Indicates a rank with more output information
|
|
||||||
else:
|
|
||||||
tb_prefix = ""
|
|
||||||
|
|
||||||
tb_log_file_name = (
|
|
||||||
f"{tb_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_"
|
|
||||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}"
|
|
||||||
)
|
|
||||||
return tb_log_file_name
|
|
||||||
|
|
||||||
|
|
||||||
def copy_ignore_folder(source_path, target_path):
|
def copy_ignore_folder(source_path, target_path):
|
||||||
os.system(f"cp -r {source_path}/* {target_path}/")
|
os.system(f"cp -r {source_path}/* {target_path}/")
|
||||||
|
|
||||||
|
@ -40,16 +26,18 @@ def tb_save_run_info(writer, config_lines, global_step=0):
|
||||||
|
|
||||||
|
|
||||||
def init_tb_writer(
|
def init_tb_writer(
|
||||||
launch_time,
|
job_name: str,
|
||||||
|
launch_time: str,
|
||||||
|
file_name: str,
|
||||||
tensorboard_folder: str,
|
tensorboard_folder: str,
|
||||||
resume_tb_folder: str,
|
resume_tb_folder: str,
|
||||||
step_count: int,
|
step_count: int,
|
||||||
config: str,
|
config: str,
|
||||||
logger: logging.Logger,
|
logger: logging.Logger,
|
||||||
):
|
):
|
||||||
tb_log_file_name = get_tb_log_file_name()
|
tb_log_file_name = file_name
|
||||||
if not tensorboard_folder:
|
if not tensorboard_folder:
|
||||||
tb_folder = os.path.join(gpc.config.JOB_NAME, launch_time)
|
tb_folder = os.path.join(job_name, launch_time, "tensorboards")
|
||||||
else:
|
else:
|
||||||
tb_folder = tensorboard_folder
|
tb_folder = tensorboard_folder
|
||||||
|
|
||||||
|
@ -62,7 +50,7 @@ def init_tb_writer(
|
||||||
|
|
||||||
tb_logdir = os.path.join(tb_folder, tb_log_file_name)
|
tb_logdir = os.path.join(tb_folder, tb_log_file_name)
|
||||||
writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3)
|
writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3)
|
||||||
writer.add_text(tag="job_name", text_string=gpc.config.JOB_NAME, global_step=step_count)
|
writer.add_text(tag="job_name", text_string=job_name, global_step=step_count)
|
||||||
writer.add_text(tag="tensorboard_folder", text_string=tb_logdir, global_step=step_count)
|
writer.add_text(tag="tensorboard_folder", text_string=tb_logdir, global_step=step_count)
|
||||||
|
|
||||||
torch.distributed.broadcast_object_list([tb_folder], src=0)
|
torch.distributed.broadcast_object_list([tb_folder], src=0)
|
||||||
|
@ -95,7 +83,9 @@ class Writer:
|
||||||
Customed writer based on tensorboard for recording training metrics.
|
Customed writer based on tensorboard for recording training metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
job_name (str): The name of training job, defaults to None.
|
||||||
launch_time (str): A string representing the launch time of the training.
|
launch_time (str): A string representing the launch time of the training.
|
||||||
|
file_name (str): The log file name, defaults to None.
|
||||||
tensorboard_folder (str): A string representing the folder for saving tensorboard logs.
|
tensorboard_folder (str): A string representing the folder for saving tensorboard logs.
|
||||||
resume_tb_folder (str): A string representing the folder for resuming tensorboard logs.
|
resume_tb_folder (str): A string representing the folder for resuming tensorboard logs.
|
||||||
step_count (int): An integer representing the step count of the training.
|
step_count (int): An integer representing the step count of the training.
|
||||||
|
@ -107,7 +97,9 @@ class Writer:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
launch_time: str,
|
job_name: str = None,
|
||||||
|
launch_time: str = None,
|
||||||
|
file_name: str = None,
|
||||||
tensorboard_folder: str = None,
|
tensorboard_folder: str = None,
|
||||||
resume_tb_folder: str = None,
|
resume_tb_folder: str = None,
|
||||||
step_count: int = 0,
|
step_count: int = 0,
|
||||||
|
@ -117,7 +109,9 @@ class Writer:
|
||||||
) -> None:
|
) -> None:
|
||||||
self.enable_tb = enable_tb
|
self.enable_tb = enable_tb
|
||||||
self.tb_writer, self.tb_logdir = init_tb_writer(
|
self.tb_writer, self.tb_logdir = init_tb_writer(
|
||||||
|
job_name=job_name,
|
||||||
launch_time=launch_time,
|
launch_time=launch_time,
|
||||||
|
file_name=file_name,
|
||||||
tensorboard_folder=tensorboard_folder,
|
tensorboard_folder=tensorboard_folder,
|
||||||
resume_tb_folder=resume_tb_folder,
|
resume_tb_folder=resume_tb_folder,
|
||||||
step_count=step_count,
|
step_count=step_count,
|
||||||
|
|
40
train.py
40
train.py
|
@ -39,7 +39,7 @@ from internlm.utils.common import (
|
||||||
launch_time,
|
launch_time,
|
||||||
parse_args,
|
parse_args,
|
||||||
)
|
)
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.model_checkpoint import (
|
from internlm.utils.model_checkpoint import (
|
||||||
load_context,
|
load_context,
|
||||||
|
@ -50,6 +50,7 @@ from internlm.utils.model_checkpoint import (
|
||||||
save_checkpoint,
|
save_checkpoint,
|
||||||
)
|
)
|
||||||
from internlm.utils.parallel import (
|
from internlm.utils.parallel import (
|
||||||
|
get_parallel_log_file_name,
|
||||||
is_no_pp_or_last_stage,
|
is_no_pp_or_last_stage,
|
||||||
sync_model_param,
|
sync_model_param,
|
||||||
sync_model_param_within_tp,
|
sync_model_param_within_tp,
|
||||||
|
@ -87,6 +88,17 @@ def initialize_distributed_env(config: str, launcher: str = "slurm", master_port
|
||||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_llm_logger(start_time: str):
|
||||||
|
uniscale_logger = initialize_uniscale_logger(
|
||||||
|
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
|
||||||
|
)
|
||||||
|
if uniscale_logger is not None:
|
||||||
|
global logger
|
||||||
|
logger = uniscale_logger
|
||||||
|
|
||||||
|
return uniscale_logger
|
||||||
|
|
||||||
|
|
||||||
def initialize_model():
|
def initialize_model():
|
||||||
"""
|
"""
|
||||||
Initialize model.
|
Initialize model.
|
||||||
|
@ -254,6 +266,7 @@ def record_current_batch_training_metrics(
|
||||||
loss,
|
loss,
|
||||||
grad_norm,
|
grad_norm,
|
||||||
metric,
|
metric,
|
||||||
|
update_panel,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Print some training metrics of current batch.
|
Print some training metrics of current batch.
|
||||||
|
@ -318,7 +331,24 @@ def record_current_batch_training_metrics(
|
||||||
line += f"{key}={value} "
|
line += f"{key}={value} "
|
||||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||||
|
|
||||||
logger.info(line)
|
if update_panel:
|
||||||
|
logger.info(
|
||||||
|
line,
|
||||||
|
extra={
|
||||||
|
"step": batch_count,
|
||||||
|
"lr": lr,
|
||||||
|
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||||
|
"grad_norm": grad_norm,
|
||||||
|
"loss": loss.item(),
|
||||||
|
"flops": tflops,
|
||||||
|
"tgs": tk_per_gpu,
|
||||||
|
"acc": acc_perplex["acc"],
|
||||||
|
"perplexity": acc_perplex["perplexity"],
|
||||||
|
"fwd_bwd_time": fwd_bwd_time,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(line)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
@ -359,11 +389,16 @@ def main(args):
|
||||||
dist.broadcast_object_list(objs, src=0)
|
dist.broadcast_object_list(objs, src=0)
|
||||||
current_time = objs[0]
|
current_time = objs[0]
|
||||||
|
|
||||||
|
# initialize customed llm logger
|
||||||
|
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||||
|
|
||||||
# initialize customed llm writer
|
# initialize customed llm writer
|
||||||
with open(args.config, "r") as f:
|
with open(args.config, "r") as f:
|
||||||
config_lines = f.readlines()
|
config_lines = f.readlines()
|
||||||
writer = Writer(
|
writer = Writer(
|
||||||
|
job_name=gpc.config.JOB_NAME,
|
||||||
launch_time=current_time,
|
launch_time=current_time,
|
||||||
|
file_name=get_parallel_log_file_name(),
|
||||||
tensorboard_folder=gpc.config.tensorboard_folder,
|
tensorboard_folder=gpc.config.tensorboard_folder,
|
||||||
resume_tb_folder=gpc.config.resume_tb_folder,
|
resume_tb_folder=gpc.config.resume_tb_folder,
|
||||||
config=config_lines,
|
config=config_lines,
|
||||||
|
@ -513,6 +548,7 @@ def main(args):
|
||||||
loss=loss,
|
loss=loss,
|
||||||
grad_norm=grad_norm,
|
grad_norm=grad_norm,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
|
update_panel=uniscale_logger is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
timer("one-batch").stop()
|
timer("one-batch").stop()
|
||||||
|
|
Loading…
Reference in New Issue