merge Internlm/develop into feature_add_moe

pull/182/head
Qu Wenwen 2023-09-19 17:44:12 +08:00
commit b7ddc42dcd
10 changed files with 673 additions and 4 deletions

56
.github/workflows/e2e_test.yaml vendored Normal file
View File

@ -0,0 +1,56 @@
name: e2e-tests
on:
pull_request:
branches:
- "main"
- "develop"
paths-ignore:
- "doc/**"
- "**.md"
env:
WORKSPACE_PREFIX: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-4)
SLURM_PARTITION: llm
jobs:
check-requirements:
runs-on: [lmtest]
steps:
- name: mask env
run: |
echo "::add-mask::${{env.WORKSPACE_PREFIX}}"
- uses: actions/checkout@v3
with:
fetch-depth: 2
- name: check-requirements
run: |
source activate internlm-env-test
changed_files=$(git diff --name-only -r HEAD^1 HEAD)
echo $changed_files
if [[ $changed_files =~ "runtime.txt" ]]; then
pip install -r requirements/runtime.txt
fi
if [[ $changed_files =~ "torch.txt" ]]; then
pip install -r requirements/torch.txt
fi
e2e_tests:
if: ${{ always() }}
needs: check-requirements
runs-on: [lmtest]
timeout-minutes: 30
steps:
- name: mask env
run: |
echo "::add-mask::${{env.WORKSPACE_PREFIX}}"
- uses: actions/checkout@v3
- name: e2e-test
run: |
source activate internlm-env-test
srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU" ./tests/test_training
srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TP" ./tests/test_training
srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2TPSP" ./tests/test_training
srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP" ./tests/test_training
srun -p ${SLURM_PARTITION} --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n16 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_16GPU_8DP2PP_InterleavedOverlap" ./tests/test_training

View File

@ -263,6 +263,12 @@ def args_sanity_check():
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
assert (
gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True
), "only support interleaved pipeline scheduler with overlap"
# monitoring default config
monitor_default_config = {
"alert_address": None, # compatible with old alert config

View File

@ -4,14 +4,13 @@
from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import fused_dense_func_torch
from internlm.model.utils import Silu, fused_dense_func_torch
class ScaleColumnParallelLinear(nn.Linear):
@ -197,5 +196,7 @@ class FeedForward(nn.Module):
)
def forward(self, x):
out = self.w3(F.silu(self.w1(x)) * self.w2(x))
w1_o = self.w1(x)
w2_o = self.w2(x)
out = self.w3(Silu(w1_o, w2_o))
return out

View File

@ -225,3 +225,10 @@ def is_norm_param(param: torch.Tensor) -> bool:
if hasattr(param, "is_norm") and param.is_norm:
return True
return False
def Silu(w1_o, w2_o):
return F.silu(w1_o) * w2_o
Silu = torch.jit.script(Silu)

View File

View File

