mirror of https://github.com/hpcaitech/ColossalAI
[Inference] ADD async and sync Api server using FastAPI (#5396)
* add api server * fix * add * add completion service and fix bug * add generation config * revise shardformer * fix bugs * add docstrings and fix some bugs * fix bugs and add choices for prompt templatefeat/online-serving
parent
d482922035
commit
69cd7e069d
|
@ -62,6 +62,9 @@ class BatchBucket:
|
|||
def current_batch_size(self):
|
||||
return self._current_batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self._current_batch_size
|
||||
|
||||
@property
|
||||
def available_batch_size(self):
|
||||
return self.max_batch_size - self._current_batch_size
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
"""
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -214,3 +214,18 @@ class InferenceConfig:
|
|||
meta_config[type] = getattr(model_config, type)
|
||||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
inference_config_args = {}
|
||||
for attr in attrs:
|
||||
if attr in config_dict:
|
||||
inference_config_args[attr] = config_dict[attr]
|
||||
else:
|
||||
inference_config_args[attr] = getattr(cls, attr)
|
||||
|
||||
# Set the attributes from the parsed arguments.
|
||||
inference_config = cls(**inference_config_args)
|
||||
return inference_config
|
||||
|
|
|
@ -0,0 +1,318 @@
|
|||
import asyncio
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
|
||||
msg = "Task finished unexpectedly. This should never happen! "
|
||||
try:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc
|
||||
raise AsyncEngineDeadError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of Output for a request that can be
|
||||
iterated over asynchronously."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item) -> None:
|
||||
if self._finished:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
|
||||
def finish(self) -> None:
|
||||
self._queue.put_nowait(StopIteration)
|
||||
self._finished = True
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self._finished
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
result = await self._queue.get()
|
||||
if result is StopIteration:
|
||||
raise StopAsyncIteration
|
||||
elif isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
|
||||
class RequestTracker:
|
||||
"""Synchronous abstraction for tracking requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def init_event(self):
|
||||
self.new_requests_event = asyncio.Event()
|
||||
|
||||
def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
Propagate an exception to request streams (all if request_id is None).
|
||||
"""
|
||||
if request_id is not None:
|
||||
self._request_streams[request_id].put(exc)
|
||||
else:
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
|
||||
def process_finished_request(self, finished_request) -> None:
|
||||
"""Process a finished request from the engine."""
|
||||
request_id = finished_request.request_id
|
||||
|
||||
self._request_streams[request_id].put(finished_request)
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream:
|
||||
"""
|
||||
Add a request to be sent to the engine on the next background
|
||||
loop iteration.
|
||||
"""
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
Logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_requests(self):
|
||||
"""
|
||||
Get new requests from http server.
|
||||
"""
|
||||
new_requests: List[Dict] = []
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[int] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
await self.new_requests_event.wait()
|
||||
|
||||
|
||||
class _AsyncInferenceEngine(InferenceEngine):
|
||||
"""
|
||||
Async methods for Inference Engine.
|
||||
"""
|
||||
|
||||
async def async_step(self) -> List[str]:
|
||||
"""
|
||||
The async version of Engine.step()
|
||||
Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
It first schedules the sequences to be executed in the next iteration.
|
||||
Then, it executes the model and updates the scheduler with the model
|
||||
outputs. Finally, it decodes the sequences and returns the newly
|
||||
generated results.
|
||||
"""
|
||||
batch = self.request_handler.schedule()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Use run_in_executor to asyncally run the sync method model.forward().
|
||||
logits = await loop.run_in_executor(
|
||||
None,
|
||||
self.model,
|
||||
batch,
|
||||
self.k_cache,
|
||||
self.v_cache,
|
||||
)
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
# Return: List[Sequence]
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
|
||||
|
||||
|
||||
class AsyncInferenceEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
|
||||
This class is used to wrap the InferenceEngine class to make it asynchronous.
|
||||
It uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMEngine is kicked by the generate method when there are
|
||||
requests in the waiting queue. The generate method yields the outputs
|
||||
from the InferenceEngine to the caller.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
|
||||
|
||||
def __init__(self, start_engine_loop: bool = True, **kwargs):
|
||||
self.engine = self._init_engine(**kwargs)
|
||||
self.background_loop = None
|
||||
# reference to the unshielded loop
|
||||
self._background_loop_unshielded = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
@property
|
||||
def background_loop_status(self):
|
||||
return self.background_loop is not None and not self.background_loop.done()
|
||||
|
||||
def start_background_loop(self):
|
||||
if self.background_loop_status:
|
||||
raise RuntimeError("Existing loop is running")
|
||||
|
||||
self._request_tracker.init_event()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_raise_exception_on_finish, request_tracker=self._request_tracker)
|
||||
)
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
def _init_engine(self, **kwargs):
|
||||
return self._engine_class(**kwargs)
|
||||
|
||||
async def step(self):
|
||||
"""
|
||||
Run engine to process requests
|
||||
|
||||
Returns True if there are in-progress requests.
|
||||
"""
|
||||
new_requests = self._request_tracker.get_new_requests()
|
||||
for new_request in new_requests:
|
||||
self.engine.add_single_request(**new_request)
|
||||
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
||||
for seq in newly_finished_seqs:
|
||||
self._request_tracker.process_finished_request(seq)
|
||||
|
||||
return has_running_requests
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[int]):
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def abort(self, request_id: int):
|
||||
"""
|
||||
Abort a single request
|
||||
"""
|
||||
if not self.background_loop_status:
|
||||
raise RuntimeError("Background loop is not running or launched correctly.")
|
||||
return self._abort(request_id)
|
||||
|
||||
def _abort(self, request_id: int):
|
||||
self._request_tracker.abort_request(request_id)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
processing_requests = False
|
||||
while True:
|
||||
if not processing_requests:
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
processing_requests = await self.step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: int,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> AsyncStream:
|
||||
"""
|
||||
Add a request to the background tracker(waitting queue), start the background loop if needed.
|
||||
"""
|
||||
if not self.background_loop_status:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
else:
|
||||
raise RuntimeError("Background loop is not running.")
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
return stream
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request_id: int,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
Generate output from a request. It receives the request from http server, adds it into the
|
||||
waitting queue of Async Engine and streams the output sequence.
|
||||
|
||||
"""
|
||||
try:
|
||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
self._abort(request_id)
|
||||
raise e
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from itertools import count
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -507,9 +507,9 @@ class InferenceEngine:
|
|||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str] = None,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
request_ids: List[int] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
|
@ -527,6 +527,11 @@ class InferenceEngine:
|
|||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
|
||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||
prompts = [prompts]
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
self.add_request(
|
||||
|
@ -535,7 +540,7 @@ class InferenceEngine:
|
|||
prompts_token_ids=prompts_token_ids,
|
||||
**gen_config_dict,
|
||||
)
|
||||
|
||||
|
||||
output_seqs_list = []
|
||||
total_tokens_list = []
|
||||
|
||||
|
@ -580,13 +585,13 @@ class InferenceEngine:
|
|||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.rompt_template.format(input_text=prompts)
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: List[int] = None,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
**kwargs,
|
||||
|
@ -601,6 +606,7 @@ class InferenceEngine:
|
|||
"""
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
|
@ -614,6 +620,7 @@ class InferenceEngine:
|
|||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
print(prompts_token_ids)
|
||||
|
||||
if isinstance(prompts_token_ids, list):
|
||||
pass
|
||||
|
@ -632,8 +639,6 @@ class InferenceEngine:
|
|||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
if not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
|
@ -734,6 +739,9 @@ class InferenceEngine:
|
|||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
|
||||
print("in step", logits)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
|
|
|
@ -263,24 +263,27 @@ class RequestHandler:
|
|||
), f"Sequence {req.request_id} exceeds input length limit"
|
||||
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
||||
|
||||
def abort_sequence(self, request_id: str):
|
||||
def abort_sequence(self, request_id: int):
|
||||
"""
|
||||
Abort the request.
|
||||
"""
|
||||
seq, priority = self._find_sequence(request_id)
|
||||
if seq.status == RequestStatus.WAITING:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||
self.running_list.remove(seq)
|
||||
else:
|
||||
try:
|
||||
self.done_list.remove(seq)
|
||||
except:
|
||||
return
|
||||
result = self._find_sequence(request_id)
|
||||
if result is not None:
|
||||
seq, priority = result
|
||||
if seq.status == RequestStatus.WAITING:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||
self.running_list.remove(seq)
|
||||
else:
|
||||
try:
|
||||
self.done_list.remove(seq)
|
||||
except:
|
||||
return
|
||||
return
|
||||
|
||||
def _find_sequence(self, request_id: str) -> Sequence:
|
||||
def _find_sequence(self, request_id: int) -> Sequence:
|
||||
"""
|
||||
Find the request by request_id.
|
||||
"""
|
||||
|
@ -324,6 +327,9 @@ class RequestHandler:
|
|||
def check_unfinished_seqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def current_requests_in_batch(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
"""
|
||||
Doc:
|
||||
Feature:
|
||||
- FastAPI based http server for Colossal-Inference
|
||||
- Completion Service Supported
|
||||
Usage: (for local user)
|
||||
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
|
||||
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
|
||||
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.server.completion_service import CompletionServing
|
||||
from colossalai.inference.server.utils import id_generator
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
|
||||
prompt_template_choices = ["llama", "vicuna"]
|
||||
|
||||
|
||||
@app.get("/v0/models")
|
||||
def get_available_models() -> Response:
|
||||
return JSONResponse(supported_models_dict)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
A request should be a JSON object with the following fields:
|
||||
- prompts: the prompts to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields:
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", None)
|
||||
|
||||
request_id = id_generator()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
results = engine.generate(request_id, prompt, generation_config=generation_config)
|
||||
|
||||
# Streaming case
|
||||
def stream_results():
|
||||
for request_output in results:
|
||||
ret = {"text": request_output}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
for request_output in results:
|
||||
if request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
engine.abort(request_id)
|
||||
return Response(status_code=499)
|
||||
final_output = request_output
|
||||
|
||||
assert final_output is not None
|
||||
ret = {"text": final_output}
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
@app.post("/v1/completion")
|
||||
async def create_completion(request: Request):
|
||||
request_dict = await request.json()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
generator = await completion_serving.create_completion(request, generation_config)
|
||||
output = tokenizer.decode(generator.output_token_id)
|
||||
ret = {"request_id": generator.request_id, "text": output}
|
||||
return ret
|
||||
|
||||
|
||||
def get_generation_config(request):
|
||||
generation_config = async_engine.engine.generation_config
|
||||
for arg in request:
|
||||
if hasattr(generation_config, arg):
|
||||
generation_config[arg] = request[arg]
|
||||
return generation_config
|
||||
|
||||
|
||||
def add_engine_config(parser):
|
||||
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
|
||||
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
||||
)
|
||||
# Parallel arguments
|
||||
parser.add_argument(
|
||||
"--worker-use-ray",
|
||||
action="store_true",
|
||||
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
|
||||
)
|
||||
|
||||
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
|
||||
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
||||
|
||||
# KV cache arguments
|
||||
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
|
||||
|
||||
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
|
||||
|
||||
# generation arguments
|
||||
parser.add_argument(
|
||||
"--prompt_template",
|
||||
choices=prompt_template_choices,
|
||||
default=None,
|
||||
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
|
||||
)
|
||||
|
||||
# Quantization settings.
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
"-q",
|
||||
type=str,
|
||||
choices=["awq", "gptq", "squeezellm", None],
|
||||
default=None,
|
||||
help="Method used to quantize the weights. If "
|
||||
"None, we first check the `quantization_config` "
|
||||
"attribute in the model config file. If that is "
|
||||
"None, we assume the model weights are not "
|
||||
"quantized and use `dtype` to determine the data "
|
||||
"type of the weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enforce-eager",
|
||||
action="store_true",
|
||||
help="Always use eager-mode PyTorch. If False, "
|
||||
"will use eager mode and CUDA graph in hybrid "
|
||||
"for maximal performance and flexibility.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
|
||||
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.",
|
||||
)
|
||||
parser = add_engine_config(parser)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
inference_config = InferenceConfig.from_dict(vars(args))
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
async_engine = AsyncInferenceEngine(
|
||||
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config
|
||||
)
|
||||
engine = async_engine.engine
|
||||
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
|
||||
|
||||
app.root_path = args.root_path
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
|
@ -0,0 +1,35 @@
|
|||
import asyncio
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
from .utils import id_generator
|
||||
|
||||
|
||||
class CompletionServing:
|
||||
def __init__(self, engine: AsyncInferenceEngine, served_model: str):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def create_completion(self, request, generation_config):
|
||||
request_dict = await request.json()
|
||||
request_id = id_generator()
|
||||
prompt = request_dict.pop("prompt")
|
||||
|
||||
# it is not a intuitive way
|
||||
self.engine.engine.generation_config = generation_config
|
||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||
|
||||
final_res = None
|
||||
async for res in result_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return {"error_msg": "Client disconnected"}
|
||||
final_res = res
|
||||
|
||||
return final_res
|
|
@ -0,0 +1,16 @@
|
|||
# make it singleton
|
||||
class NumericIDGenerator:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(NumericIDGenerator, cls).__new__(cls)
|
||||
cls._instance.current_id = 0
|
||||
return cls._instance
|
||||
|
||||
def __call__(self):
|
||||
self.current_id += 1
|
||||
return self.current_id
|
||||
|
||||
|
||||
id_generator = NumericIDGenerator()
|
|
@ -164,6 +164,7 @@ class Sequence:
|
|||
return (
|
||||
f"(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"output_token_id={self.output_token_id},"
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"input_len={self.input_len},"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -36,7 +37,11 @@ class ShardFormer:
|
|||
"""
|
||||
|
||||
def __init__(self, shard_config: ShardConfig):
|
||||
self.coordinator = DistCoordinator()
|
||||
self.is_distributed = dist.is_initialized()
|
||||
if self.is_distributed:
|
||||
self.coordinator = DistCoordinator()
|
||||
else:
|
||||
self.coordinator = None
|
||||
self.shard_config = shard_config
|
||||
|
||||
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceTpye:
|
||||
request_id: int
|
||||
|
||||
|
||||
class MockEngine:
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
self.add_request_calls = 0
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
|
||||
async def async_step(self):
|
||||
self.step_calls += 1
|
||||
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
||||
def stop_generating(self):
|
||||
self.request_id = None
|
||||
|
||||
def add_request(self, **kwargs):
|
||||
del kwargs # Unused
|
||||
self.add_request_calls += 1
|
||||
|
||||
def abort_request(self, request_id):
|
||||
del request_id # Unused
|
||||
self.abort_request_calls += 1
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncInferenceEngine):
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
|
||||
await engine.add_request(1, "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 1
|
||||
assert engine.engine.step_calls == 1
|
||||
|
||||
await engine.add_request(2, "", None)
|
||||
engine.engine.generate(2)
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls == 2
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 3
|
||||
engine.engine.stop_generating()
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
|
||||
await engine.add_request(3, "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_new_requests_event()
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from colossalai.inference.core.async_engine import RequestTracker
|
||||
from colossalai.inference.struct import Sequence
|
||||
|
||||
|
||||
class SampleEvent:
|
||||
def __init__(self):
|
||||
self.flag = False
|
||||
|
||||
def set(self):
|
||||
self.flag = True
|
||||
|
||||
def clear(self):
|
||||
self.flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
tracker.new_requests_event = SampleEvent()
|
||||
stream_1 = tracker.add_request(1)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 1
|
||||
assert not finished
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request(2)
|
||||
stream_3 = tracker.add_request(3)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == 2
|
||||
assert new[1]["request_id"] == 3
|
||||
assert not finished
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request(1)
|
||||
assert not tracker.new_requests_event.flag
|
||||
|
||||
tracker.abort_request(1)
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert 1 in finished
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request(4)
|
||||
tracker.abort_request(4)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert 4 in finished
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request(5)
|
||||
assert tracker.new_requests_event.flag
|
||||
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(finished) == 1
|
||||
assert 2 in finished
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 5
|
||||
assert stream_2.finished
|
||||
assert not stream_5.finished
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_request_tracker()
|
Loading…
Reference in New Issue