Merge pull request #5588 from hpcaitech/feat/online-serving

[Feature]Online Serving
pull/5707/head
Jianghai 2024-05-09 17:19:45 +08:00 committed by GitHub
commit 492520dbdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1172 additions and 34 deletions

View File

@ -62,6 +62,9 @@ class BatchBucket:
def current_batch_size(self): def current_batch_size(self):
return self._current_batch_size return self._current_batch_size
def __len__(self):
return self._current_batch_size
@property @property
def available_batch_size(self): def available_batch_size(self):
return self.max_batch_size - self._current_batch_size return self.max_batch_size - self._current_batch_size

View File

@ -1,10 +1,9 @@
""" """
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
""" """
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass, fields
from typing import Optional, Union from typing import Any, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -214,3 +213,18 @@ class InferenceConfig:
meta_config[type] = getattr(model_config, type) meta_config[type] = getattr(model_config, type)
return GenerationConfig.from_dict(meta_config) return GenerationConfig.from_dict(meta_config)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in fields(cls)]
inference_config_args = {}
for attr in attrs:
if attr in config_dict:
inference_config_args[attr] = config_dict[attr]
else:
inference_config_args[attr] = getattr(cls, attr)
# Set the attributes from the parsed arguments.
inference_config = cls(**inference_config_args)
return inference_config

View File

@ -0,0 +1,309 @@
import asyncio
import logging
from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine
# CLI logger
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("colossalai-inference")
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
msg = "Task finished unexpectedly. This should never happen! "
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
raise RuntimeError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
class RequstStream:
"""
A stream of Output for a request that can be iterated over asynchronously.
Attributes: 1.request_id: The id of the request.
2._future: A future that will be set when the request is finished.
Methods: set_result and get_result, results will be set when finished, for once, and
the `self.future` will be set to done.
"""
def __init__(self, request_id: int) -> None:
self.request_id = request_id
self._future = asyncio.Future()
def set_result(self, result) -> None:
"""Set final result and signal taht it's ready"""
if not self._future.done():
self._future.set_result(result)
async def get_result(self):
"""Wait for the result to be set and return it."""
return await self._future
@property
def finished(self) -> bool:
"""Check if the stream has finished by checking if the future is done."""
return self._future.done()
class Tracer:
"""
Recording new requests and finished requests.
Attributes: 1._request_streams: We create one stream for each request to trace the output.
2._finished_requests: A queue to store the finished requests.
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
4.new_requests_event: An event to notify the engine that there are new requests.
"""
def __init__(self) -> None:
self._request_streams: Dict[int, RequstStream] = {}
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item):
return item in self._request_streams
def init_event(self):
self.new_requests_event = asyncio.Event()
def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None:
"""
Propagate an exception to request streams (all if request_id is None).
"""
if request_id is not None:
self._request_streams[request_id].set_result(exc)
else:
for stream in self._request_streams.values():
stream.set_result(exc)
def process_finished_request(self, finished_request) -> None:
"""Process a finished request from the engine."""
request_id = finished_request.request_id
try:
self._request_streams[request_id].set_result(finished_request)
except:
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
self.abort_request(request_id)
def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
"""
Add a request to be sent to the engine on the next background
loop iteration.
"""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = RequstStream(request_id)
logger.info(f"Added request {request_id}.")
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set()
return stream
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info(f"Aborted request {request_id}.")
self._finished_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[request_id].finished:
# The request has already finished or been aborted.
# The requests in new_requests will be aborted when try to get them(if marked aborted)
return
self._request_streams[request_id].set_result(None)
def get_new_requests(self):
"""
Get new requests from http server.
"""
new_requests: List[Dict] = []
finished_requests: Set[int] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if new_request["request_id"] in finished_requests:
# The request has been aborted.
stream.set_result(None)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
class _AsyncInferenceEngine(InferenceEngine):
"""
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
Methods: 1. async_step: The async version of Engine.step()
"""
async def async_step(self) -> List[str]:
"""
The async version of Engine.step()
Performs one decoding iteration and returns newly generated results.
It first schedules the sequences to be executed in the next iteration.
Then, it executes the model and updates the scheduler with the model
outputs. Finally, it decodes the sequences and returns the newly
generated results.
"""
batch = self.request_handler.schedule()
loop = asyncio.get_running_loop()
# Use run_in_executor to asyncally run the sync method model.forward().
logits = await loop.run_in_executor(
None,
self.model,
batch,
self.k_cache,
self.v_cache,
)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()
for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id)
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
class AsyncInferenceEngine:
"""An asynchronous wrapper for the InferenceEngine class.
This class is used to wrap the InferenceEngine class to make it asynchronous.
It uses asyncio to create a background loop that keeps processing incoming
requests. Note that this class does not hold model directly, when incoming a new
request, it first called `add_request` and the Tracer will record the request, putting
it to the background `InferenceEngine`(done in background loop) to process. You can
consider this engine as an interface.
"""
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
def __init__(self, start_engine_loop: bool = True, **kwargs):
self.engine = self._init_engine(**kwargs)
self.background_loop = None
# reference to the unshielded loop
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracer = Tracer()
@property
def background_loop_status(self):
return self.background_loop is not None and not self.background_loop.done()
def start_background_loop(self):
if self.background_loop_status:
raise RuntimeError("Existing loop is running")
self._request_tracer.init_event()
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
)
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, **kwargs):
return self._engine_class(**kwargs)
async def step(self):
"""
Run engine to process requests
Returns True if there are in-progress requests.
"""
new_requests = self._request_tracer.get_new_requests()
for new_request in new_requests:
self.engine.add_single_request(**new_request)
newly_finished_seqs, has_running_requests = await self.engine.async_step()
for seq in newly_finished_seqs:
self._request_tracer.process_finished_request(seq)
return has_running_requests
async def _engine_abort(self, request_ids: Iterable[int]):
self.engine.abort_request(request_ids)
async def abort(self, request_id: int):
"""
Abort a single request
"""
if not self.background_loop_status:
raise RuntimeError("Background loop is not running or launched correctly.")
return self._abort(request_id)
def _abort(self, request_id: int):
self._request_tracer.abort_request(request_id)
async def run_engine_loop(self):
processing_requests = False
while True:
if not processing_requests:
await self._request_tracer.wait_for_new_requests()
processing_requests = await self.step()
await asyncio.sleep(0)
async def add_request(
self,
request_id: int,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
) -> RequstStream:
"""
Add a request to the background tracker(waiting queue), start the background loop if needed.
"""
if not self.background_loop_status:
if self.start_engine_loop:
self.start_background_loop()
else:
raise RuntimeError("Background loop is not running.")
stream = self._request_tracer.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
)
return stream
async def generate(
self,
request_id: int,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
) -> AsyncIterator[str]:
"""
Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence.
"""
try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
return await stream.get_result()
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the request.
self._abort(request_id)
raise e