@ -0,0 +1,390 @@
import math
import subprocess
import pytest
import torch
import torch.distributed as dist
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.initialize import initialize_distributed_env
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.train import (
get_train_data_loader,
initialize_model,
initialize_optimizer,
load_new_batch,
)
from internlm.utils.common import BatchSkipper, launch_time
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import CheckpointManager
CONFIG_FILE_PATH = "./configs/7B_sft.py"
TOTAL_STEPS = 10
LOSS_SPIKE_LIMIT = 1.5
LOSS_DEVIATION_LIMIT = 0.2
BASELINE_LOSS_LIST = [
11.64188003540039,
7.9205322265625,
6.944362163543701,
6.147305488586426,
6.060564994812012,
5.660439491271973,
5.19430685043335,
5.157323837280273,
4.769168376922607,
4.449280738830566,
]
cur_loss_list = []
def train():
# initialize distributed environment
initialize_distributed_env(config=CONFIG_FILE_PATH)
assert hasattr(gpc, "config") and gpc.config is not None
# init setting
gpc.config.data.total_steps = TOTAL_STEPS
gpc.config.lr_scheduler.total_steps = TOTAL_STEPS
total_steps = gpc.config.data.total_steps
skip_batches = gpc.config.data.skip_batches
label_smoothing = gpc.config.loss.label_smoothing
# get and broadcast current time
current_time = launch_time()
objs = [current_time]
dist.broadcast_object_list(objs, src=0)
current_time = objs[0]
# initialize model
model = initialize_model()
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
# initialize the train data loader
train_dl, dataset_types = get_train_data_loader(num_worker=4)
# initialize and resume train state
train_state = TrainState(gpc.config, train_dl.batch_sampler)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
with open(CONFIG_FILE_PATH, "r") as f:
config_lines = f.readlines()
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
train_dl=train_dl,
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
)
# Loading other persistent training states.
ckpt_manager.try_resume_training(train_state, current_time)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
dataset_types=dataset_types,
)
# initialize trainer
scheduler_hooks = [
SchedulerMetricHook(
metric=metric,
skip=(
gpc.is_using_pp()
and hasattr(gpc.config.model, "num_chunks")
and gpc.config.model.num_chunks > 1
and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
),
),
]
trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=scheduler_hooks,
)
# initialize the batch skipper
batch_skipper = BatchSkipper(skip_batches)
trainer.train()
# transfer the train data loader into train data iterator
train_iter = iter(train_dl)
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
timer("one-batch").start()
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
# record the consumed samples in training
train_state.batch_count = batch_count
train_state.num_consumed_samples_in_epoch += len(batch[1])
if batch_skipper(batch_count): # skip this batch
if gpc.is_rank_for_log():
print(f"Skip batch count:`{batch_count}`...")
timer("one-batch").stop()
continue
# zero the grads of parameters
trainer.zero_grad()
# process data
if batch[0].get("type_ids", None) is not None:
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
if gpc.is_rank_for_log():
assert loss is not None and not math.isnan(loss.item())
global cur_loss_list
cur_loss_list.append(loss.item())
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)
trainer_result = trainer.step()
assert trainer_result is not None
success_update, _ = trainer_result
assert success_update, "Error: grad norm inf or nan occurs!"
if success_update: # update parameters successfully
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
timer("one-batch").stop()
def check_loss_spike():
if gpc.is_rank_for_log():
for step in range(1, TOTAL_STEPS):
assert (
cur_loss_list[step] < cur_loss_list[step - 1] * LOSS_SPIKE_LIMIT
), f"The loss spike occurs, {cur_loss_list[step - 1]}->{cur_loss_list[step]}, please check it!"
def check_loss_accuracy():
if gpc.is_rank_for_log():
for cur, target in zip(cur_loss_list, BASELINE_LOSS_LIST):
assert (
abs(cur - target) < LOSS_DEVIATION_LIMIT
), f"The loss accuracy is abnormal, {target}->{cur}, please check it!"
class TestCaseTrain8GPU:
"""
Test cases for Model Training with 8 GPUs.
Parallel Config:
data parallel size = 8.
"""
@staticmethod
def setup_class():
# model training
train()
# print loss value
print(f"cur_loss_list: {cur_loss_list}", flush=True)
@staticmethod
@pytest.mark.training_8GPU
def test_loss_spike_with_dp8():
check_loss_spike()
@staticmethod
@pytest.mark.training_8GPU
def test_loss_accuracy_with_dp8():
check_loss_accuracy()
class TestCaseTrain16GPUWith8DP2TP:
"""
Test cases for Model Training with 16 GPUs.
Parallel Config:
data parallel size = 8.
tensor parallel size = 2.
"""
@staticmethod
def setup_class():
# update config tensor parallel size
command = f"sed -i 's/^.*tensor=.*/ tensor=2,/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
# model training
train()
# print loss value
print(f"cur_loss_list: {cur_loss_list}", flush=True)
@staticmethod
@pytest.mark.training_16GPU_8DP2TP
def test_loss_spike_with_dp8_tp2():
check_loss_spike()
@staticmethod
@pytest.mark.training_16GPU_8DP2TP
def test_loss_accuracy_with_dp8_tp2():
check_loss_accuracy()
class TestCaseTrain16GPUWith8DP2TPSP:
"""
Test cases for Model Training with 16 GPUs.
Parallel Config:
data parallel size = 8.
tensor parallel size = 2.
sequence parallel = True.
"""
@staticmethod
def setup_class():
# update config tensor parallel size and sequence parallel
command = f"sed -i 's/^.*tensor=.*/ tensor=2,/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
command = f"sed -i 's/^.*sequence_parallel=.*/ sequence_parallel=True,/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
# model training
train()
# print loss value
print(f"cur_loss_list: {cur_loss_list}", flush=True)
@staticmethod
@pytest.mark.training_16GPU_8DP2TPSP
def test_loss_spike_with_dp8_tp2_sp():
check_loss_spike()
@staticmethod
@pytest.mark.training_16GPU_8DP2TPSP
def test_loss_accuracy_with_dp8_tp2_sp():
check_loss_accuracy()
class TestCaseTrain16GPUWith8DP2PP:
"""
Test cases for Model Training with 16 GPUs.
Parallel Config:
data parallel size = 8.
pipeline parallel size = 2.
"""
@staticmethod
def setup_class():
# update config pipeline parallel size
command = f"sed -i 's/^.*pipeline=.*/ pipeline=dict(size=2),/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
command = f"sed -i 's/^.*tensor=.*/ tensor=1,/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
# model training
train()
# print loss value
print(f"cur_loss_list: {cur_loss_list}", flush=True)
@staticmethod
@pytest.mark.training_16GPU_8DP2PP
def test_loss_spike_with_dp8_pp2():
check_loss_spike()
@staticmethod
@pytest.mark.training_16GPU_8DP2PP
def test_loss_accuracy_with_dp8_pp2():
check_loss_accuracy()
class TestCaseTrain16GPUWith8DP2PPInterleaved:
"""
Test cases for Model Training with 16 GPUs.
Parallel Config:
data parallel size = 8.
pipeline parallel size = 2.
interleaved scheduler = True.
"""
@staticmethod
def setup_class():
# update config pipeline parallel size
command = f"sed -i 's/^.*pipeline=.*/ pipeline=dict(size=2),/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
command = f"sed -i 's/^.*num_chunks=.*/ num_chunks=2,/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
command = f"sed -i 's/^.*tensor=.*/ tensor=1,/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=False)
# model training
train()
# print loss value
print(f"cur_loss_list: {cur_loss_list}", flush=True)
@staticmethod
@pytest.mark.training_16GPU_8DP2PP_Interleaved
def test_loss_spike_with_dp8_pp2_interleaved():
check_loss_spike()
@staticmethod
@pytest.mark.training_16GPU_8DP2PP_Interleaved
def test_loss_accuracy_with_dp8_pp2_interleaved():
check_loss_accuracy()
class TestCaseTrain16GPUWith8DP2PPInterleavedOverlap:
"""
Test cases for Model Training with 16 GPUs.
Parallel Config:
data parallel size = 8.
pipeline parallel size = 2.
interleaved scheduler = True.
interleaved overlap = True.
"""
@staticmethod
def setup_class():
# update config pipeline parallel size
command = f"sed -i 's/^.*pipeline=.*/ pipeline=dict(size=2, interleaved_overlap=True),/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
command = f"sed -i 's/^.*num_chunks=.*/ num_chunks=2,/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
command = f"sed -i 's/^.*tensor=.*/ tensor=1,/' {CONFIG_FILE_PATH}"
subprocess.run(command, shell=True, check=True)
# model training
train()
# print loss value
print(f"cur_loss_list: {cur_loss_list}", flush=True)
@staticmethod
@pytest.mark.training_16GPU_8DP2PP_InterleavedOverlap
def test_loss_spike_with_dp8_pp2_interleaved_overlap():
check_loss_spike()
@staticmethod
@pytest.mark.training_16GPU_8DP2PP_InterleavedOverlap
def test_loss_accuracy_with_dp8_pp2_interleaved_overlap():
check_loss_accuracy()

