diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 0218a0b..1cbb5e7 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -44,8 +44,8 @@ ckpt = dict( oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) -TRAIN_FOLDER = "/path/to/dataset" -VALID_FOLDER = "/path/to/dataset" +TRAIN_FOLDER = None # "/path/to/dataset" +VALID_FOLDER = None # "/path/to/dataset" data = dict( seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update @@ -64,12 +64,12 @@ data = dict( # each increment. For example, "192 24 8" means that the batch size (micro_num) # starts at 192 and increases by 24 every 8 steps. Defaults to None. # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size=None, + rampup_batch_size="", # Datasets with less than 50 rows will be discarded min_length=50, - # train_folder=TRAIN_FOLDER, - # valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=10, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, diag_outlier_ratio=1.1, ) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index c851789..550584e 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -35,19 +35,19 @@ def get_tensor_shape(): if gpc.config.parallel.sequence_parallel: sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) tensor_shape = ( - gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size, - gpc.config.HIDDEN_SIZE, + gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"] // sequence_world_size, + gpc.config.model["hidden_size"], ) else: tensor_shape = ( - gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], - gpc.config.HIDDEN_SIZE, + gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"], + gpc.config.model["hidden_size"], ) else: tensor_shape = ( gpc.config.data["micro_bsz"], - gpc.config.SEQ_LEN, - gpc.config.HIDDEN_SIZE, + gpc.config.data["seq_len"], + gpc.config.model["hidden_size"], ) return tensor_shape else: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e96d2d9..2736532 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -13,6 +13,7 @@ from internlm.core.context import Config from internlm.core.context import global_context as gpc from internlm.monitor import initialize_light_monitor from internlm.utils.common import get_master_node +from internlm.utils.gputest import warmup_process_group from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout @@ -60,6 +61,9 @@ def get_default_parser(): def args_sanity_check(): assert gpc.config is not None, "config is not load!" + if "JOB_NAME" not in gpc.config: + gpc.config._add_item("JOB_NAME", "AnonymousJob") + # the default model type is INTERNLM if "model_type" not in gpc.config: gpc.config._add_item("model_type", "INTERNLM") @@ -144,10 +148,6 @@ def args_sanity_check(): if "diag_outlier_ratio" not in data: data._add_item("diag_outlier_ratio", 1.1) - if "rampup_batch_size" not in data or not data.rampup_batch_size or len(data.rampup_batch_size) == 0: - bsz = data.micro_num - data._add_item("rampup_batch_size", f"{bsz} {bsz} 1") - data.diag_outlier_ratio = max(1, data.diag_outlier_ratio) if gpc.is_rank_for_log(): @@ -423,6 +423,8 @@ def launch( gpc.set_seed(seed) + warmup_process_group() + if gpc.is_rank_for_log(): logger.info( f"Distributed environment is initialized, " diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 22d998b..a94784c 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -101,7 +101,7 @@ def evaluate_on_val_dls( assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]] ) with switch_evaluation_pipeline_scheduler( diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 85d4cdc..48ec0e3 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -27,10 +27,17 @@ from internlm.utils.common import get_current_device logger = get_logger(__file__) +# Gloabl cuda cache flush counter +n_caching_allocator_flushes = 0 + + def empty_cache_and_diag(batch_count, interval=50): """empty cuda cache and run diag bench or tests.""" if interval <= 0: interval = 50 + + cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5) + if batch_count % int(interval) == 0: # there is no need to do diag on the first batch if batch_count > 0: @@ -259,3 +266,75 @@ def bench_gpu(use_flash_attn=True): address=gpc.config.monitor.alert.feishu_alert_address, message=msg, ) + + +""" +Useful utility functions migrated from deepseped. +""" + + +def warmup_process_group(): + # Prevent OOM from nccl communication. + if dist.is_initialized(): + buffer = torch.ones([64]).cuda() + if gpc.is_initialized(ParallelMode.DATA): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.DATA)) + if gpc.is_initialized(ParallelMode.TENSOR): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.TENSOR)) + if gpc.is_initialized(ParallelMode.PIPELINE): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.PIPELINE)) + if gpc.is_initialized(ParallelMode.ZERO1): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO1)) + if gpc.is_initialized(ParallelMode.MODEL): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.MODEL)) + if gpc.is_initialized(ParallelMode.ZERO3_DP): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO3_DP)) + if gpc.is_initialized(ParallelMode.EXPERT_DATA): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT_DATA)) + if gpc.is_initialized(ParallelMode.EXPERT): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT)) + + dist.barrier() + del buffer + torch.cuda.empty_cache() + + +def cuda_memory_analyze(step=0, print_mm_suage=False): + global n_caching_allocator_flushes + torch.cuda.synchronize() + + g_rank = gpc.get_global_rank() + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + dp_rank = gpc.get_local_rank(ParallelMode.DATA) + rank_id = f"Rank:{g_rank}-tp{tp_rank}-pp{pp_rank}-dp{dp_rank}" + + if print_mm_suage and gpc.get_local_rank(ParallelMode.DATA) == 0: + logger.info( + f"{rank_id}: Step {step}: " + f"Allocated {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),4 )} GB, " + f"Max_Allocated {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),4)} GB, " + f"Reserved {round(torch.cuda.memory_reserved()/ (1024 * 1024 * 1024),4)} GB, " + f"Max_Reserved {round(torch.cuda.max_memory_reserved()/ (1024 * 1024 * 1024),4)} GB " + ) + + torch.cuda.reset_peak_memory_stats() + + # warn user about caching allocator flushes + memory_stats = torch.cuda.memory_stats() + alloc_retries = memory_stats.get("num_alloc_retries") + if alloc_retries is None: + alloc_retries = 0 + if alloc_retries > n_caching_allocator_flushes: + retry_count = alloc_retries - n_caching_allocator_flushes + if gpc.get_global_rank() == 0: + logger.warning( + f"{rank_id}: pytorch allocator cache flushes {retry_count} times since last step." + "this happens when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "torch.cuda.empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time" + ) + n_caching_allocator_flushes = alloc_retries diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 348c780..98a69c9 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -106,13 +106,13 @@ def main(args): get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, - seq_len=gpc.config.SEQ_LEN, + seq_len=gpc.config.data["seq_len"], hidden_size=gpc.config.model.hidden_size, num_layers=gpc.config.model.num_layers, vocab_size=gpc.config.model.vocab_size, global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), - mlp_ratio=gpc.config.MLP_RATIO, + mlp_ratio=gpc.config.model["mlp_ratio"], ) # get and broadcast current time diff --git a/train.py b/train.py index 35e39fa..9f0c1ac 100644 --- a/train.py +++ b/train.py @@ -77,13 +77,13 @@ def main(args): get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, - seq_len=gpc.config.SEQ_LEN, + seq_len=gpc.config.data["seq_len"], hidden_size=gpc.config.model.hidden_size, num_layers=gpc.config.model.num_layers, vocab_size=gpc.config.model.vocab_size, global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), - mlp_ratio=gpc.config.MLP_RATIO, + mlp_ratio=gpc.config.model["mlp_ratio"], ) # get and broadcast current time