View File

@ -507,9 +507,9 @@ class InferenceEngine:
def generate( def generate(
self, self,
prompts: List[str] = None, request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
request_ids: List[int] = None,
return_token_ids: bool = False, return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
) -> List[str]: ) -> List[str]:
@ -527,6 +527,9 @@ class InferenceEngine:
List[str]: Inference result returned by one generation. List[str]: Inference result returned by one generation.
""" """
with torch.inference_mode(): with torch.inference_mode():
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {} gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request( self.add_request(
@ -580,13 +583,13 @@ class InferenceEngine:
if isinstance(prompts, (list, tuple)): if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts] return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str): elif isinstance(prompts, str):
return self.inference_config.rompt_template.format(input_text=prompts) return self.inference_config.prompt_template.format(input_text=prompts)
else: else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.") raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
def add_request( def add_request(
self, self,
request_ids: List[int] = None, request_ids: Union[List[int], int] = None,
prompts: List[str] = None, prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs, **kwargs,
@ -601,11 +604,15 @@ class InferenceEngine:
""" """
# apply the prompt template to the input prompts # apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None: if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts) prompts = self.format_prompt(prompts)
block_size = self.inference_config.block_size block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list): if prompts is not None and not isinstance(prompts, list):
prompts = [prompts] prompts = [prompts]
@ -615,8 +622,10 @@ class InferenceEngine:
"input_ids" "input_ids"
] ]
# list of torch Tensor
if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids, list):
pass if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist() prompts_token_ids = prompts_token_ids.tolist()
else: else:
@ -632,8 +641,6 @@ class InferenceEngine:
for i in range(prompts_num): for i in range(prompts_num):
if request_ids: if request_ids:
if not isinstance(request_ids, list):
request_ids = [request_ids]
assert isinstance( assert isinstance(
request_ids[0], int request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}" ), f"The request_id type must be int, but got {type(request_ids[0])}"
@ -733,7 +740,6 @@ class InferenceEngine:
logits = logits[:, -1, :] logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits) next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens) self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()
return finished_sequences return finished_sequences