View File

@ -109,3 +109,29 @@ InternLM 在 GSM8K 数据集中带工具和不带工具的性能表现:
| -------- | -------------------- |
| w/o tool | 34.5 |
| w tool | 39.2 |
# openai_api.py
使用 OpenAI 接口实现的流式部署,可以应用于基于 ChatGPT 的应用的后端。部署的命令为:
```bash
python openai_api.py
```
然后可以通过下面代码调用部署好的 api
```python
import openai
if __name__ == "__main__":
openai.api_base = "http://localhost:8000/internlm"
openai.api_key = "none"
for chunk in openai.ChatCompletion.create(
model="internlm-chat-7b",
messages=[
{"role": "user", "content": "你好"},
],
stream=True
):
if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, end="", flush=True)
```

View File

@ -107,3 +107,29 @@ InternLM performance in the GSM8K dataset with and without tools:
| -------- | -------------------- |
| w/o tool | 34.5 |
| w tool | 39.2 |
# openai_api.py
`openai_api.py` implements stream deployment with OpenAI APIs which an be used on any applications based on ChatGPT. Below is the command to deploy `internlm`:
```bash
python openai_api.py
```
Then it is able to call the deployed API using the following python code:
```python
import openai
if __name__ == "__main__":
openai.api_base = "http://localhost:8000/internlm"
openai.api_key = "none"
for chunk in openai.ChatCompletion.create(
model="internlm-chat-7b",
messages=[
{"role": "user", "content": "Hello!"},
],
stream=True
):
if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, end="", flush=True)
```

157
tools/openai_api.py Normal file
View File

@ -0,0 +1,157 @@
import time
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/internlm/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="internlm")
return ModelList(data=[model_card])
@app.post("/internlm/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i + 1].content])
if request.stream:
generate = predict(query, history, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"), finish_reason=None)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
for new_response, _ in model.stream_chat(tokenizer, query, history):
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason="stop")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "[DONE]"
if __name__ == "__main__":
model_name = "internlm/internlm-chat-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
model.eval()
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@ -869,7 +869,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
producer.start()
while True:
res = response_queue.get()
if res is not None:
if res is None:
return
yield res