mirror of https://github.com/InternLM/InternLM
parent
a38af602bc
commit
bd7e501b69
|
@ -0,0 +1,167 @@
|
|||
JOB_NAME = "7b_train"
|
||||
DO_ALERT = False
|
||||
|
||||
SEQ_LEN = 2048
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
NUM_LAYER = 32
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
MODEL_ONLY_FOLDER = "/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt/init"
|
||||
# Ckpt folder format:
|
||||
# fs: 'local:/mnt/nfs/XXX'
|
||||
# SAVE_CKPT_FOLDER = "local:llm_ckpts_0925_9"
|
||||
# LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||
|
||||
# boto3 Ckpt folder format:
|
||||
# import os
|
||||
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
||||
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
||||
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
||||
CHECKPOINT_EVERY = 100
|
||||
ckpt = dict(
|
||||
enable_save_ckpt=False, # enable ckpt save.
|
||||
auto_resume=False,
|
||||
# save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
|
||||
# load_ckpt_folder="local:llm_ckpts/",
|
||||
# # 'load_ckpt_info' setting guide:
|
||||
# # 1. the 'path' indicate ckpt path,
|
||||
# # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
# checkpoint_every=CHECKPOINT_EVERY,
|
||||
# async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||
# async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||
# oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
VALID_FOLDER = "/path/to/dataset"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=2,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=2,
|
||||
# defaults to 0, means disable evaluate
|
||||
valid_every=5000,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=501,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
# valid_folder=VALID_FOLDER,
|
||||
empty_cache_and_diag_interval=500,
|
||||
diag_outlier_ratio=1.1,
|
||||
num_worker=4,
|
||||
)
|
||||
|
||||
grad_scaler = dict(
|
||||
fp16=dict(
|
||||
# the initial loss scale, defaults to 2**16
|
||||
initial_scale=2**16,
|
||||
# the minimum loss scale, defaults to None
|
||||
min_scale=1,
|
||||
# the number of steps to increase loss scale when no overflow occurs
|
||||
growth_interval=1000,
|
||||
),
|
||||
# the multiplication factor for increasing loss scale, defaults to 2
|
||||
growth_factor=2,
|
||||
# the multiplication factor for decreasing loss scale, defaults to 0.5
|
||||
backoff_factor=0.5,
|
||||
# the maximum loss scale, defaults to None
|
||||
max_scale=2**24,
|
||||
# the number of overflows before decreasing loss scale, defaults to 2
|
||||
hysteresis=2,
|
||||
)
|
||||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
overlap_sync_grad=True,
|
||||
overlap_sync_param=True,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
clip_grad_norm=1.0,
|
||||
)
|
||||
|
||||
loss = dict(
|
||||
label_smoothing=0,
|
||||
)
|
||||
|
||||
adam = dict(
|
||||
lr=1e-4,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.95,
|
||||
adam_beta2_c=0,
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
lr_scheduler = dict(
|
||||
total_steps=data["total_steps"],
|
||||
init_steps=0, # optimizer_warmup_step
|
||||
warmup_ratio=0.01,
|
||||
eta_min=1e-5,
|
||||
last_epoch=-1,
|
||||
)
|
||||
|
||||
beta2_scheduler = dict(
|
||||
init_beta2=adam["adam_beta2"],
|
||||
c=adam["adam_beta2_c"],
|
||||
cur_iter=-1,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
||||
so parameters will be divided within the range of dp.
|
||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||
pipeline parallel (dict):
|
||||
1. size: int, the size of pipeline parallel.
|
||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=8, fsdp=False),
|
||||
tensor=1,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=False,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
cudnn_benchmark = False
|
||||
|
||||
monitor = dict(
|
||||
# feishu alert configs
|
||||
alert=dict(
|
||||
enable_feishu_alert=DO_ALERT,
|
||||
feishu_alert_address=None, # feishu webhook to send alert message
|
||||
light_monitor_address=None, # light_monitor address to send heartbeat
|
||||
),
|
||||
)
|
|
@ -0,0 +1,355 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
|
||||
sys.path.append(project_root)
|
||||
|
||||
# pylint: disable=C0413,W0612,W0611
|
||||
import socket
|
||||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.train import (
|
||||
get_train_data_loader,
|
||||
get_validation_data_loader,
|
||||
initialize_llm_profile,
|
||||
initialize_model,
|
||||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
)
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
get_megatron_flops,
|
||||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.gputest import empty_cache_and_diag
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
from internlm.utils.parallel import get_parallel_log_file_name
|
||||
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
|
||||
from internlm.utils.writer import Writer
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def initialize_llm_logger(start_time: str):
|
||||
"""
|
||||
Initialize customed uniscale logger.
|
||||
|
||||
Args:
|
||||
start_time (str): The launch time of current training job.
|
||||
|
||||
Returns: The instance of uniscale logger.
|
||||
"""
|
||||
|
||||
uniscale_logger = initialize_uniscale_logger(
|
||||
job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name()
|
||||
)
|
||||
if uniscale_logger is not None:
|
||||
global logger
|
||||
logger = uniscale_logger
|
||||
|
||||
return uniscale_logger
|
||||
|
||||
|
||||
def check_model_weights(model, ckpt_path):
|
||||
model1_dict = torch.load(ckpt_path, map_location="cuda")
|
||||
model2_dict = model.state_dict()
|
||||
|
||||
for key in model1_dict.keys():
|
||||
if key in model2_dict:
|
||||
tensor1 = model1_dict[key]
|
||||
tensor2 = model2_dict[key]
|
||||
assert torch.allclose(tensor1, tensor2, rtol=3e-2, atol=3e-2)
|
||||
|
||||
|
||||
def main(args):
|
||||
# init setting
|
||||
skip_batches = gpc.config.data.skip_batches
|
||||
total_steps = gpc.config.data.total_steps
|
||||
valid_every = gpc.config.data.valid_every
|
||||
label_smoothing = gpc.config.loss.label_smoothing
|
||||
|
||||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
checkpoint=gpc.config.model.checkpoint,
|
||||
seq_len=gpc.config.SEQ_LEN,
|
||||
hidden_size=gpc.config.model.hidden_size,
|
||||
num_layers=gpc.config.model.num_layers,
|
||||
vocab_size=gpc.config.model.vocab_size,
|
||||
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
|
||||
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
)
|
||||
|
||||
# get and broadcast current time
|
||||
current_time = launch_time()
|
||||
objs = [current_time]
|
||||
dist.broadcast_object_list(objs, src=0)
|
||||
current_time = objs[0]
|
||||
|
||||
# initialize customed llm logger
|
||||
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||
|
||||
# initialize model
|
||||
model = initialize_model()
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
|
||||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||
|
||||
# initialize the train and validation data loader
|
||||
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
||||
val_dls = get_validation_data_loader()
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
train_dl=train_dl,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
)
|
||||
|
||||
# Loading other persistent training states.
|
||||
ckpt_manager.try_resume_training(train_state, current_time)
|
||||
|
||||
# initialize customed llm writer
|
||||
writer = Writer(
|
||||
job_name=gpc.config.JOB_NAME,
|
||||
launch_time=current_time,
|
||||
file_name=get_parallel_log_file_name(),
|
||||
tensorboard_folder=gpc.config.tensorboard_folder,
|
||||
resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt.
|
||||
step_count=train_state.step_count, # resume from ckpt.
|
||||
config=config_lines,
|
||||
logger=logger,
|
||||
enable_tb=gpc.config.enable_tb,
|
||||
)
|
||||
|
||||
# initialize metric for calculating accuracy and perplexity
|
||||
metric = AccPerplex(
|
||||
device=torch.cuda.current_device(),
|
||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||
dataset_types=dataset_types,
|
||||
)
|
||||
|
||||
# initialize trainer
|
||||
scheduler_hooks = [
|
||||
SchedulerMetricHook(
|
||||
metric=metric,
|
||||
skip=(
|
||||
gpc.is_using_pp()
|
||||
and hasattr(gpc.config.model, "num_chunks")
|
||||
and gpc.config.model.num_chunks > 1
|
||||
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
trainer, train_dl, _, _ = internlm.initialize_trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dl,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
|
||||
# initialize simple memory profiler
|
||||
if args.profiling:
|
||||
memory_profiler = SimpleMemoryProfiler(
|
||||
model,
|
||||
optimizer.optim,
|
||||
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
|
||||
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
|
||||
)
|
||||
else:
|
||||
memory_profiler = None
|
||||
|
||||
# initialize the batch skipper
|
||||
batch_skipper = BatchSkipper(skip_batches)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# transfer the train data loader into train data iterator
|
||||
train_iter = iter(train_dl)
|
||||
|
||||
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
|
||||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
||||
# load batch data
|
||||
# batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
# pylint: disable=C0301
|
||||
batch_index = batch_count % 1000
|
||||
if batch_index == 0:
|
||||
data_local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
batch_step = (batch_count // 1000 + 1) * 1000
|
||||
data_path = f"/mnt/petrelfs/share/quailty_assurance/debug_Qiansanqiang_7B_v16/dp-11{data_local_rank}/batch-{batch_step}.pt"
|
||||
data_1000 = torch.load(data_path, map_location=torch.device("cpu"))
|
||||
batch = data_1000[batch_index]
|
||||
|
||||
# record the consumed samples in training
|
||||
train_state.batch_count = batch_count
|
||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||
if batch_skipper(batch_count): # skip this batch
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Skip batch count:`{batch_count}`...")
|
||||
timer("one-batch").stop()
|
||||
continue
|
||||
|
||||
# zero the grads of parameters
|
||||
trainer.zero_grad()
|
||||
# process data
|
||||
if batch[0].get("type_ids", None) is not None:
|
||||
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
|
||||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
|
||||
moe_loss = None
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
else:
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
||||
success_update, grad_norm_groups = trainer_result
|
||||
if success_update: # update parameters successfully
|
||||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
message=f"Warning: skip parameter update at step {batch_count}.",
|
||||
)
|
||||
|
||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
record_current_batch_training_metrics(
|
||||
get_tflops_func=get_tflops_func,
|
||||
logger=logger,
|
||||
writer=writer,
|
||||
success_update=success_update,
|
||||
batch_count=batch_count,
|
||||
batch=batch,
|
||||
train_state=train_state,
|
||||
optimizer=optimizer,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
moe_loss=moe_loss,
|
||||
grad_norm=grad_norm_groups,
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
timer("one-batch").stop()
|
||||
|
||||
# evaluate on validation data loaders
|
||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||
evaluate_on_val_dls(
|
||||
trainer=trainer,
|
||||
val_dls=val_dls,
|
||||
writer=writer,
|
||||
logger=logger,
|
||||
step_count=train_state.step_count,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
if batch_count > 0 and batch_count % 100 == 0:
|
||||
ckpt_path = os.path.join(
|
||||
"/mnt/petrelfs/share/quailty_assurance/7B_model_weights_ckpt", str(batch_count), "model_tp0_pp0.pt"
|
||||
)
|
||||
check_model_weights(model, ckpt_path)
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
now_break = ckpt_manager.try_save_checkpoint(train_state)
|
||||
if now_break:
|
||||
break
|
||||
|
||||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
|
||||
if batch_count % 2 == 0:
|
||||
prof.step()
|
||||
|
||||
ckpt_manager.wait_async_upload_finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
# initialize distributed environment
|
||||
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
|
||||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
|
||||
# initialize monitor manager context
|
||||
with initialize_monitor_manager(
|
||||
job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
|
||||
):
|
||||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
|
||||
)
|
||||
mm.monitor_exception(
|
||||
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
|
||||
)
|
Loading…
Reference in New Issue