mirror of https://github.com/hpcaitech/ColossalAI
commit
492520dbdb
|
@ -62,6 +62,9 @@ class BatchBucket:
|
|||
def current_batch_size(self):
|
||||
return self._current_batch_size
|
||||
|
||||
def __len__(self):
|
||||
return self._current_batch_size
|
||||
|
||||
@property
|
||||
def available_batch_size(self):
|
||||
return self.max_batch_size - self._current_batch_size
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
"""
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -214,3 +213,18 @@ class InferenceConfig:
|
|||
meta_config[type] = getattr(model_config, type)
|
||||
|
||||
return GenerationConfig.from_dict(meta_config)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in fields(cls)]
|
||||
inference_config_args = {}
|
||||
for attr in attrs:
|
||||
if attr in config_dict:
|
||||
inference_config_args[attr] = config_dict[attr]
|
||||
else:
|
||||
inference_config_args[attr] = getattr(cls, attr)
|
||||
|
||||
# Set the attributes from the parsed arguments.
|
||||
inference_config = cls(**inference_config_args)
|
||||
return inference_config
|
||||
|
|
|
@ -0,0 +1,309 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
|
||||
# CLI logger
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger("colossalai-inference")
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
|
||||
msg = "Task finished unexpectedly. This should never happen! "
|
||||
try:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
|
||||
raise RuntimeError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
class RequstStream:
|
||||
"""
|
||||
A stream of Output for a request that can be iterated over asynchronously.
|
||||
Attributes: 1.request_id: The id of the request.
|
||||
2._future: A future that will be set when the request is finished.
|
||||
Methods: set_result and get_result, results will be set when finished, for once, and
|
||||
the `self.future` will be set to done.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: int) -> None:
|
||||
self.request_id = request_id
|
||||
self._future = asyncio.Future()
|
||||
|
||||
def set_result(self, result) -> None:
|
||||
"""Set final result and signal taht it's ready"""
|
||||
if not self._future.done():
|
||||
self._future.set_result(result)
|
||||
|
||||
async def get_result(self):
|
||||
"""Wait for the result to be set and return it."""
|
||||
return await self._future
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
"""Check if the stream has finished by checking if the future is done."""
|
||||
return self._future.done()
|
||||
|
||||
|
||||
class Tracer:
|
||||
"""
|
||||
Recording new requests and finished requests.
|
||||
Attributes: 1._request_streams: We create one stream for each request to trace the output.
|
||||
2._finished_requests: A queue to store the finished requests.
|
||||
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
|
||||
4.new_requests_event: An event to notify the engine that there are new requests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[int, RequstStream] = {}
|
||||
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def init_event(self):
|
||||
self.new_requests_event = asyncio.Event()
|
||||
|
||||
def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
Propagate an exception to request streams (all if request_id is None).
|
||||
"""
|
||||
if request_id is not None:
|
||||
self._request_streams[request_id].set_result(exc)
|
||||
else:
|
||||
for stream in self._request_streams.values():
|
||||
stream.set_result(exc)
|
||||
|
||||
def process_finished_request(self, finished_request) -> None:
|
||||
"""Process a finished request from the engine."""
|
||||
request_id = finished_request.request_id
|
||||
try:
|
||||
self._request_streams[request_id].set_result(finished_request)
|
||||
except:
|
||||
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
|
||||
"""
|
||||
Add a request to be sent to the engine on the next background
|
||||
loop iteration.
|
||||
"""
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = RequstStream(request_id)
|
||||
logger.info(f"Added request {request_id}.")
|
||||
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||
self.new_requests_event.set()
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
# The requests in new_requests will be aborted when try to get them(if marked aborted)
|
||||
return
|
||||
|
||||
self._request_streams[request_id].set_result(None)
|
||||
|
||||
def get_new_requests(self):
|
||||
"""
|
||||
Get new requests from http server.
|
||||
"""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[int] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if new_request["request_id"] in finished_requests:
|
||||
# The request has been aborted.
|
||||
stream.set_result(None)
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
await self.new_requests_event.wait()
|
||||
|
||||
|
||||
class _AsyncInferenceEngine(InferenceEngine):
|
||||
"""
|
||||
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
|
||||
Methods: 1. async_step: The async version of Engine.step()
|
||||
"""
|
||||
|
||||
async def async_step(self) -> List[str]:
|
||||
"""
|
||||
The async version of Engine.step()
|
||||
Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
It first schedules the sequences to be executed in the next iteration.
|
||||
Then, it executes the model and updates the scheduler with the model
|
||||
outputs. Finally, it decodes the sequences and returns the newly
|
||||
generated results.
|
||||
"""
|
||||
batch = self.request_handler.schedule()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Use run_in_executor to asyncally run the sync method model.forward().
|
||||
logits = await loop.run_in_executor(
|
||||
None,
|
||||
self.model,
|
||||
batch,
|
||||
self.k_cache,
|
||||
self.v_cache,
|
||||
)
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
for sequence in finished_sequences:
|
||||
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||
|
||||
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
|
||||
|
||||
|
||||
class AsyncInferenceEngine:
|
||||
"""An asynchronous wrapper for the InferenceEngine class.
|
||||
|
||||
This class is used to wrap the InferenceEngine class to make it asynchronous.
|
||||
It uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. Note that this class does not hold model directly, when incoming a new
|
||||
request, it first called `add_request` and the Tracer will record the request, putting
|
||||
it to the background `InferenceEngine`(done in background loop) to process. You can
|
||||
consider this engine as an interface.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
|
||||
|
||||
def __init__(self, start_engine_loop: bool = True, **kwargs):
|
||||
self.engine = self._init_engine(**kwargs)
|
||||
self.background_loop = None
|
||||
# reference to the unshielded loop
|
||||
self._background_loop_unshielded = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracer = Tracer()
|
||||
|
||||
@property
|
||||
def background_loop_status(self):
|
||||
return self.background_loop is not None and not self.background_loop.done()
|
||||
|
||||
def start_background_loop(self):
|
||||
if self.background_loop_status:
|
||||
raise RuntimeError("Existing loop is running")
|
||||
|
||||
self._request_tracer.init_event()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
|
||||
)
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
def _init_engine(self, **kwargs):
|
||||
return self._engine_class(**kwargs)
|
||||
|
||||
async def step(self):
|
||||
"""
|
||||
Run engine to process requests
|
||||
|
||||
Returns True if there are in-progress requests.
|
||||
"""
|
||||
new_requests = self._request_tracer.get_new_requests()
|
||||
for new_request in new_requests:
|
||||
self.engine.add_single_request(**new_request)
|
||||
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
||||
|
||||
for seq in newly_finished_seqs:
|
||||
self._request_tracer.process_finished_request(seq)
|
||||
|
||||
return has_running_requests
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[int]):
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def abort(self, request_id: int):
|
||||
"""
|
||||
Abort a single request
|
||||
"""
|
||||
if not self.background_loop_status:
|
||||
raise RuntimeError("Background loop is not running or launched correctly.")
|
||||
return self._abort(request_id)
|
||||
|
||||
def _abort(self, request_id: int):
|
||||
self._request_tracer.abort_request(request_id)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
processing_requests = False
|
||||
while True:
|
||||
if not processing_requests:
|
||||
await self._request_tracer.wait_for_new_requests()
|
||||
processing_requests = await self.step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: int,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> RequstStream:
|
||||
"""
|
||||
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
||||
"""
|
||||
if not self.background_loop_status:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
else:
|
||||
raise RuntimeError("Background loop is not running.")
|
||||
stream = self._request_tracer.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
return stream
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request_id: int,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""
|
||||
Generate output from a request. It receives the request from http server, adds it into the
|
||||
waitting queue of Async Engine and streams the output sequence.
|
||||
"""
|
||||
try:
|
||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||
return await stream.get_result()
|
||||
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the request.
|
||||
self._abort(request_id)
|
||||
raise e
|
|
@ -507,9 +507,9 @@ class InferenceEngine:
|
|||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str] = None,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: Union[List[str], str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
request_ids: List[int] = None,
|
||||
return_token_ids: bool = False,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> List[str]:
|
||||
|
@ -527,6 +527,9 @@ class InferenceEngine:
|
|||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||
prompts = [prompts]
|
||||
request_ids = [request_ids]
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
self.add_request(
|
||||
|
@ -580,13 +583,13 @@ class InferenceEngine:
|
|||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.rompt_template.format(input_text=prompts)
|
||||
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_ids: List[int] = None,
|
||||
request_ids: Union[List[int], int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||
**kwargs,
|
||||
|
@ -601,11 +604,15 @@ class InferenceEngine:
|
|||
"""
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
|
@ -615,8 +622,10 @@ class InferenceEngine:
|
|||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
pass
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
|
@ -632,8 +641,6 @@ class InferenceEngine:
|
|||
|
||||
for i in range(prompts_num):
|
||||
if request_ids:
|
||||
if not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
assert isinstance(
|
||||
request_ids[0], int
|
||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||
|
@ -733,7 +740,6 @@ class InferenceEngine:
|
|||
logits = logits[:, -1, :]
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
|
|
|
@ -209,6 +209,7 @@ class RequestHandler:
|
|||
break
|
||||
|
||||
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
||||
# for now the recycle logic is not working
|
||||
remove_list.extend(lst[:num_seqs_to_add])
|
||||
self.running_list.extend(lst[:num_seqs_to_add])
|
||||
|
||||
|
@ -263,24 +264,27 @@ class RequestHandler:
|
|||
), f"Sequence {req.request_id} exceeds input length limit"
|
||||
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
||||
|
||||
def abort_sequence(self, request_id: str):
|
||||
def abort_sequence(self, request_id: int):
|
||||
"""
|
||||
Abort the request.
|
||||
"""
|
||||
seq, priority = self._find_sequence(request_id)
|
||||
if seq.status == RequestStatus.WAITING:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||
self.running_list.remove(seq)
|
||||
else:
|
||||
try:
|
||||
self.done_list.remove(seq)
|
||||
except:
|
||||
return
|
||||
result = self._find_sequence(request_id)
|
||||
if result is not None:
|
||||
seq, priority = result
|
||||
if seq.status == RequestStatus.WAITING:
|
||||
seq.mark_aborted()
|
||||
self.waiting_list[priority].remove(seq)
|
||||
elif seq.status.is_running():
|
||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||
self.running_list.remove(seq)
|
||||
else:
|
||||
try:
|
||||
self.done_list.remove(seq)
|
||||
except:
|
||||
return
|
||||
return
|
||||
|
||||
def _find_sequence(self, request_id: str) -> Sequence:
|
||||
def _find_sequence(self, request_id: int) -> Sequence:
|
||||
"""
|
||||
Find the request by request_id.
|
||||
"""
|
||||
|
@ -324,6 +328,9 @@ class RequestHandler:
|
|||
def check_unfinished_seqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||
"""
|
||||
Sample tokens for finished requests.
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
# Online Service
|
||||
Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and
|
||||
you can simply construct a server with both completion and chat functionalities. For now we only support `Llama` model, we will fullfill
|
||||
the blank quickly.
|
||||
|
||||
# Usage
|
||||
```bash
|
||||
# First, Lauch an API locally.
|
||||
python3 -m colossalai.inference.server.api_server --model path of your llama2 model --chat_template "{% for message in messages %}
|
||||
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
||||
|
||||
|
||||
# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
|
||||
|
||||
# For completion service, you can invoke it
|
||||
curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}'
|
||||
|
||||
# For chat service, you can invoke it
|
||||
curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"converation":
|
||||
[{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},],
|
||||
"stream": "False",}'
|
||||
# If you just want to test a simple generation, turn to generate api
|
||||
curl -X POST http://127.0.0.1:8000/generate -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}'
|
||||
|
||||
```
|
||||
We also support streaming output, simply change the `stream` to `True` in the request body.
|
|
@ -0,0 +1,210 @@
|
|||
"""
|
||||
Doc:
|
||||
Feature:
|
||||
- FastAPI based http server for Colossal-Inference
|
||||
- Completion Service Supported
|
||||
Usage: (for local user)
|
||||
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
|
||||
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
|
||||
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||
Version: V1.0
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.server.chat_service import ChatServing
|
||||
from colossalai.inference.server.completion_service import CompletionServing
|
||||
from colossalai.inference.server.utils import id_generator
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
|
||||
prompt_template_choices = ["llama", "vicuna"]
|
||||
async_engine = None
|
||||
chat_serving = None
|
||||
completion_serving = None
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# NOTE: (CjhHa1) models are still under development, need to be updated
|
||||
@app.get("/models")
|
||||
def get_available_models() -> Response:
|
||||
return JSONResponse(supported_models_dict)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
A request should be a JSON object with the following fields:
|
||||
- prompts: the prompts to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields:
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
|
||||
request_id = id_generator()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
results = engine.generate(request_id, prompt, generation_config=generation_config)
|
||||
|
||||
# Streaming case
|
||||
def stream_results():
|
||||
for request_output in results:
|
||||
ret = {"text": request_output[len(prompt) :]}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream == "true":
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
for request_output in results:
|
||||
if request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
engine.abort(request_id)
|
||||
return Response(status_code=499)
|
||||
final_output = request_output[len(prompt) :]
|
||||
|
||||
assert final_output is not None
|
||||
ret = {"text": final_output}
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
@app.post("/completion")
|
||||
async def create_completion(request: Request):
|
||||
request_dict = await request.json()
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
result = await completion_serving.create_completion(request, generation_config)
|
||||
|
||||
ret = {"request_id": result.request_id, "text": result.output}
|
||||
if stream == "true":
|
||||
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(content=ret)
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
async def create_chat(request: Request):
|
||||
request_dict = await request.json()
|
||||
|
||||
stream = request_dict.get("stream", "false").lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
message = await chat_serving.create_chat(request, generation_config)
|
||||
if stream == "true":
|
||||
return StreamingResponse(content=message, media_type="text/event-stream")
|
||||
else:
|
||||
ret = {"role": message.role, "text": message.content}
|
||||
return ret
|
||||
|
||||
|
||||
def get_generation_config(request):
|
||||
generation_config = async_engine.engine.generation_config
|
||||
for arg in request:
|
||||
if hasattr(generation_config, arg):
|
||||
generation_config[arg] = request[arg]
|
||||
return generation_config
|
||||
|
||||
|
||||
def add_engine_config(parser):
|
||||
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
|
||||
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
||||
)
|
||||
# Parallel arguments
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
||||
|
||||
# KV cache arguments
|
||||
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
|
||||
|
||||
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
|
||||
|
||||
# generation arguments
|
||||
parser.add_argument(
|
||||
"--prompt_template",
|
||||
choices=prompt_template_choices,
|
||||
default=None,
|
||||
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
|
||||
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the chat template, " "or the template in single-line form " "for the specified model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--response-role",
|
||||
type=str,
|
||||
default="assistant",
|
||||
help="The role name to return if " "`request.add_generation_prompt=true`.",
|
||||
)
|
||||
parser = add_engine_config(parser)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
inference_config = InferenceConfig.from_dict(vars(args))
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
async_engine = AsyncInferenceEngine(
|
||||
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config
|
||||
)
|
||||
engine = async_engine.engine
|
||||
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
|
||||
chat_serving = ChatServing(
|
||||
async_engine,
|
||||
served_model=model.__class__.__name__,
|
||||
tokenizer=tokenizer,
|
||||
response_role=args.response_role,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
app.root_path = args.root_path
|
||||
uvicorn.run(
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
)
|
|
@ -0,0 +1,142 @@
|
|||
import asyncio
|
||||
import codecs
|
||||
import logging
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator
|
||||
|
||||
logger = logging.getLogger("colossalai-inference")
|
||||
|
||||
|
||||
class ChatServing:
|
||||
def __init__(
|
||||
self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None
|
||||
):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
self.tokenizer = tokenizer
|
||||
self.response_role = response_role
|
||||
self._load_chat_template(chat_template)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def create_chat(self, request: Request, generation_config):
|
||||
request_dict = await request.json()
|
||||
messages = request_dict["messages"]
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
add_generation_prompt = request_dict.pop("add_generation_prompt", False)
|
||||
request_id = id_generator()
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation=messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in applying chat template from request: {str(e)}")
|
||||
|
||||
# it is not a intuitive way
|
||||
self.engine.engine.generation_config = generation_config
|
||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||
|
||||
if stream == "true":
|
||||
return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id)
|
||||
else:
|
||||
return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id)
|
||||
|
||||
async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int):
|
||||
# Send first response for each request.n (index) with the role
|
||||
role = self.get_chat_request_role(request, request_dict)
|
||||
n = request_dict.get("n", 1)
|
||||
echo = request_dict.get("echo", "false").lower()
|
||||
for i in range(n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role))
|
||||
data = choice_data.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the last message
|
||||
if echo == "true":
|
||||
last_msg_content = ""
|
||||
if (
|
||||
request_dict["messages"]
|
||||
and isinstance(request_dict["messages"], list)
|
||||
and request_dict["messages"][-1].get("content")
|
||||
and request_dict["messages"][-1].get("role") == role
|
||||
):
|
||||
last_msg_content = request_dict["messages"][-1]["content"]
|
||||
if last_msg_content:
|
||||
for i in range(n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, message=DeltaMessage(content=last_msg_content)
|
||||
)
|
||||
data = choice_data.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
result = await result_generator
|
||||
choice_data = DeltaMessage(content=result.output)
|
||||
data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: Request,
|
||||
request_dict: dict,
|
||||
result_generator,
|
||||
request_id,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return {"error_msg": "Client disconnected"}
|
||||
|
||||
result = await result_generator
|
||||
assert result is not None
|
||||
role = self.get_chat_request_role(request, request_dict)
|
||||
choice_data = ChatMessage(role=role, content=result.output)
|
||||
echo = request_dict.get("echo", "false").lower()
|
||||
|
||||
if echo == "true":
|
||||
last_msg_content = ""
|
||||
if (
|
||||
request.messages
|
||||
and isinstance(request.messages, list)
|
||||
and request.messages[-1].get("content")
|
||||
and request.messages[-1].get("role") == role
|
||||
):
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
full_message = last_msg_content + choice_data.content
|
||||
choice_data.content = full_message
|
||||
|
||||
return choice_data
|
||||
|
||||
def get_chat_request_role(self, request: Request, request_dict: dict) -> str:
|
||||
add_generation_prompt = request_dict.get("add_generation_prompt", False)
|
||||
if add_generation_prompt:
|
||||
return self.response_role
|
||||
else:
|
||||
return request_dict["messages"][-1]["role"]
|
||||
|
||||
def _load_chat_template(self, chat_template):
|
||||
if chat_template is not None:
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
self.tokenizer.chat_template = f.read()
|
||||
except OSError:
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||
|
||||
logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}")
|
||||
elif self.tokenizer.chat_template is not None:
|
||||
logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}")
|
||||
else:
|
||||
logger.warning("No chat template provided. Chat API will not work.")
|
|
@ -0,0 +1,34 @@
|
|||
import asyncio
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
from .utils import id_generator
|
||||
|
||||
|
||||
class CompletionServing:
|
||||
def __init__(self, engine: AsyncInferenceEngine, served_model: str):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def create_completion(self, request, generation_config):
|
||||
request_dict = await request.json()
|
||||
request_id = id_generator()
|
||||
|
||||
prompt = request_dict.pop("prompt")
|
||||
|
||||
# it is not a intuitive way
|
||||
self.engine.engine.generation_config = generation_config
|
||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
raise RuntimeError("Client disconnected")
|
||||
|
||||
final_res = await result_generator
|
||||
return final_res
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# make it singleton
|
||||
class NumericIDGenerator:
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(NumericIDGenerator, cls).__new__(cls)
|
||||
cls._instance.current_id = 0
|
||||
return cls._instance
|
||||
|
||||
def __call__(self):
|
||||
self.current_id += 1
|
||||
return self.current_id
|
||||
|
||||
|
||||
id_generator = NumericIDGenerator()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Any
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[Any] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
message: DeltaMessage
|
|
@ -61,6 +61,7 @@ class Sequence:
|
|||
pad_token_id (int): The pad token id for this inference process.
|
||||
max_output_len (int): Maximum output length.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
output(str): The output of sequence
|
||||
"""
|
||||
|
||||
request_id: int
|
||||
|
@ -73,6 +74,7 @@ class Sequence:
|
|||
max_output_len: int = 256
|
||||
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
|
||||
ignore_eos: bool = False
|
||||
output: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.output_token_id = []
|
||||
|
@ -163,11 +165,13 @@ class Sequence:
|
|||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"input_len={self.input_len},"
|
||||
f"output_len={self.output_len})"
|
||||
f"prompt={self.prompt},\n"
|
||||
f"output_token_id={self.output_token_id},\n"
|
||||
f"output={self.output},\n"
|
||||
f"status={self.status.name},\n"
|
||||
f"sample_params={self.sample_params},\n"
|
||||
f"input_len={self.input_len},\n"
|
||||
f"output_len={self.output_len})\n"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -249,7 +249,6 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
|||
|
||||
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
||||
::
|
||||
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
||||
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -36,7 +37,11 @@ class ShardFormer:
|
|||
"""
|
||||
|
||||
def __init__(self, shard_config: ShardConfig):
|
||||
self.coordinator = DistCoordinator()
|
||||
self.is_distributed = dist.is_initialized()
|
||||
if self.is_distributed:
|
||||
self.coordinator = DistCoordinator()
|
||||
else:
|
||||
self.coordinator = None
|
||||
self.shard_config = shard_config
|
||||
|
||||
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
from locust import HttpUser, between, tag, task
|
||||
|
||||
|
||||
class QuickstartUser(HttpUser):
|
||||
wait_time = between(1, 5)
|
||||
|
||||
@tag("online-generation")
|
||||
@task(5)
|
||||
def completion(self):
|
||||
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
|
||||
|
||||
@tag("online-generation")
|
||||
@task(5)
|
||||
def completion_streaming(self):
|
||||
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
|
||||
|
||||
@tag("online-chat")
|
||||
@task(5)
|
||||
def chat(self):
|
||||
self.client.post(
|
||||
"/chat",
|
||||
json={
|
||||
"converation": [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},
|
||||
],
|
||||
"stream": "False",
|
||||
},
|
||||
)
|
||||
|
||||
@tag("online-chat")
|
||||
@task(5)
|
||||
def chat_streaming(self):
|
||||
self.client.post(
|
||||
"/chat",
|
||||
json={
|
||||
"converation": [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},
|
||||
],
|
||||
"stream": "True",
|
||||
},
|
||||
)
|
||||
|
||||
@tag("offline-generation")
|
||||
@task(5)
|
||||
def generate_streaming(self):
|
||||
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
|
||||
|
||||
@tag("offline-generation")
|
||||
@task(5)
|
||||
def generate(self):
|
||||
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"})
|
||||
|
||||
@tag("online-generation", "offline-generation")
|
||||
@task
|
||||
def get_models(self):
|
||||
self.client.get("/models")
|
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
|
||||
#argument1: model_path
|
||||
|
||||
# launch server
|
||||
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
|
||||
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
||||
echo "Model Path: $model_path"
|
||||
echo "Starting server..."
|
||||
python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template &
|
||||
SERVER_PID=$!
|
||||
|
||||
# waiting time
|
||||
sleep 60
|
||||
|
||||
# Run Locust
|
||||
echo "Starting Locust..."
|
||||
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
|
||||
echo "Test completion api first"
|
||||
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
||||
echo "Test chat api"
|
||||
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
||||
# kill Server
|
||||
echo "Stopping server..."
|
||||
kill $SERVER_PID
|
||||
|
||||
echo "Test and server shutdown completely"
|
|
@ -0,0 +1,4 @@
|
|||
#!/bin/bash
|
||||
echo "Skip the test (this test is slow)"
|
||||
|
||||
# bash ./run_benchmark.sh
|
|
@ -0,0 +1,80 @@
|
|||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockSequence:
|
||||
request_id: int
|
||||
|
||||
|
||||
class MockEngine:
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
self.add_request_calls = 0
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
|
||||
async def async_step(self):
|
||||
self.step_calls += 1
|
||||
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
|
||||
|
||||
def add_single_request(self, **kwargs):
|
||||
del kwargs
|
||||
self.add_request_calls += 1
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
||||
def stop_generating(self):
|
||||
self.request_id = None
|
||||
|
||||
def add_request(self, **kwargs):
|
||||
del kwargs # Unused
|
||||
self.add_request_calls += 1
|
||||
|
||||
def abort_request(self, request_id):
|
||||
del request_id # Unused
|
||||
self.abort_request_calls += 1
|
||||
|
||||
|
||||
class MockAsyncInferenceEngine(AsyncInferenceEngine):
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncInferenceEngine()
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
|
||||
await engine.add_request(1, "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 1
|
||||
assert engine.engine.step_calls == 1
|
||||
|
||||
await engine.add_request(2, "", None)
|
||||
engine.engine.generate(2)
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls == 2
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 3
|
||||
engine.engine.stop_generating()
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
|
||||
await engine.add_request(3, "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
|
@ -0,0 +1,68 @@
|
|||
import pytest
|
||||
|
||||
from colossalai.inference.core.async_engine import Tracer
|
||||
from colossalai.inference.struct import Sequence
|
||||
|
||||
|
||||
class SampleEvent:
|
||||
def __init__(self):
|
||||
self.flag = False
|
||||
|
||||
def set(self):
|
||||
self.flag = True
|
||||
|
||||
def clear(self):
|
||||
self.flag = False
|
||||
|
||||
|
||||
def test_request_tracer():
|
||||
tracker = Tracer()
|
||||
tracker.new_requests_event = SampleEvent()
|
||||
stream_1 = tracker.add_request(1)
|
||||
assert tracker.new_requests_event.flag
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 1
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request(2)
|
||||
stream_3 = tracker.add_request(3)
|
||||
assert tracker.new_requests_event.flag
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == 2
|
||||
assert new[1]["request_id"] == 3
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request(1)
|
||||
assert not tracker.new_requests_event.flag
|
||||
|
||||
tracker.abort_request(1)
|
||||
new = tracker.get_new_requests()
|
||||
assert not new
|
||||
|
||||
stream_4 = tracker.add_request(4)
|
||||
tracker.abort_request(4)
|
||||
assert tracker.new_requests_event.flag
|
||||
new = tracker.get_new_requests()
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request(5)
|
||||
assert tracker.new_requests_event.flag
|
||||
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 5
|
||||
assert stream_2.finished
|
||||
assert not stream_5.finished
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_request_tracer()
|
|
@ -0,0 +1,103 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def generate_inputs(num_sequences, min_length, max_length):
|
||||
sequences = []
|
||||
for _ in range(num_sequences):
|
||||
length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item()
|
||||
# generating randomly lengthed sequences
|
||||
sequence = torch.randint(10, 30000, size=(length,))
|
||||
sequences.append(sequence)
|
||||
return sequences
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"max_batch_size": 8,
|
||||
"max_output_len": 512,
|
||||
"max_input_len": 64,
|
||||
"do_sample": False,
|
||||
}
|
||||
],
|
||||
)
|
||||
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
max_batch_size = test_config["max_batch_size"]
|
||||
max_input_len = test_config["max_input_len"]
|
||||
max_output_len = test_config["max_output_len"]
|
||||
do_sample = test_config["do_sample"]
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||
model = model.eval()
|
||||
|
||||
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(
|
||||
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
|
||||
)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == max_output_len
|
||||
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
# apply prompt template
|
||||
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
outputs = model.generate(inputs, generation_config=generation_config)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
assert len(outputs) == 10 * max_batch_size
|
||||
|
||||
|
||||
@parameterize("prompt_template", [None, "llama"])
|
||||
def check_continuous_batching(prompt_template):
|
||||
check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_continuous_batching()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_continuous_batching():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_continuous_batching()
|
|
@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
|||
)
|
||||
).cuda()
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||
"介绍一下武汉,",
|
||||
|
@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
|||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
|
||||
)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
|
@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
|||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
dtype="fp32",
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
|
|
Loading…
Reference in New Issue