mirror of https://github.com/InternLM/InternLM
				
				
				
			merge Internlm/develop into feature_add_moe
						commit
						b7ddc42dcd
					
				| 
						 | 
				
			
			@ -0,0 +1,56 @@
 | 
			
		|||
name: e2e-tests
 | 
			
		||||
on: 
 | 
			
		||||
  pull_request:
 | 
			
		||||
    branches:
 | 
			
		||||
      - "main"
 | 
			
		||||
      - "develop"
 | 
			
		||||
    paths-ignore:
 | 
			
		||||
      - "doc/**"
 | 
			
		||||
      - "**.md"
 | 
			
		||||
env:
 | 
			
		||||
  WORKSPACE_PREFIX: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-4)
 | 
			
		||||
  SLURM_PARTITION: llm
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  check-requirements:
 | 
			
		||||
    runs-on: [lmtest]
 | 
			
		||||
    steps:
 | 
			
		||||
    - name: mask env
 | 
			
		||||
      run: |
 | 
			
		||||
        echo "::add-mask::${{env.WORKSPACE_PREFIX}}"
 | 
			
		||||
    - uses: actions/checkout@v3
 | 
			
		||||
      with:
 | 
			
		||||
         fetch-depth: 2
 | 
			
		||||
    - name: check-requirements
 | 
			
		||||
      run: |
 | 
			
		||||
        source activate internlm-env-test
 | 
			
		||||
        changed_files=$(git diff --name-only -r HEAD^1 HEAD)
 | 
			
		||||
        echo $changed_files
 | 
			
		||||
        if [[ $changed_files =~ "runtime.txt" ]]; then
 | 
			
		||||
          pip install -r requirements/runtime.txt
 | 
			
		||||
        fi
 | 
			
		||||
 | 
			
		||||
        if [[ $changed_files =~ "torch.txt"  ]]; then
 | 
			
		||||
          pip install -r requirements/torch.txt
 | 
			
		||||
        fi
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  e2e_tests:
 | 
			
		||||
    if: ${{ always() }}
 | 
			
		||||
    needs: check-requirements
 | 
			
		||||
    runs-on: [lmtest]
 | 
			
		||||
    timeout-minutes: 30
 | 
			
		||||
    steps:
 | 
			
		||||
    - name: mask env
 | 
			
		||||
      run: |
 | 
			
		||||
        echo "::add-mask::${{env.WORKSPACE_PREFIX}}"
 | 
			
		||||
    - uses: actions/checkout@v3
 | 
			
		||||
 | 
			
		||||
    - name: e2e-test
 | 
			
		||||
      run: |
 | 
			
		||||
        source activate internlm-env-test
 | 
			
		||||
        srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU" ./tests/test_training
 | 
			
		||||
        srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TP" ./tests/test_training
 | 
			
		||||
        srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TPSP" ./tests/test_training
 | 
			
		||||
        srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP" ./tests/test_training
 | 
			
		||||
        srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP_InterleavedOverlap" ./tests/test_training
 | 
			
		||||
