feat(tools): support openai api (#313)

* fix(chat): fix stream_chat to return generator (#123)

* fix(configs/7B_sft.py): model dtype float16 to bfloat16 (#302)

* fix(convert2hf.py): fix the rotary_emb.inv_freq KeyError (#299)

* support openai api to deploy internlm

* update README for information os openai_api.py

* change example in README_EN.md to English

* delete unnecessary print; fix model card typo; fix chat epoch

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
Co-authored-by: zhjunqin <zhjunqin@users.noreply.github.com>
Co-authored-by: huangting4201 <1538303371@qq.com>
Co-authored-by: jiangtann <39088437+jiangtann@users.noreply.github.com>
pull/315/head
x54-729 2023-09-19 13:49:48 +08:00 committed by GitHub
parent ab513e1ddd
commit cd6426a249
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 209 additions and 0 deletions

View File

@ -109,3 +109,29 @@ InternLM 在 GSM8K 数据集中带工具和不带工具的性能表现:
| -------- | -------------------- |
| w/o tool | 34.5 |
| w tool | 39.2 |
# openai_api.py
使用 OpenAI 接口实现的流式部署,可以应用于基于 ChatGPT 的应用的后端。部署的命令为:
```bash
python openai_api.py
```
然后可以通过下面代码调用部署好的 api
```python
import openai
if __name__ == "__main__":
openai.api_base = "http://localhost:8000/internlm"
openai.api_key = "none"
for chunk in openai.ChatCompletion.create(
model="internlm-chat-7b",
messages=[
{"role": "user", "content": "你好"},
],
stream=True
):
if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, end="", flush=True)
```

View File

@ -107,3 +107,29 @@ InternLM performance in the GSM8K dataset with and without tools:
| -------- | -------------------- |
| w/o tool | 34.5 |
| w tool | 39.2 |
# openai_api.py
`openai_api.py` implements stream deployment with OpenAI APIs which an be used on any applications based on ChatGPT. Below is the command to deploy `internlm`:
```bash
python openai_api.py
```
Then it is able to call the deployed API using the following python code:
```python
import openai
if __name__ == "__main__":
openai.api_base = "http://localhost:8000/internlm"
openai.api_key = "none"
for chunk in openai.ChatCompletion.create(
model="internlm-chat-7b",
messages=[
{"role": "user", "content": "Hello!"},
],
stream=True
):
if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, end="", flush=True)
```

157
tools/openai_api.py Normal file
View File

@ -0,0 +1,157 @@
import time
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_length: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/internlm/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="internlm")
return ModelList(data=[model_card])
@app.post("/internlm/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
query = prev_messages.pop(0).content + query
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i + 1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i + 1].content])
if request.stream:
generate = predict(query, history, request.model)
return EventSourceResponse(generate, media_type="text/event-stream")
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
index=0, message=ChatMessage(role="assistant", content=response), finish_reason="stop"
)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
async def predict(query: str, history: List[List[str]], model_id: str):
global model, tokenizer
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"), finish_reason=None)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
current_length = 0
for new_response, _ in model.stream_chat(tokenizer, query, history):
if len(new_response) == current_length:
continue
new_text = new_response[current_length:]
current_length = len(new_response)
choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason="stop")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "[DONE]"
if __name__ == "__main__":
model_name = "internlm/internlm-chat-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
model.eval()
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)