View File

@ -209,6 +209,7 @@ class RequestHandler:
break break
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num) num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
# for now the recycle logic is not working
remove_list.extend(lst[:num_seqs_to_add]) remove_list.extend(lst[:num_seqs_to_add])
self.running_list.extend(lst[:num_seqs_to_add]) self.running_list.extend(lst[:num_seqs_to_add])
@ -263,24 +264,27 @@ class RequestHandler:
), f"Sequence {req.request_id} exceeds input length limit" ), f"Sequence {req.request_id} exceeds input length limit"
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req) self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
def abort_sequence(self, request_id: str): def abort_sequence(self, request_id: int):
""" """
Abort the request. Abort the request.
""" """
seq, priority = self._find_sequence(request_id) result = self._find_sequence(request_id)
if seq.status == RequestStatus.WAITING: if result is not None:
seq.mark_aborted() seq, priority = result
self.waiting_list[priority].remove(seq) if seq.status == RequestStatus.WAITING:
elif seq.status.is_running(): seq.mark_aborted()
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table) self.waiting_list[priority].remove(seq)
self.running_list.remove(seq) elif seq.status.is_running():
else: self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
try: self.running_list.remove(seq)
self.done_list.remove(seq) else:
except: try:
return self.done_list.remove(seq)
except:
return
return
def _find_sequence(self, request_id: str) -> Sequence: def _find_sequence(self, request_id: int) -> Sequence:
""" """
Find the request by request_id. Find the request by request_id.
""" """
@ -324,6 +328,9 @@ class RequestHandler:
def check_unfinished_seqs(self) -> bool: def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty() return self._has_waiting() or not self.running_list.is_empty()
def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits): def search_tokens(self, generation_config: GenerationConfig, logits):
""" """
Sample tokens for finished requests. Sample tokens for finished requests.

View File

@ -0,0 +1,27 @@
# Online Service
Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and
you can simply construct a server with both completion and chat functionalities. For now we only support `Llama` model, we will fullfill
the blank quickly.
# Usage
```bash
# First, Lauch an API locally.
python3 -m colossalai.inference.server.api_server --model path of your llama2 model --chat_template "{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
# For completion service, you can invoke it
curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}'
# For chat service, you can invoke it
curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"converation":
[{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},],
"stream": "False",}'
# If you just want to test a simple generation, turn to generate api
curl -X POST http://127.0.0.1:8000/generate -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}'
```
We also support streaming output, simply change the `stream` to `True` in the request body.

View File

View File

@ -0,0 +1,210 @@
"""
Doc:
Feature:
- FastAPI based http server for Colossal-Inference
- Completion Service Supported
Usage: (for local user)
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \
-H 'Content-Type: application/json' \
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
Version: V1.0
"""
import argparse
import json
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.inference.config import InferenceConfig
from colossalai.inference.server.chat_service import ChatServing
from colossalai.inference.server.completion_service import CompletionServing
from colossalai.inference.server.utils import id_generator
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
TIMEOUT_KEEP_ALIVE = 5 # seconds.
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
prompt_template_choices = ["llama", "vicuna"]
async_engine = None
chat_serving = None
completion_serving = None
app = FastAPI()
# NOTE: (CjhHa1) models are still under development, need to be updated
@app.get("/models")
def get_available_models() -> Response:
return JSONResponse(supported_models_dict)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
A request should be a JSON object with the following fields:
- prompts: the prompts to use for the generation.
- stream: whether to stream the results or not.
- other fields:
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", "false").lower()
request_id = id_generator()
generation_config = get_generation_config(request_dict)
results = engine.generate(request_id, prompt, generation_config=generation_config)
# Streaming case
def stream_results():
for request_output in results:
ret = {"text": request_output[len(prompt) :]}
yield (json.dumps(ret) + "\0").encode("utf-8")
if stream == "true":
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
for request_output in results:
if request.is_disconnected():
# Abort the request if the client disconnects.
engine.abort(request_id)
return Response(status_code=499)
final_output = request_output[len(prompt) :]
assert final_output is not None
ret = {"text": final_output}
return JSONResponse(ret)
@app.post("/completion")
async def create_completion(request: Request):
request_dict = await request.json()
stream = request_dict.pop("stream", "false").lower()
generation_config = get_generation_config(request_dict)
result = await completion_serving.create_completion(request, generation_config)
ret = {"request_id": result.request_id, "text": result.output}
if stream == "true":
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
else:
return JSONResponse(content=ret)
@app.post("/chat")
async def create_chat(request: Request):
request_dict = await request.json()
stream = request_dict.get("stream", "false").lower()
generation_config = get_generation_config(request_dict)
message = await chat_serving.create_chat(request, generation_config)
if stream == "true":
return StreamingResponse(content=message, media_type="text/event-stream")
else:
ret = {"role": message.role, "text": message.content}
return ret
def get_generation_config(request):
generation_config = async_engine.engine.generation_config
for arg in request:
if hasattr(generation_config, arg):
generation_config[arg] = request[arg]
return generation_config
def add_engine_config(parser):
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
parser.add_argument(
"--max-model-len",
type=int,
default=None,
help="model context length. If unspecified, " "will be automatically derived from the model.",
)
# Parallel arguments
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
# KV cache arguments
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
# generation arguments
parser.add_argument(
"--prompt_template",
choices=prompt_template_choices,
default=None,
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
)
return parser
def parse_args():
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument(
"--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy"
)
parser.add_argument(
"--model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.",
)
parser.add_argument(
"--chat-template",
type=str,
default=None,
help="The file path to the chat template, " "or the template in single-line form " "for the specified model",
)
parser.add_argument(
"--response-role",
type=str,
default="assistant",
help="The role name to return if " "`request.add_generation_prompt=true`.",
)
parser = add_engine_config(parser)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
inference_config = InferenceConfig.from_dict(vars(args))
model = AutoModelForCausalLM.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model)
async_engine = AsyncInferenceEngine(
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config
)
engine = async_engine.engine
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
chat_serving = ChatServing(
async_engine,
served_model=model.__class__.__name__,
tokenizer=tokenizer,
response_role=args.response_role,
chat_template=args.chat_template,
)
app.root_path = args.root_path
uvicorn.run(
app=app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)

View File

@ -0,0 +1,142 @@
import asyncio
import codecs
import logging
from fastapi import Request
from colossalai.inference.core.async_engine import AsyncInferenceEngine
from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator
logger = logging.getLogger("colossalai-inference")
class ChatServing:
def __init__(
self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None
):
self.engine = engine
self.served_model = served_model
self.tokenizer = tokenizer
self.response_role = response_role
self._load_chat_template(chat_template)
try:
asyncio.get_running_loop()
except RuntimeError:
pass
async def create_chat(self, request: Request, generation_config):
request_dict = await request.json()
messages = request_dict["messages"]
stream = request_dict.pop("stream", "false").lower()
add_generation_prompt = request_dict.pop("add_generation_prompt", False)
request_id = id_generator()
try:
prompt = self.tokenizer.apply_chat_template(
conversation=messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
except Exception as e:
raise RuntimeError(f"Error in applying chat template from request: {str(e)}")
# it is not a intuitive way
self.engine.engine.generation_config = generation_config
result_generator = self.engine.generate(request_id, prompt=prompt)
if stream == "true":
return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id)
else:
return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id)
async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int):
# Send first response for each request.n (index) with the role
role = self.get_chat_request_role(request, request_dict)
n = request_dict.get("n", 1)
echo = request_dict.get("echo", "false").lower()
for i in range(n):
choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role))
data = choice_data.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if echo == "true":
last_msg_content = ""
if (
request_dict["messages"]
and isinstance(request_dict["messages"], list)
and request_dict["messages"][-1].get("content")
and request_dict["messages"][-1].get("role") == role
):
last_msg_content = request_dict["messages"][-1]["content"]
if last_msg_content:
for i in range(n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, message=DeltaMessage(content=last_msg_content)
)
data = choice_data.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
result = await result_generator
choice_data = DeltaMessage(content=result.output)
data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True)
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self,
request: Request,
request_dict: dict,
result_generator,
request_id,
):
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return {"error_msg": "Client disconnected"}
result = await result_generator
assert result is not None
role = self.get_chat_request_role(request, request_dict)
choice_data = ChatMessage(role=role, content=result.output)
echo = request_dict.get("echo", "false").lower()
if echo == "true":
last_msg_content = ""
if (
request.messages
and isinstance(request.messages, list)
and request.messages[-1].get("content")
and request.messages[-1].get("role") == role
):
last_msg_content = request.messages[-1]["content"]
full_message = last_msg_content + choice_data.content
choice_data.content = full_message
return choice_data
def get_chat_request_role(self, request: Request, request_dict: dict) -> str:
add_generation_prompt = request_dict.get("add_generation_prompt", False)
if add_generation_prompt:
return self.response_role
else:
return request_dict["messages"][-1]["role"]
def _load_chat_template(self, chat_template):
if chat_template is not None:
try:
with open(chat_template, "r") as f:
self.tokenizer.chat_template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape")
logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}")
elif self.tokenizer.chat_template is not None:
logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}")
else:
logger.warning("No chat template provided. Chat API will not work.")

View File

@ -0,0 +1,34 @@
import asyncio
from colossalai.inference.core.async_engine import AsyncInferenceEngine
from .utils import id_generator
class CompletionServing:
def __init__(self, engine: AsyncInferenceEngine, served_model: str):
self.engine = engine
self.served_model = served_model
try:
asyncio.get_running_loop()
except RuntimeError:
pass
async def create_completion(self, request, generation_config):
request_dict = await request.json()
request_id = id_generator()
prompt = request_dict.pop("prompt")
# it is not a intuitive way
self.engine.engine.generation_config = generation_config
result_generator = self.engine.generate(request_id, prompt=prompt)
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
raise RuntimeError("Client disconnected")
final_res = await result_generator
return final_res

View File

@ -0,0 +1,36 @@
from typing import Any, Optional
from pydantic import BaseModel
# make it singleton
class NumericIDGenerator:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(NumericIDGenerator, cls).__new__(cls)
cls._instance.current_id = 0
return cls._instance
def __call__(self):
self.current_id += 1
return self.current_id
id_generator = NumericIDGenerator()
class ChatMessage(BaseModel):
role: str
content: Any
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[Any] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
message: DeltaMessage

View File

@ -61,6 +61,7 @@ class Sequence:
pad_token_id (int): The pad token id for this inference process. pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length. max_output_len (int): Maximum output length.
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token. ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
output(str): The output of sequence
""" """
request_id: int request_id: int
@ -73,6 +74,7 @@ class Sequence:
max_output_len: int = 256 max_output_len: int = 256
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future. # NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
ignore_eos: bool = False ignore_eos: bool = False
output: str = None
def __post_init__(self): def __post_init__(self):
self.output_token_id = [] self.output_token_id = []
@ -163,11 +165,13 @@ class Sequence:
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"(request_id={self.request_id}, " f"(request_id={self.request_id}, "
f"prompt={self.prompt}, " f"prompt={self.prompt},\n"
f"status={self.status.name}, " f"output_token_id={self.output_token_id},\n"
f"sample_params={self.sample_params}, " f"output={self.output},\n"
f"input_len={self.input_len}," f"status={self.status.name},\n"
f"output_len={self.output_len})" f"sample_params={self.sample_params},\n"
f"input_len={self.input_len},\n"
f"output_len={self.output_len})\n"
) )

View File

@ -249,7 +249,6 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
:: ::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place. renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.

View File

@ -1,6 +1,7 @@
import os import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
@ -36,7 +37,11 @@ class ShardFormer:
""" """
def __init__(self, shard_config: ShardConfig): def __init__(self, shard_config: ShardConfig):
self.coordinator = DistCoordinator() self.is_distributed = dist.is_initialized()
if self.is_distributed:
self.coordinator = DistCoordinator()
else:
self.coordinator = None
self.shard_config = shard_config self.shard_config = shard_config
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:

View File