| 
						 | 
				
			
			@ -263,6 +263,12 @@ def args_sanity_check():
 | 
			
		|||
            gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
 | 
			
		||||
        ), "sequence parallel does not support use_flash_attn=False"
 | 
			
		||||
 | 
			
		||||
    # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
 | 
			
		||||
    if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
 | 
			
		||||
        assert (
 | 
			
		||||
            gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True
 | 
			
		||||
        ), "only support interleaved pipeline scheduler with overlap"
 | 
			
		||||
 | 
			
		||||
    # monitoring default config
 | 
			
		||||
    monitor_default_config = {
 | 
			
		||||
        "alert_address": None,  # compatible with old alert config
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,14 +4,13 @@
 | 
			
		|||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
 | 
			
		||||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from internlm.core.context import ParallelMode
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.model.utils import fused_dense_func_torch
 | 
			
		||||
from internlm.model.utils import Silu, fused_dense_func_torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScaleColumnParallelLinear(nn.Linear):
 | 
			
		||||
| 
						 | 
				
			
			@ -197,5 +196,7 @@ class FeedForward(nn.Module):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        out = self.w3(F.silu(self.w1(x)) * self.w2(x))
 | 
			
		||||
        w1_o = self.w1(x)
 | 
			
		||||
        w2_o = self.w2(x)
 | 
			
		||||
        out = self.w3(Silu(w1_o, w2_o))
 | 
			
		||||
        return out
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -225,3 +225,10 @@ def is_norm_param(param: torch.Tensor) -> bool:
 | 
			
		|||
    if hasattr(param, "is_norm") and param.is_norm:
 | 
			
		||||
        return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def Silu(w1_o, w2_o):
 | 
			
		||||
    return F.silu(w1_o) * w2_o
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Silu = torch.jit.script(Silu)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,390 @@
 | 
			
		|||
import math
 | 
			
		||||
import subprocess
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
 | 
			
		||||
import internlm
 | 
			
		||||
from internlm.core.context import ParallelMode
 | 
			
		||||
from internlm.core.context import global_context as gpc
 | 
			
		||||
from internlm.core.scheduler import SchedulerMetricHook
 | 
			
		||||
from internlm.core.trainer import TrainState
 | 
			
		||||
from internlm.initialize import initialize_distributed_env
 | 
			
		||||
from internlm.model.loss import FlashGPTLMLoss
 | 
			
		||||
from internlm.model.metrics import AccPerplex
 | 
			
		||||
from internlm.train import (
 | 
			
		||||
    get_train_data_loader,
 | 
			
		||||
    initialize_model,
 | 
			
		||||
    initialize_optimizer,
 | 
			
		||||
    load_new_batch,
 | 
			
		||||
)
 | 
			
		||||
from internlm.utils.common import BatchSkipper, launch_time
 | 
			
		||||
from internlm.utils.gputest import empty_cache_and_diag
 | 
			
		||||
from internlm.utils.megatron_timers import megatron_timer as timer
 | 
			
		||||
from internlm.utils.model_checkpoint import CheckpointManager
 | 
			
		||||
 | 
			
		||||
CONFIG_FILE_PATH = "./configs/7B_sft.py"
 | 
			
		||||
TOTAL_STEPS = 10
 | 
			
		||||
LOSS_SPIKE_LIMIT = 1.5
 | 
			
		||||
LOSS_DEVIATION_LIMIT = 0.2
 | 
			
		||||
BASELINE_LOSS_LIST = [
 | 
			
		||||
    11.64188003540039,
 | 
			
		||||
    7.9205322265625,
 | 
			
		||||
    6.944362163543701,
 | 
			
		||||
    6.147305488586426,
 | 
			
		||||
    6.060564994812012,
 | 
			
		||||
    5.660439491271973,
 | 
			
		||||
    5.19430685043335,
 | 
			
		||||
    5.157323837280273,
 | 
			
		||||
    4.769168376922607,
 | 
			
		||||
    4.449280738830566,
 | 
			
		||||
]
 | 
			
		||||
cur_loss_list = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train():
 | 
			
		||||
    # initialize distributed environment
 | 
			
		||||
    initialize_distributed_env(config=CONFIG_FILE_PATH)
 | 
			
		||||
    assert hasattr(gpc, "config") and gpc.config is not None
 | 
			
		||||
 | 
			
		||||
    # init setting
 | 
			
		||||
    gpc.config.data.total_steps = TOTAL_STEPS
 | 
			
		||||
    gpc.config.lr_scheduler.total_steps = TOTAL_STEPS
 | 
			
		||||
    total_steps = gpc.config.data.total_steps
 | 
			
		||||
    skip_batches = gpc.config.data.skip_batches
 | 
			
		||||
    label_smoothing = gpc.config.loss.label_smoothing
 | 
			
		||||
 | 
			
		||||
    # get and broadcast current time
 | 
			
		||||
    current_time = launch_time()
 | 
			
		||||
    objs = [current_time]
 | 
			
		||||
    dist.broadcast_object_list(objs, src=0)
 | 
			
		||||
    current_time = objs[0]
 | 
			
		||||
 | 
			
		||||
    # initialize model
 | 
			
		||||
    model = initialize_model()
 | 
			
		||||
 | 
			
		||||
    # initialize loss function
 | 
			
		||||
    criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
 | 
			
		||||
 | 
			
		||||
    # initialize the train data loader
 | 
			
		||||
    train_dl, dataset_types = get_train_data_loader(num_worker=4)
 | 
			
		||||
 | 
			
		||||
    # initialize and resume train state
 | 
			
		||||
    train_state = TrainState(gpc.config, train_dl.batch_sampler)
 | 
			
		||||
 | 
			
		||||
    optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
 | 
			
		||||
 | 
			
		||||
    with open(CONFIG_FILE_PATH, "r") as f:
 | 
			
		||||
        config_lines = f.readlines()
 | 
			
		||||
    ckpt_manager = CheckpointManager(
 | 
			
		||||
        ckpt_config=gpc.config.ckpt,
 | 
			
		||||
        model=model,
 | 
			
		||||
        optimizer=optimizer,
 | 
			
		||||
        lr_scheduler=lr_scheduler,
 | 
			
		||||
        train_dl=train_dl,
 | 
			
		||||
        model_config=gpc.config.model,
 | 
			
		||||
        model_config_file="".join(config_lines),
 | 
			
		||||
        feishu_address=gpc.config.monitor.alert.feishu_alert_address,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Loading other persistent training states.
 | 
			
		||||
    ckpt_manager.try_resume_training(train_state, current_time)
 | 
			
		||||
 | 
			
		||||
    # initialize metric for calculating accuracy and perplexity
 | 
			
		||||
    metric = AccPerplex(
 | 
			
		||||
        device=torch.cuda.current_device(),
 | 
			
		||||
        tp_pg=gpc.get_group(ParallelMode.TENSOR),
 | 
			
		||||
        dp_pg=gpc.get_group(ParallelMode.DATA),
 | 
			
		||||
        dataset_types=dataset_types,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # initialize trainer
 | 
			
		||||
    scheduler_hooks = [
 | 
			
		||||
        SchedulerMetricHook(
 | 
			
		||||
            metric=metric,
 | 
			
		||||
            skip=(
 | 
			
		||||
                gpc.is_using_pp()
 | 
			
		||||
                and hasattr(gpc.config.model, "num_chunks")
 | 
			
		||||
                and gpc.config.model.num_chunks > 1
 | 
			
		||||
                and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    trainer, train_dl, _, _ = internlm.initialize_trainer(
 | 
			
		||||
        model=model,
 | 
			
		||||
        optimizer=optimizer,
 | 
			
		||||
        criterion=criterion,
 | 
			
		||||
        train_dataloader=train_dl,
 | 
			
		||||
        lr_scheduler=lr_scheduler,
 | 
			
		||||
        beta2_scheduler=beta2_scheduler,
 | 
			
		||||
        scheduler_hooks=scheduler_hooks,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # initialize the batch skipper
 | 
			
		||||
    batch_skipper = BatchSkipper(skip_batches)
 | 
			
		||||
 | 
			
		||||
    trainer.train()
 | 
			
		||||
 | 
			
		||||
    # transfer the train data loader into train data iterator
 | 
			
		||||
    train_iter = iter(train_dl)
 | 
			
		||||
 | 
			
		||||
    # start iterating the train data and begin training
 | 
			
		||||
    for batch_count in range(train_state.batch_count, total_steps):
 | 
			
		||||
        empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
 | 
			
		||||
        timer("one-batch").start()
 | 
			
		||||
 | 
			
		||||
        # load batch data
 | 
			
		||||
        batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
 | 
			
		||||
 | 
			
		||||
        # record the consumed samples in training
 | 
			
		||||
        train_state.batch_count = batch_count
 | 
			
		||||
        train_state.num_consumed_samples_in_epoch += len(batch[1])
 | 
			
		||||
        if batch_skipper(batch_count):  # skip this batch
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                print(f"Skip batch count:`{batch_count}`...")
 | 
			
		||||
            timer("one-batch").stop()
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # zero the grads of parameters
 | 
			
		||||
        trainer.zero_grad()
 | 
			
		||||
        # process data
 | 
			
		||||
        if batch[0].get("type_ids", None) is not None:
 | 
			
		||||
            metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
 | 
			
		||||
 | 
			
		||||
        # do forward and backward
 | 
			
		||||
        timer("fwd-bwd").start()
 | 
			
		||||
 | 
			
		||||
        _, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
 | 
			
		||||
        if gpc.is_rank_for_log():
 | 
			
		||||
            assert loss is not None and not math.isnan(loss.item())
 | 
			
		||||
            global cur_loss_list
 | 
			
		||||
            cur_loss_list.append(loss.item())
 | 
			
		||||
        timer("fwd-bwd").stop()
 | 
			
		||||
 | 
			
		||||
        # update parameters, and returns (success_update, grad_norm)
 | 
			
		||||
        trainer_result = trainer.step()
 | 
			
		||||
        assert trainer_result is not None
 | 
			
		||||
 | 
			
		||||
        success_update, _ = trainer_result
 | 
			
		||||
        assert success_update, "Error: grad norm inf or nan occurs!"
 | 
			
		||||
        if success_update:  # update parameters successfully
 | 
			
		||||
            train_state.step_count += 1
 | 
			
		||||
        else:
 | 
			
		||||
            train_state.inf_nan_skip_batches += 1  # record the amount of updating parameters unsuccessfully.
 | 
			
		||||
 | 
			
		||||
        timer("one-batch").stop()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_loss_spike():
 | 
			
		||||
    if gpc.is_rank_for_log():
 | 
			
		||||
        for step in range(1, TOTAL_STEPS):
 | 
			
		||||
            assert (
 | 
			
		||||
                cur_loss_list[step] < cur_loss_list[step - 1] * LOSS_SPIKE_LIMIT
 | 
			
		||||
            ), f"The loss spike occurs, {cur_loss_list[step - 1]}->{cur_loss_list[step]}, please check it!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_loss_accuracy():
 | 
			
		||||
    if gpc.is_rank_for_log():
 | 
			
		||||
        for cur, target in zip(cur_loss_list, BASELINE_LOSS_LIST):
 | 
			
		||||
            assert (
 | 
			
		||||
                abs(cur - target) < LOSS_DEVIATION_LIMIT
 | 
			
		||||
            ), f"The loss accuracy is abnormal, {target}->{cur}, please check it!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCaseTrain8GPU:
 | 
			
		||||
    """
 | 
			
		||||
    Test cases for Model Training with 8 GPUs.
 | 
			
		||||
    Parallel Config:
 | 
			
		||||
        data parallel size = 8.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def setup_class():
 | 
			
		||||
        # model training
 | 
			
		||||
        train()
 | 
			
		||||
 | 
			
		||||
        # print loss value
 | 
			
		||||
        print(f"cur_loss_list: {cur_loss_list}", flush=True)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_8GPU
 | 
			
		||||
    def test_loss_spike_with_dp8():
 | 
			
		||||
        check_loss_spike()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_8GPU
 | 
			
		||||
    def test_loss_accuracy_with_dp8():
 | 
			
		||||
        check_loss_accuracy()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCaseTrain16GPUWith8DP2TP:
 | 
			
		||||
    """
 | 
			
		||||
    Test cases for Model Training with 16 GPUs.
 | 
			
		||||
    Parallel Config:
 | 
			
		||||
        data parallel size = 8.
 | 
			
		||||
        tensor parallel size = 2.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def setup_class():
 | 
			
		||||
        # update config tensor parallel size
 | 
			
		||||
        command = f"sed -i 's/^.*tensor=.*/    tensor=2,/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
 | 
			
		||||
        # model training
 | 
			
		||||
        train()
 | 
			
		||||
 | 
			
		||||
        # print loss value
 | 
			
		||||
        print(f"cur_loss_list: {cur_loss_list}", flush=True)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2TP
 | 
			
		||||
    def test_loss_spike_with_dp8_tp2():
 | 
			
		||||
        check_loss_spike()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2TP
 | 
			
		||||
    def test_loss_accuracy_with_dp8_tp2():
 | 
			
		||||
        check_loss_accuracy()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCaseTrain16GPUWith8DP2TPSP:
 | 
			
		||||
    """
 | 
			
		||||
    Test cases for Model Training with 16 GPUs.
 | 
			
		||||
    Parallel Config:
 | 
			
		||||
        data parallel size = 8.
 | 
			
		||||
        tensor parallel size = 2.
 | 
			
		||||
        sequence parallel = True.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def setup_class():
 | 
			
		||||
        # update config tensor parallel size and sequence parallel
 | 
			
		||||
        command = f"sed -i 's/^.*tensor=.*/    tensor=2,/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
        command = f"sed -i 's/^.*sequence_parallel=.*/    sequence_parallel=True,/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
 | 
			
		||||
        # model training
 | 
			
		||||
        train()
 | 
			
		||||
 | 
			
		||||
        # print loss value
 | 
			
		||||
        print(f"cur_loss_list: {cur_loss_list}", flush=True)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2TPSP
 | 
			
		||||
    def test_loss_spike_with_dp8_tp2_sp():
 | 
			
		||||
        check_loss_spike()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2TPSP
 | 
			
		||||
    def test_loss_accuracy_with_dp8_tp2_sp():
 | 
			
		||||
        check_loss_accuracy()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCaseTrain16GPUWith8DP2PP:
 | 
			
		||||
    """
 | 
			
		||||
    Test cases for Model Training with 16 GPUs.
 | 
			
		||||
    Parallel Config:
 | 
			
		||||
        data parallel size = 8.
 | 
			
		||||
        pipeline parallel size = 2.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def setup_class():
 | 
			
		||||
        # update config pipeline parallel size
 | 
			
		||||
        command = f"sed -i 's/^.*pipeline=.*/    pipeline=dict(size=2),/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
        command = f"sed -i 's/^.*tensor=.*/    tensor=1,/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
 | 
			
		||||
        # model training
 | 
			
		||||
        train()
 | 
			
		||||
 | 
			
		||||
        # print loss value
 | 
			
		||||
        print(f"cur_loss_list: {cur_loss_list}", flush=True)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2PP
 | 
			
		||||
    def test_loss_spike_with_dp8_pp2():
 | 
			
		||||
        check_loss_spike()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2PP
 | 
			
		||||
    def test_loss_accuracy_with_dp8_pp2():
 | 
			
		||||
        check_loss_accuracy()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCaseTrain16GPUWith8DP2PPInterleaved:
 | 
			
		||||
    """
 | 
			
		||||
    Test cases for Model Training with 16 GPUs.
 | 
			
		||||
    Parallel Config:
 | 
			
		||||
        data parallel size = 8.
 | 
			
		||||
        pipeline parallel size = 2.
 | 
			
		||||
        interleaved scheduler = True.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def setup_class():
 | 
			
		||||
        # update config pipeline parallel size
 | 
			
		||||
        command = f"sed -i 's/^.*pipeline=.*/    pipeline=dict(size=2),/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
        command = f"sed -i 's/^.*num_chunks=.*/    num_chunks=2,/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
        command = f"sed -i 's/^.*tensor=.*/    tensor=1,/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=False)
 | 
			
		||||
 | 
			
		||||
        # model training
 | 
			
		||||
        train()
 | 
			
		||||
 | 
			
		||||
        # print loss value
 | 
			
		||||
        print(f"cur_loss_list: {cur_loss_list}", flush=True)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2PP_Interleaved
 | 
			
		||||
    def test_loss_spike_with_dp8_pp2_interleaved():
 | 
			
		||||
        check_loss_spike()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2PP_Interleaved
 | 
			
		||||
    def test_loss_accuracy_with_dp8_pp2_interleaved():
 | 
			
		||||
        check_loss_accuracy()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestCaseTrain16GPUWith8DP2PPInterleavedOverlap:
 | 
			
		||||
    """
 | 
			
		||||
    Test cases for Model Training with 16 GPUs.
 | 
			
		||||
    Parallel Config:
 | 
			
		||||
        data parallel size = 8.
 | 
			
		||||
        pipeline parallel size = 2.
 | 
			
		||||
        interleaved scheduler = True.
 | 
			
		||||
        interleaved overlap = True.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def setup_class():
 | 
			
		||||
        # update config pipeline parallel size
 | 
			
		||||
        command = f"sed -i 's/^.*pipeline=.*/    pipeline=dict(size=2, interleaved_overlap=True),/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
        command = f"sed -i 's/^.*num_chunks=.*/    num_chunks=2,/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
        command = f"sed -i 's/^.*tensor=.*/    tensor=1,/' {CONFIG_FILE_PATH}"
 | 
			
		||||
        subprocess.run(command, shell=True, check=True)
 | 
			
		||||
 | 
			
		||||
        # model training
 | 
			
		||||
        train()
 | 
			
		||||
 | 
			
		||||
        # print loss value
 | 
			
		||||
        print(f"cur_loss_list: {cur_loss_list}", flush=True)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2PP_InterleavedOverlap
 | 
			
		||||
    def test_loss_spike_with_dp8_pp2_interleaved_overlap():
 | 
			
		||||
        check_loss_spike()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @pytest.mark.training_16GPU_8DP2PP_InterleavedOverlap
 | 
			
		||||
    def test_loss_accuracy_with_dp8_pp2_interleaved_overlap():
 | 
			
		||||
        check_loss_accuracy()
 | 
			
		||||
| 
						 | 
				
			
			@ -109,3 +109,29 @@ InternLM 在 GSM8K 数据集中带工具和不带工具的性能表现:
 | 
			
		|||
| -------- | -------------------- |
 | 
			
		||||
| w/o tool | 34.5                 |
 | 
			
		||||
| w tool   | 39.2                 |
 | 
			
		||||
 | 
			
		||||
# openai_api.py
 | 
			
		||||
 | 
			
		||||
使用 OpenAI 接口实现的流式部署,可以应用于基于 ChatGPT 的应用的后端。部署的命令为:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python openai_api.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
然后可以通过下面代码调用部署好的 api:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import openai
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    openai.api_base = "http://localhost:8000/internlm"
 | 
			
		||||
    openai.api_key = "none"
 | 
			
		||||
    for chunk in openai.ChatCompletion.create(
 | 
			
		||||
        model="internlm-chat-7b",
 | 
			
		||||
        messages=[
 | 
			
		||||
            {"role": "user", "content": "你好"},
 | 
			
		||||
        ],
 | 
			
		||||
        stream=True
 | 
			
		||||
    ):
 | 
			
		||||
        if hasattr(chunk.choices[0].delta, "content"):
 | 
			
		||||
            print(chunk.choices[0].delta.content, end="", flush=True)
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			@ -107,3 +107,29 @@ InternLM performance in the GSM8K dataset with and without tools:
 | 
			
		|||
| -------- | -------------------- |
 | 
			
		||||
| w/o tool | 34.5                 |
 | 
			
		||||
| w tool   | 39.2                 |
 | 
			
		||||
 | 
			
		||||
# openai_api.py
 | 
			
		||||
 | 
			
		||||
`openai_api.py` implements stream deployment with OpenAI APIs which an be used on any applications based on ChatGPT. Below is the command to deploy `internlm`:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python openai_api.py
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Then it is able to call the deployed API using the following python code:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import openai
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    openai.api_base = "http://localhost:8000/internlm"
 | 
			
		||||
    openai.api_key = "none"
 | 
			
		||||
    for chunk in openai.ChatCompletion.create(
 | 
			
		||||
        model="internlm-chat-7b",
 | 
			
		||||
        messages=[
 | 
			
		||||
            {"role": "user", "content": "Hello!"},
 | 
			
		||||
        ],
 | 
			
		||||
        stream=True
 | 
			
		||||
    ):
 | 
			
		||||
        if hasattr(chunk.choices[0].delta, "content"):
 | 
			
		||||
            print(chunk.choices[0].delta.content, end="", flush=True)
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,157 @@
 | 
			
		|||
import time
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from typing import List, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import uvicorn
 | 
			
		||||
from fastapi import FastAPI, HTTPException
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
from sse_starlette.sse import EventSourceResponse
 | 
			
		||||
from transformers import AutoModelForCausalLM, AutoTokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def lifespan(app: FastAPI):  # collects GPU memory
 | 
			
		||||
    yield
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        torch.cuda.ipc_collect()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app = FastAPI(lifespan=lifespan)
 | 
			
		||||
 | 
			
		||||
app.add_middleware(
 | 
			
		||||
    CORSMiddleware,
 | 
			
		||||
    allow_origins=["*"],
 | 
			
		||||
    allow_credentials=True,
 | 
			
		||||
    allow_methods=["*"],
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelCard(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    object: str = "model"
 | 
			
		||||
    created: int = Field(default_factory=lambda: int(time.time()))
 | 
			
		||||
    owned_by: str = "owner"
 | 
			
		||||
    root: Optional[str] = None
 | 
			
		||||
    parent: Optional[str] = None
 | 
			
		||||
    permission: Optional[list] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelList(BaseModel):
 | 
			
		||||
    object: str = "list"
 | 
			
		||||
    data: List[ModelCard] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatMessage(BaseModel):
 | 
			
		||||
    role: Literal["user", "assistant", "system"]
 | 
			
		||||
    content: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaMessage(BaseModel):
 | 
			
		||||
    role: Optional[Literal["user", "assistant", "system"]] = None
 | 
			
		||||
    content: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionRequest(BaseModel):
 | 
			
		||||
    model: str
 | 
			
		||||
    messages: List[ChatMessage]
 | 
			
		||||
    temperature: Optional[float] = None
 | 
			
		||||
    top_p: Optional[float] = None
 | 
			
		||||
    max_length: Optional[int] = None
 | 
			
		||||
    stream: Optional[bool] = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponseChoice(BaseModel):
 | 
			
		||||
    index: int
 | 
			
		||||
    message: ChatMessage
 | 
			
		||||
    finish_reason: Literal["stop", "length"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponseStreamChoice(BaseModel):
 | 
			
		||||
    index: int
 | 
			
		||||
    delta: DeltaMessage
 | 
			
		||||
    finish_reason: Optional[Literal["stop", "length"]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponse(BaseModel):
 | 
			
		||||
    model: str
 | 
			
		||||
    object: Literal["chat.completion", "chat.completion.chunk"]
 | 
			
		||||
    choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
 | 
			
		||||
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/internlm/models", response_model=ModelList)
 | 
			
		||||
async def list_models():
 | 
			
		||||
    model_card = ModelCard(id="internlm")
 | 
			
		||||
    return ModelList(data=[model_card])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/internlm/chat/completions", response_model=ChatCompletionResponse)
 | 
			
		||||
async def create_chat_completion(request: ChatCompletionRequest):
 | 
			
		||||
    global model, tokenizer
 | 
			
		||||
 | 
			
		||||
    if request.messages[-1].role != "user":
 | 
			
		||||
        raise HTTPException(status_code=400, detail="Invalid request")
 | 
			
		||||
    query = request.messages[-1].content
 | 
			
		||||
 | 
			
		||||
    prev_messages = request.messages[:-1]
 | 
			
		||||
    if len(prev_messages) > 0 and prev_messages[0].role == "system":
 | 
			
		||||
        query = prev_messages.pop(0).content + query
 | 
			
		||||
 | 
			
		||||
    history = []
 | 
			
		||||
    if len(prev_messages) % 2 == 0:
 | 
			
		||||
        for i in range(0, len(prev_messages), 2):
 | 
			
		||||
            if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
 | 
			
		||||
                history.append([prev_messages[i].content, prev_messages[i + 1].content])
 | 
			
		||||
 | 
			
		||||
    if request.stream:
 | 
			
		||||
        generate = predict(query, history, request.model)
 | 
			
		||||
        return EventSourceResponse(generate, media_type="text/event-stream")
 | 
			
		||||
 | 
			
		||||
    response, _ = model.chat(tokenizer, query, history=history)
 | 
			
		||||
    choice_data = ChatCompletionResponseChoice(
 | 
			
		||||
        index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def predict(query: str, history: List[List[str]], model_id: str):
 | 
			
		||||
    global model, tokenizer
 | 
			
		||||
 | 
			
		||||
    choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"), finish_reason=None)
 | 
			
		||||
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
 | 
			
		||||
    current_length = 0
 | 
			
		||||
 | 
			
		||||
    for new_response, _ in model.stream_chat(tokenizer, query, history):
 | 
			
		||||
        if len(new_response) == current_length:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        new_text = new_response[current_length:]
 | 
			
		||||
 | 
			
		||||
        current_length = len(new_response)
 | 
			
		||||
 | 
			
		||||
        choice_data = ChatCompletionResponseStreamChoice(
 | 
			
		||||
            index=0, delta=DeltaMessage(content=new_text), finish_reason=None
 | 
			
		||||
        )
 | 
			
		||||
        chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
        yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
 | 
			
		||||
    choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason="stop")
 | 
			
		||||
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
 | 
			
		||||
    yield "[DONE]"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    model_name = "internlm/internlm-chat-7b"
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
 | 
			
		||||
| 
						 | 
				
			
			@ -869,7 +869,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
 | 
			
		|||
            producer.start()
 | 
			
		||||
            while True:
 | 
			
		||||
                res = response_queue.get()
 | 
			
		||||
                if res is not None:
 | 
			
		||||
                if res is None:
 | 
			
		||||
                    return
 | 
			
		||||
                yield res
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue