From 757e19e01a6e04bea03d09e0cdeb0df16b5b64ea Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Tue, 28 Nov 2023 19:33:46 +0800 Subject: [PATCH 1/6] 1. fix(config): rampup_batch_size defalut value BC. (#515) 2. fix(config): standardize config parameter access. 3. feat(launch): add warmup_process_group 4. feat(memory): add cuda_memory_analyze --- configs/7B_sft.py | 12 +-- internlm/core/scheduler/pipeline_scheduler.py | 12 +-- internlm/initialize/launch.py | 10 ++- internlm/utils/evaluation.py | 2 +- internlm/utils/gputest.py | 79 +++++++++++++++++++ tests/test_training/train_CI.py | 4 +- train.py | 4 +- 7 files changed, 102 insertions(+), 21 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 0218a0b..1cbb5e7 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -44,8 +44,8 @@ ckpt = dict( oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) -TRAIN_FOLDER = "/path/to/dataset" -VALID_FOLDER = "/path/to/dataset" +TRAIN_FOLDER = None # "/path/to/dataset" +VALID_FOLDER = None # "/path/to/dataset" data = dict( seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update @@ -64,12 +64,12 @@ data = dict( # each increment. For example, "192 24 8" means that the batch size (micro_num) # starts at 192 and increases by 24 every 8 steps. Defaults to None. # (IMPORTANT): The interval step size is 'micro_bsz'. - rampup_batch_size=None, + rampup_batch_size="", # Datasets with less than 50 rows will be discarded min_length=50, - # train_folder=TRAIN_FOLDER, - # valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=10, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, diag_outlier_ratio=1.1, ) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index c851789..550584e 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -35,19 +35,19 @@ def get_tensor_shape(): if gpc.config.parallel.sequence_parallel: sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) tensor_shape = ( - gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size, - gpc.config.HIDDEN_SIZE, + gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"] // sequence_world_size, + gpc.config.model["hidden_size"], ) else: tensor_shape = ( - gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], - gpc.config.HIDDEN_SIZE, + gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"], + gpc.config.model["hidden_size"], ) else: tensor_shape = ( gpc.config.data["micro_bsz"], - gpc.config.SEQ_LEN, - gpc.config.HIDDEN_SIZE, + gpc.config.data["seq_len"], + gpc.config.model["hidden_size"], ) return tensor_shape else: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e96d2d9..2736532 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -13,6 +13,7 @@ from internlm.core.context import Config from internlm.core.context import global_context as gpc from internlm.monitor import initialize_light_monitor from internlm.utils.common import get_master_node +from internlm.utils.gputest import warmup_process_group from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout @@ -60,6 +61,9 @@ def get_default_parser(): def args_sanity_check(): assert gpc.config is not None, "config is not load!" + if "JOB_NAME" not in gpc.config: + gpc.config._add_item("JOB_NAME", "AnonymousJob") + # the default model type is INTERNLM if "model_type" not in gpc.config: gpc.config._add_item("model_type", "INTERNLM") @@ -144,10 +148,6 @@ def args_sanity_check(): if "diag_outlier_ratio" not in data: data._add_item("diag_outlier_ratio", 1.1) - if "rampup_batch_size" not in data or not data.rampup_batch_size or len(data.rampup_batch_size) == 0: - bsz = data.micro_num - data._add_item("rampup_batch_size", f"{bsz} {bsz} 1") - data.diag_outlier_ratio = max(1, data.diag_outlier_ratio) if gpc.is_rank_for_log(): @@ -423,6 +423,8 @@ def launch( gpc.set_seed(seed) + warmup_process_group() + if gpc.is_rank_for_log(): logger.info( f"Distributed environment is initialized, " diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 22d998b..a94784c 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -101,7 +101,7 @@ def evaluate_on_val_dls( assert total_val_bsz % data_cfg.micro_bsz == 0 num_microbatches = total_val_bsz // data_cfg.micro_bsz tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] + [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]] ) with switch_evaluation_pipeline_scheduler( diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 85d4cdc..48ec0e3 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -27,10 +27,17 @@ from internlm.utils.common import get_current_device logger = get_logger(__file__) +# Gloabl cuda cache flush counter +n_caching_allocator_flushes = 0 + + def empty_cache_and_diag(batch_count, interval=50): """empty cuda cache and run diag bench or tests.""" if interval <= 0: interval = 50 + + cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5) + if batch_count % int(interval) == 0: # there is no need to do diag on the first batch if batch_count > 0: @@ -259,3 +266,75 @@ def bench_gpu(use_flash_attn=True): address=gpc.config.monitor.alert.feishu_alert_address, message=msg, ) + + +""" +Useful utility functions migrated from deepseped. +""" + + +def warmup_process_group(): + # Prevent OOM from nccl communication. + if dist.is_initialized(): + buffer = torch.ones([64]).cuda() + if gpc.is_initialized(ParallelMode.DATA): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.DATA)) + if gpc.is_initialized(ParallelMode.TENSOR): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.TENSOR)) + if gpc.is_initialized(ParallelMode.PIPELINE): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.PIPELINE)) + if gpc.is_initialized(ParallelMode.ZERO1): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO1)) + if gpc.is_initialized(ParallelMode.MODEL): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.MODEL)) + if gpc.is_initialized(ParallelMode.ZERO3_DP): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO3_DP)) + if gpc.is_initialized(ParallelMode.EXPERT_DATA): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT_DATA)) + if gpc.is_initialized(ParallelMode.EXPERT): + dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT)) + + dist.barrier() + del buffer + torch.cuda.empty_cache() + + +def cuda_memory_analyze(step=0, print_mm_suage=False): + global n_caching_allocator_flushes + torch.cuda.synchronize() + + g_rank = gpc.get_global_rank() + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + dp_rank = gpc.get_local_rank(ParallelMode.DATA) + rank_id = f"Rank:{g_rank}-tp{tp_rank}-pp{pp_rank}-dp{dp_rank}" + + if print_mm_suage and gpc.get_local_rank(ParallelMode.DATA) == 0: + logger.info( + f"{rank_id}: Step {step}: " + f"Allocated {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),4 )} GB, " + f"Max_Allocated {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),4)} GB, " + f"Reserved {round(torch.cuda.memory_reserved()/ (1024 * 1024 * 1024),4)} GB, " + f"Max_Reserved {round(torch.cuda.max_memory_reserved()/ (1024 * 1024 * 1024),4)} GB " + ) + + torch.cuda.reset_peak_memory_stats() + + # warn user about caching allocator flushes + memory_stats = torch.cuda.memory_stats() + alloc_retries = memory_stats.get("num_alloc_retries") + if alloc_retries is None: + alloc_retries = 0 + if alloc_retries > n_caching_allocator_flushes: + retry_count = alloc_retries - n_caching_allocator_flushes + if gpc.get_global_rank() == 0: + logger.warning( + f"{rank_id}: pytorch allocator cache flushes {retry_count} times since last step." + "this happens when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "torch.cuda.empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time" + ) + n_caching_allocator_flushes = alloc_retries diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 348c780..98a69c9 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -106,13 +106,13 @@ def main(args): get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, - seq_len=gpc.config.SEQ_LEN, + seq_len=gpc.config.data["seq_len"], hidden_size=gpc.config.model.hidden_size, num_layers=gpc.config.model.num_layers, vocab_size=gpc.config.model.vocab_size, global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), - mlp_ratio=gpc.config.MLP_RATIO, + mlp_ratio=gpc.config.model["mlp_ratio"], ) # get and broadcast current time diff --git a/train.py b/train.py index 35e39fa..9f0c1ac 100644 --- a/train.py +++ b/train.py @@ -77,13 +77,13 @@ def main(args): get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, - seq_len=gpc.config.SEQ_LEN, + seq_len=gpc.config.data["seq_len"], hidden_size=gpc.config.model.hidden_size, num_layers=gpc.config.model.num_layers, vocab_size=gpc.config.model.vocab_size, global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), - mlp_ratio=gpc.config.MLP_RATIO, + mlp_ratio=gpc.config.model["mlp_ratio"], ) # get and broadcast current time From b79d5ea7ae326b31d732d888a7f322af0fc0e3fe Mon Sep 17 00:00:00 2001 From: kkscilife <126147887+kkscilife@users.noreply.github.com> Date: Thu, 30 Nov 2023 11:04:07 +0800 Subject: [PATCH 2/6] test(workflow): add workflow for loss test and change trigger event (#513) * add workflow for loss test * change trigger event * optimize trigger event --------- Co-authored-by: wangmengke --- .github/workflows/pr_merged.yaml | 17 ++++++++++ .github/workflows/weekly_test.yaml | 53 +++++++++++++++++++----------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/.github/workflows/pr_merged.yaml b/.github/workflows/pr_merged.yaml index 65e273b..5a09019 100644 --- a/.github/workflows/pr_merged.yaml +++ b/.github/workflows/pr_merged.yaml @@ -50,3 +50,20 @@ jobs: source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 export PYTHONPATH=$PWD:$PYTHONPATH srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=internlm-acc-test-${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python ./tests/test_training/train_CI.py --config ./tests/test_training/7B_check_acc.py + + check_loss_when_swapping_micro_num_and_micro_bsz: + if: ${{ !cancelled() }} + needs: check-requirements + runs-on: [t_cluster] + timeout-minutes: 40 + steps: + - name: mask env + run: | + echo "::add-mask::${{env.WORKSPACE_PREFIX}}" + - uses: actions/checkout@v3 + + - name: loss_tests + run: | + source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 + export PYTHONPATH=$PWD:$PYTHONPATH + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=internlm-loss-test-${GITHUB_RUN_ID}-${GITHUB_JOB} -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_training/test_swap_nb_loss_and_gradnorm.py diff --git a/.github/workflows/weekly_test.yaml b/.github/workflows/weekly_test.yaml index bf360c8..133bccc 100644 --- a/.github/workflows/weekly_test.yaml +++ b/.github/workflows/weekly_test.yaml @@ -1,102 +1,115 @@ name: weekly-tests on: - push: - branches: - - "main" - - "develop" + workflow_dispatch: + schedule: + - cron: '56 18 * * 5' env: SLURM_PARTITION: llm_s jobs: training_8GPU: runs-on: [t_cluster] - timeout-minutes: 5 + timeout-minutes: 10 steps: - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: training_8GPU run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU" ./tests/test_training + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU" ./tests/test_training training_16GPU_8DP2TP: runs-on: [t_cluster] - timeout-minutes: 5 + timeout-minutes: 10 steps: - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: training_16GPU_8DP2TP run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 sed -i 's/^.*tensor=.*/ tensor=2,/' ./configs/7B_sft.py - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TP" ./tests/test_training + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TP" ./tests/test_training training_16GPU_8DP2TPSP: runs-on: [t_cluster] - timeout-minutes: 5 + timeout-minutes: 10 steps: - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: training_16GPU_8DP2TPSP run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 sed -i 's/^.*tensor=.*/ tensor=2,/' ./configs/7B_sft.py sed -i 's/^.*sequence_parallel=.*/ sequence_parallel=True,/' ./configs/7B_sft.py - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TPSP" ./tests/test_training + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TPSP" ./tests/test_training training_16GPU_8DP2PP: runs-on: [t_cluster] - timeout-minutes: 5 + timeout-minutes: 10 steps: - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: training_16GPU_8DP2PP run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 sed -i 's/^.*pipeline=.*/ pipeline=dict(size=2),/' ./configs/7B_sft.py - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP" ./tests/test_training + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP" ./tests/test_training training_16GPU_8DP2PP_InterleavedOverlap: runs-on: [t_cluster] - timeout-minutes: 5 + timeout-minutes: 10 steps: - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: training_16GPU_8DP2PP_InterleavedOverlap run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 sed -i 's/^.*pipeline=.*/ pipeline=dict(size=2, interleaved_overlap=True),/' ./configs/7B_sft.py sed -i 's/^.*num_chunks=.*/ num_chunks=2,/' ./configs/7B_sft.py - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP_InterleavedOverlap" ./tests/test_training + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP_InterleavedOverlap" ./tests/test_training unit_test_optimizer: runs-on: [t_cluster] - timeout-minutes: 30 + timeout-minutes: 35 steps: - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: test_optimizer run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_solver/test_optimizer.py + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_solver/test_optimizer.py unit_test_model: runs-on: [t_cluster] - timeout-minutes: 5 + timeout-minutes: 10 steps: - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: test_embedding_accuracy run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_model/test_embedding.py + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_model/test_embedding.py - name: test_model_internlm_accuracy run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_model/test_model_internlm.py + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_model/test_model_internlm.py - name: test_norm_accuracy run: | source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_model/test_norm.py + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -N 1 -n 1 --gres=gpu:8 python -m pytest -s ./tests/test_model/test_norm.py From b3be333aa2898749757eb4782b16c55b7cfb6400 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Thu, 30 Nov 2023 19:16:35 +0800 Subject: [PATCH 3/6] fix(ci): fix test model ckpt ci test (#518) --- tests/test_utils/common_fixture.py | 39 ++++++++++++++++------- tests/test_utils/test_model_checkpoint.py | 9 ++++-- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 56e7b21..d0f1455 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -8,22 +8,33 @@ import torch from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer +from internlm.train.utils import create_param_groups from internlm.utils.common import SingletonMeta -OSS_NAME = os.environ.get("OSS_BUCKET_NAME") -OSS_IP = os.environ.get("OSS_IP") -USER = os.environ.get("USER") +OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None) +OSS_IP = os.environ.get("OSS_IP", None) +USER = os.environ.get("USER", None) JOB_NAME = "CI_TEST" LOCAL_SAVE_PATH = "local:local_ckpt" -BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" -BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" +if OSS_NAME is None or OSS_IP is None: + BOTO_SAVE_PATH = None + BOTO_SAVE_PATH_NO_PRFIX = None -VOLC_SAVE_PATH = f"volc:vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" -VOLC_SAVE_PATH_NO_PRFIX = f"vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + VOLC_SAVE_PATH = None + VOLC_SAVE_PATH_NO_PRFIX = None -ALI_SAVE_PATH = f"oss2:ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" -ALI_SAVE_PATH_NO_PRFIX = f"ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + ALI_SAVE_PATH = None + ALI_SAVE_PATH_NO_PRFIX = None +else: + BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" + BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + + VOLC_SAVE_PATH = f"volc:vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" + VOLC_SAVE_PATH_NO_PRFIX = f"vc://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" + + ALI_SAVE_PATH = f"oss2:ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}" + ALI_SAVE_PATH_NO_PRFIX = f"ali://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/" ASYNC_TMP_FOLDER = "./async_tmp_folder" @@ -31,7 +42,12 @@ ASYNC_TMP_FOLDER = "./async_tmp_folder" # 1B init_config = Config( dict( - parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1), + parallel=dict( + zero1=dict(size=1, fsdp=False), + pipeline=dict(size=1, interleaved_overlap=False), + sequence_parallel=False, + tensor=1, + ), model_type="INTERNLM", adam=dict( lr=1e-4, @@ -90,8 +106,9 @@ def init_naive_optim(model): def init_hybrid_optim(model): + params = create_param_groups(model, 0.01) naive_optimizer = torch.optim.AdamW( - params=[{"params": model.parameters(), "weight_decay": 0.01}], + params=params, lr=1e-4, betas=(0.9, 0.95), eps=1e-8, diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 0804455..c50eec1 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -32,7 +32,7 @@ ckpt_config_list = [ checkpoint_every=0, async_upload=True, async_upload_tmp_folder=ASYNC_TMP_FOLDER, - snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), + snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]) if BOTO_SAVE_PATH is not None else None, oss_snapshot_freq=0, stop_file_path=None, load_model_only_folder=None, @@ -207,6 +207,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: ckpt_config.checkpoint_every = checkpoint_every ckpt_config.oss_snapshot_freq = oss_snapshot_freq + if ckpt_config.save_ckpt_folder is None: + return + bond_return_latest_save_path = partial( return_latest_save_path, ckpt_config.save_ckpt_folder, @@ -298,12 +301,12 @@ def query_quit_file(rank, world_size=2): ckpt_config = Config( dict( enable_save_ckpt=True, - save_ckpt_folder=BOTO_SAVE_PATH, + save_ckpt_folder=LOCAL_SAVE_PATH, load_optimizer=True, checkpoint_every=0, async_upload=True, async_upload_tmp_folder=ASYNC_TMP_FOLDER, - snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]), + snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]), oss_snapshot_freq=0, stop_file_path=STOP_FILE_PATH, load_model_only_folder=None, From 66bffffe5ca57273a52d0cdc5446d96fd2678afc Mon Sep 17 00:00:00 2001 From: kkscilife <126147887+kkscilife@users.noreply.github.com> Date: Fri, 1 Dec 2023 16:12:39 +0800 Subject: [PATCH 4/6] add unit test case (#524) Co-authored-by: wangmengke --- .github/workflows/unit_tests.yaml | 68 +++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 3f49868..581b293 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -74,3 +74,71 @@ jobs: source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 export PYTHONPATH=$PWD:$PYTHONPATH srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=internlm-ut-${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:8 python -m pytest -s -v ./tests/test_utils/test_storage_manager.py + + unit_tests_model_fused_precision: + if: ${{ !cancelled() }} + needs: check-requirements + runs-on: [t_cluster] + timeout-minutes: 5 + steps: + - name: mask env + run: | + echo "::add-mask::${{env.WORKSPACE_PREFIX}}" + - uses: actions/checkout@v3 + + - name: model_fused_precision + run: | + source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 + export PYTHONPATH=$PWD:$PYTHONPATH + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=internlm-ut-${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:8 python -m pytest -s -v ./tests/test_model/test_fused_precision/test_fused_precision.py + + unit_tests_data_batch_sampler: + if: ${{ !cancelled() }} + needs: check-requirements + runs-on: [t_cluster] + timeout-minutes: 10 + steps: + - name: mask env + run: | + echo "::add-mask::${{env.WORKSPACE_PREFIX}}" + - uses: actions/checkout@v3 + + - name: data_batch_sample + run: | + source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 + export PYTHONPATH=$PWD:$PYTHONPATH + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=internlm-ut-${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:8 python -m pytest -s -v ./tests/test_data/test_batch_sampler.py + + unit_tests_utils_timeout: + if: ${{ !cancelled() }} + needs: check-requirements + runs-on: [t_cluster] + timeout-minutes: 5 + steps: + - name: mask env + run: | + echo "::add-mask::${{env.WORKSPACE_PREFIX}}" + - uses: actions/checkout@v3 + + - name: utils_timeout + run: | + source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 + export PYTHONPATH=$PWD:$PYTHONPATH + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=internlm-ut-${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:1 python -m pytest -s -v ./tests/test_utils/test_timeout.py + + unit_tests_utils_model_checkpoint: + if: ${{ !cancelled() }} + needs: check-requirements + runs-on: [t_cluster] + timeout-minutes: 5 + steps: + - name: mask env + run: | + echo "::add-mask::${{env.WORKSPACE_PREFIX}}" + - uses: actions/checkout@v3 + + - name: utils_model_checkpoint + run: | + source /mnt/petrelfs/share_data/llm_env/env/llm-flash2.0 + export PYTHONPATH=$PWD:$PYTHONPATH + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=internlm-ut-${GITHUB_RUN_ID}-${GITHUB_JOB} --quotatype=spot -N 1 -n 1 --gres=gpu:2 python -m pytest -s -v ./tests/test_utils/test_model_checkpoint.py From 1738bee0028444d7343bd9e9df8e13aed3b215a2 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Fri, 1 Dec 2023 17:05:58 +0800 Subject: [PATCH 5/6] feat(storage): use multipart upload when using oss (#520) * multipart upload * upload * storage * storage * storage * storage --- internlm/initialize/launch.py | 3 +- internlm/utils/model_checkpoint.py | 3 +- internlm/utils/storage_manager.py | 132 +++++++++++++++++++---- tests/test_utils/common_fixture.py | 2 +- tests/test_utils/test_storage_manager.py | 3 - 5 files changed, 116 insertions(+), 27 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 2736532..82a4a21 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -178,7 +178,8 @@ def args_sanity_check(): else: if ckpt.async_upload: assert "save_ckpt_folder" in ckpt - if "boto3:" not in ckpt.save_ckpt_folder: + prefix_list = ["boto3:", "volc:", "oss2:"] + if not any(ckpt.save_ckpt_folder.startswith(prefix) for prefix in prefix_list): if gpc.is_rank_for_log(): logger.warning( "Storing ckpt on file system does not support asynchronous storage, will use sync save!" diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index d16db0c..87a303c 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -720,6 +720,7 @@ class CheckpointManager: self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"]) self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"]) + torch.distributed.barrier() # test storage setting is ok. if self.enable_save_ckpt: self.try_ping_storage() @@ -1016,7 +1017,7 @@ now step_count is {train_state.step_count}", self.storage_manager.latest_save_step = step def try_ping_storage(self): - if gpc.get_global_rank() % 8 == 0: + if gpc.is_rank_for_log(): buff = torch.ones((1, 64, 64), dtype=torch.bfloat16) test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping") self.storage_manager.save(test_fn, buff) diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index c76b570..14c620b 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -18,10 +18,6 @@ from typing import Any, Awaitable, Callable, Dict, List, Union import torch import torch.distributed as dist -from internlm.core.context import global_context as gpc -from internlm.utils.common import SingletonMeta -from internlm.utils.logger import get_logger - try: import boto3 import botocore @@ -30,16 +26,38 @@ except ImportError: try: import tos + from tos.utils import SizeAdapter except ImportError: pass try: import oss2 + from oss2 import SizedFileAdapter, determine_part_size + from oss2.models import PartInfo except ImportError: pass -logger = get_logger(__file__) +class Logger: + "Dummy logger" + + def info(self, mesage: str): + print(f"Info: {mesage}", flush=True) + + def warning(self, mesage: str): + print(f"Warning: {mesage}", flush=True) + + def error(self, mesage: str): + print(f"Error: {mesage}", flush=True) + + +try: + from internlm.utils.logger import get_logger + + logger = get_logger(__file__) +except ImportError: + logger = Logger() + boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)") volc_url_re = re.compile(r"^(.*?)\.(.*)$") @@ -66,6 +84,12 @@ def llm_save(save_path: str, saved_obj: Any, **kwargs): storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs) +def is_rank_for_log(): + if dist.is_initialized(): + return dist.get_rank() % 8 == 0 + return True + + class StorageClient: """ StorageClient as a client for s3 storage access. @@ -267,21 +291,21 @@ def compute_file_md5_by_chunk(file_name: str): def try_get_storage_backend(path: str): if path.startswith("s3:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.") return "boto3", path elif path.startswith("vc:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.") return "volc", path elif path.startswith("ali:"): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of ali.") return "oss2", path else: sre = path.split(":", maxsplit=1) if len(sre) == 1: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.") return "local", sre[0] else: @@ -399,7 +423,7 @@ class Boto3Client(StorageClient): folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @@ -530,14 +554,41 @@ class VolcClient(StorageClient): return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): try: - handler.client.put_object_from_file(bucket_name, fp, local_nvme_path) + total_size = os.path.getsize(local_nvme_path) + part_size = 5 * 1024 * 1024 + + multi_result = handler.client.create_multipart_upload(bucket_name, fp) + + upload_id = multi_result.upload_id + parts = [] + + # Upload shard data + with open(local_nvme_path, "rb") as f: + part_number = 1 + offset = 0 + while offset < total_size: + num_to_upload = min(part_size, total_size - offset) + out = handler.client.upload_part( + bucket_name, + fp, + upload_id, + part_number, + content=SizeAdapter(f, num_to_upload, init_offset=offset), + ) + parts.append(out) + offset += num_to_upload + part_number += 1 + + # Complete the multipart upload task + handler.client.complete_multipart_upload(bucket_name, fp, upload_id, parts) + except handler.handler.exceptions.TosClientError as exc: raise RuntimeError( f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}" @@ -548,6 +599,8 @@ class VolcClient(StorageClient): f"error with request id: {exec.request_id}", f"error with message: {exec.message}", f"error with http code: {exec.status_code}", + f"error with ec: {exec.ec}", + f"error with request url: {exec.request_url}", ) from exc except Exception as e: raise e @@ -570,10 +623,10 @@ class AliClient(StorageClient): """Ali object/file storage management class Args: - access_key (str): Ali access key ID. + access_key (str): Ali access key ID.s secret_key (str): Ali secret access key. endpoint (str): Ali tos endpoint. - region (str): Ali tos region. + bucket_name (str): Ali tos bucket_name. """ super().__init__(oss2) @@ -634,14 +687,34 @@ class AliClient(StorageClient): return list(set(folder_name_list)) else: - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{fp}' not found!") return None @staticmethod def async_upload_fileobj(handler, fp: str, local_nvme_path: str): try: - handler.client.put_object_from_file(fp, local_nvme_path) + total_size = os.path.getsize(local_nvme_path) + part_size = determine_part_size(total_size, preferred_size=5 * 1024 * 1024) + upload_id = handler.client.init_multipart_upload(fp).upload_id + parts = [] + with open(local_nvme_path, "rb") as fileobj: + part_number = 1 + offset = 0 + while offset < total_size: + num_to_upload = min(part_size, total_size - offset) + # Calling the SizedFileAdapter method will generate a new file object + # and recalculate the starting append position. + result = handler.client.upload_part( + fp, upload_id, part_number, SizedFileAdapter(fileobj, num_to_upload) + ) + parts.append(PartInfo(part_number, result.etag)) + + offset += num_to_upload + part_number += 1 + + headers = dict() + handler.client.complete_multipart_upload(fp, upload_id, parts, headers=headers) except Exception as e: raise e @@ -683,7 +756,7 @@ class LocalClient(StorageClient): @staticmethod def get_fns(folder): if not os.path.exists(folder): - if gpc.is_rank_for_log(): + if is_rank_for_log(): logger.warning(f"'{folder}' not found!") return None else: @@ -815,6 +888,23 @@ def check_tmp_folder_accessibility(tmp_local_folder: str): raise RuntimeError(error_str) +class SingletonMeta(type): + """ + Singleton Meta. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + else: + assert ( + len(args) == 0 and len(kwargs) == 0 + ), f"{cls.__name__} is a singleton class and a instance has been created." + return cls._instances[cls] + + class StorageManager(metaclass=SingletonMeta): """ Storage Manager for saving or loading checkpoint. @@ -898,7 +988,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using boto3, incorrectly setting \ the proxy may make boto3 unavailable or affect performance." @@ -917,7 +1007,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using volc, incorrectly setting \ the proxy may make volc unavailable or affect performance." @@ -936,7 +1026,7 @@ class StorageManager(metaclass=SingletonMeta): or "HTTP_PROXY" in os.environ or "HTTPS_PROXY" in os.environ ): - if not self.has_warning and gpc.is_rank_for_log(): + if not self.has_warning and is_rank_for_log(): logger.warning( "HTTP/HTTPS proxy is detected when using oss2, incorrectly setting \ the proxy may make oss2 unavailable or affect performance." @@ -1082,7 +1172,7 @@ class StorageManager(metaclass=SingletonMeta): self._to_be_del_files.clear() self.async_task_peeding = False - if gpc.is_rank_for_log(): + if is_rank_for_log(): self.upload_count += 1 if self.async_mode and self.latest_save_folder: self.save( diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index d0f1455..6096156 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -9,7 +9,7 @@ from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer from internlm.train.utils import create_param_groups -from internlm.utils.common import SingletonMeta +from internlm.utils.storage_manager import SingletonMeta OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None) OSS_IP = os.environ.get("OSS_IP", None) diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py index e102ca1..9454a83 100644 --- a/tests/test_utils/test_storage_manager.py +++ b/tests/test_utils/test_storage_manager.py @@ -100,7 +100,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg init_storage_manager, llm_load, llm_save, - wait_async_upload_finish, ) ckpt_config = Config(ckpt_config) @@ -118,8 +117,6 @@ def test_storage_mm_save_load(ckpt_config): # noqa # pylint: disable=unused-arg tobj = torch.rand(64, 64) save_fn = os.path.join(ckpt_config.save_folder, "test.pt") llm_save(save_fn, tobj) - if ckpt_config.test_id == 0: - wait_async_upload_finish() check_folder(save_fn) assert get_fns(ckpt_config.save_folder)[0] == "test.pt" load_obj = llm_load(save_fn, map_location="cpu") From 2dbbab74180f1bd346d45f1fdccdd642d3b618b4 Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Mon, 4 Dec 2023 15:38:13 +0800 Subject: [PATCH 6/6] fix test_checkpoint (#526) --- tests/test_utils/test_model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index c50eec1..2dcabf4 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -8,9 +8,8 @@ import torch.distributed as dist from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer -from internlm.utils.common import SingletonMeta from internlm.utils.model_checkpoint import CheckpointManager -from internlm.utils.storage_manager import wait_async_upload_finish +from internlm.utils.storage_manager import SingletonMeta, wait_async_upload_finish from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import ASYNC_TMP_FOLDER, BOTO_SAVE_PATH,