@ -0,0 +1,58 @@
from locust import HttpUser, between, tag, task
class QuickstartUser(HttpUser):
wait_time = between(1, 5)
@tag("online-generation")
@task(5)
def completion(self):
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
@tag("online-generation")
@task(5)
def completion_streaming(self):
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
@tag("online-chat")
@task(5)
def chat(self):
self.client.post(
"/chat",
json={
"converation": [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},
],
"stream": "False",
},
)
@tag("online-chat")
@task(5)
def chat_streaming(self):
self.client.post(
"/chat",
json={
"converation": [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},
],
"stream": "True",
},
)
@tag("offline-generation")
@task(5)
def generate_streaming(self):
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
@tag("offline-generation")
@task(5)
def generate(self):
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"})
@tag("online-generation", "offline-generation")
@task
def get_models(self):
self.client.get("/models")

View File

@ -0,0 +1,27 @@
#!/bin/bash
#argument1: model_path
# launch server
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
echo "Model Path: $model_path"
echo "Starting server..."
python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template &
SERVER_PID=$!
# waiting time
sleep 60
# Run Locust
echo "Starting Locust..."
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
echo "Test completion api first"
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
echo "Test chat api"
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
# kill Server
echo "Stopping server..."
kill $SERVER_PID
echo "Test and server shutdown completely"

View File

@ -0,0 +1,4 @@
#!/bin/bash
echo "Skip the test (this test is slow)"
# bash ./run_benchmark.sh

View File

@ -0,0 +1,80 @@
import asyncio
from dataclasses import dataclass
import pytest
from colossalai.inference.core.async_engine import AsyncInferenceEngine
@dataclass
class MockSequence:
request_id: int
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
async def async_step(self):
self.step_calls += 1
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
def add_single_request(self, **kwargs):
del kwargs
self.add_request_calls += 1
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
class MockAsyncInferenceEngine(AsyncInferenceEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncInferenceEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request(1, "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request(2, "", None)
engine.engine.generate(2)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await engine.add_request(3, "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5

View File

@ -0,0 +1,68 @@
import pytest
from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence
class SampleEvent:
def __init__(self):
self.flag = False
def set(self):
self.flag = True
def clear(self):
self.flag = False
def test_request_tracer():
tracker = Tracer()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 1
assert not stream_1.finished
stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3
assert not stream_2.finished
assert not stream_3.finished
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request(1)
assert not tracker.new_requests_event.flag
tracker.abort_request(1)
new = tracker.get_new_requests()
assert not new
stream_4 = tracker.add_request(4)
tracker.abort_request(4)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not new
assert stream_4.finished
stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 5
assert stream_2.finished
assert not stream_5.finished
if __name__ == "__main__":
test_request_tracer()

View File

@ -0,0 +1,103 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def generate_inputs(num_sequences, min_length, max_length):
sequences = []
for _ in range(num_sequences):
length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item()
# generating randomly lengthed sequences
sequence = torch.randint(10, 30000, size=(length,))
sequences.append(sequence)
return sequences
@parameterize(
"test_config",
[
{
"max_batch_size": 8,
"max_output_len": 512,
"max_input_len": 64,
"do_sample": False,
}
],
)
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
setup_seed(20)
max_batch_size = test_config["max_batch_size"]
max_input_len = test_config["max_input_len"]
max_output_len = test_config["max_output_len"]
do_sample = test_config["do_sample"]
top_p = 0.5
top_k = 50
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval()
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
if use_engine:
inference_config = InferenceConfig(
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == max_output_len
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=max_output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert len(outputs) == 10 * max_batch_size
@parameterize("prompt_template", [None, "llama"])
def check_continuous_batching(prompt_template):
check_inference_engine(use_engine=True, prompt_template=prompt_template)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_continuous_batching()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_continuous_batching():
spawn(run_dist, 1)
if __name__ == "__main__":
test_continuous_batching()

View File

@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
) )
).cuda() ).cuda()
model = model.eval() model = model.eval()
inputs = [ inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,", "介绍一下武汉,",
@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
assert inference_engine.generation_config.max_new_tokens == output_len assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs) inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting() assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config) outputs = inference_engine.generate(generation_config=generation_config)
else: else:
if prompt_template: if prompt_template:
@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
inputs = inputs.cuda() inputs = inputs.cuda()
generation_config = GenerationConfig( generation_config = GenerationConfig(
do_sample=do_sample, do_sample=do_sample,
dtype="fp32",
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,