mirror of https://github.com/hpcaitech/ColossalAI
commit
492520dbdb
|
@ -62,6 +62,9 @@ class BatchBucket:
|
||||||
def current_batch_size(self):
|
def current_batch_size(self):
|
||||||
return self._current_batch_size
|
return self._current_batch_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._current_batch_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_batch_size(self):
|
def available_batch_size(self):
|
||||||
return self.max_batch_size - self._current_batch_size
|
return self.max_batch_size - self._current_batch_size
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
"""
|
"""
|
||||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, fields
|
||||||
from typing import Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -214,3 +213,18 @@ class InferenceConfig:
|
||||||
meta_config[type] = getattr(model_config, type)
|
meta_config[type] = getattr(model_config, type)
|
||||||
|
|
||||||
return GenerationConfig.from_dict(meta_config)
|
return GenerationConfig.from_dict(meta_config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
||||||
|
# Get the list of attributes of this dataclass.
|
||||||
|
attrs = [attr.name for attr in fields(cls)]
|
||||||
|
inference_config_args = {}
|
||||||
|
for attr in attrs:
|
||||||
|
if attr in config_dict:
|
||||||
|
inference_config_args[attr] = config_dict[attr]
|
||||||
|
else:
|
||||||
|
inference_config_args[attr] = getattr(cls, attr)
|
||||||
|
|
||||||
|
# Set the attributes from the parsed arguments.
|
||||||
|
inference_config = cls(**inference_config_args)
|
||||||
|
return inference_config
|
||||||
|
|
|
@ -0,0 +1,309 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
|
||||||
|
# CLI logger
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
logger = logging.getLogger("colossalai-inference")
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
|
||||||
|
msg = "Task finished unexpectedly. This should never happen! "
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
except Exception as exc:
|
||||||
|
request_tracker.propagate_exception(exc)
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
class RequstStream:
|
||||||
|
"""
|
||||||
|
A stream of Output for a request that can be iterated over asynchronously.
|
||||||
|
Attributes: 1.request_id: The id of the request.
|
||||||
|
2._future: A future that will be set when the request is finished.
|
||||||
|
Methods: set_result and get_result, results will be set when finished, for once, and
|
||||||
|
the `self.future` will be set to done.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: int) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._future = asyncio.Future()
|
||||||
|
|
||||||
|
def set_result(self, result) -> None:
|
||||||
|
"""Set final result and signal taht it's ready"""
|
||||||
|
if not self._future.done():
|
||||||
|
self._future.set_result(result)
|
||||||
|
|
||||||
|
async def get_result(self):
|
||||||
|
"""Wait for the result to be set and return it."""
|
||||||
|
return await self._future
|
||||||
|
|
||||||
|
@property
|
||||||
|
def finished(self) -> bool:
|
||||||
|
"""Check if the stream has finished by checking if the future is done."""
|
||||||
|
return self._future.done()
|
||||||
|
|
||||||
|
|
||||||
|
class Tracer:
|
||||||
|
"""
|
||||||
|
Recording new requests and finished requests.
|
||||||
|
Attributes: 1._request_streams: We create one stream for each request to trace the output.
|
||||||
|
2._finished_requests: A queue to store the finished requests.
|
||||||
|
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
|
||||||
|
4.new_requests_event: An event to notify the engine that there are new requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._request_streams: Dict[int, RequstStream] = {}
|
||||||
|
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
|
||||||
|
self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
|
||||||
|
self.new_requests_event = None
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self._request_streams
|
||||||
|
|
||||||
|
def init_event(self):
|
||||||
|
self.new_requests_event = asyncio.Event()
|
||||||
|
|
||||||
|
def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None:
|
||||||
|
"""
|
||||||
|
Propagate an exception to request streams (all if request_id is None).
|
||||||
|
"""
|
||||||
|
if request_id is not None:
|
||||||
|
self._request_streams[request_id].set_result(exc)
|
||||||
|
else:
|
||||||
|
for stream in self._request_streams.values():
|
||||||
|
stream.set_result(exc)
|
||||||
|
|
||||||
|
def process_finished_request(self, finished_request) -> None:
|
||||||
|
"""Process a finished request from the engine."""
|
||||||
|
request_id = finished_request.request_id
|
||||||
|
try:
|
||||||
|
self._request_streams[request_id].set_result(finished_request)
|
||||||
|
except:
|
||||||
|
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
|
||||||
|
self.abort_request(request_id)
|
||||||
|
|
||||||
|
def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
|
||||||
|
"""
|
||||||
|
Add a request to be sent to the engine on the next background
|
||||||
|
loop iteration.
|
||||||
|
"""
|
||||||
|
if request_id in self._request_streams:
|
||||||
|
raise KeyError(f"Request {request_id} already exists.")
|
||||||
|
|
||||||
|
stream = RequstStream(request_id)
|
||||||
|
logger.info(f"Added request {request_id}.")
|
||||||
|
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||||
|
self.new_requests_event.set()
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
|
||||||
|
"""Abort a request during next background loop iteration."""
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Aborted request {request_id}.")
|
||||||
|
|
||||||
|
self._finished_requests.put_nowait(request_id)
|
||||||
|
|
||||||
|
if request_id not in self._request_streams or self._request_streams[request_id].finished:
|
||||||
|
# The request has already finished or been aborted.
|
||||||
|
# The requests in new_requests will be aborted when try to get them(if marked aborted)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._request_streams[request_id].set_result(None)
|
||||||
|
|
||||||
|
def get_new_requests(self):
|
||||||
|
"""
|
||||||
|
Get new requests from http server.
|
||||||
|
"""
|
||||||
|
new_requests: List[Dict] = []
|
||||||
|
finished_requests: Set[int] = set()
|
||||||
|
|
||||||
|
while not self._finished_requests.empty():
|
||||||
|
request_id = self._finished_requests.get_nowait()
|
||||||
|
finished_requests.add(request_id)
|
||||||
|
|
||||||
|
while not self._new_requests.empty():
|
||||||
|
stream, new_request = self._new_requests.get_nowait()
|
||||||
|
if new_request["request_id"] in finished_requests:
|
||||||
|
# The request has been aborted.
|
||||||
|
stream.set_result(None)
|
||||||
|
continue
|
||||||
|
self._request_streams[stream.request_id] = stream
|
||||||
|
new_requests.append(new_request)
|
||||||
|
|
||||||
|
self.new_requests_event.clear()
|
||||||
|
|
||||||
|
return new_requests
|
||||||
|
|
||||||
|
async def wait_for_new_requests(self):
|
||||||
|
await self.new_requests_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncInferenceEngine(InferenceEngine):
|
||||||
|
"""
|
||||||
|
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
|
||||||
|
Methods: 1. async_step: The async version of Engine.step()
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def async_step(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
The async version of Engine.step()
|
||||||
|
Performs one decoding iteration and returns newly generated results.
|
||||||
|
|
||||||
|
It first schedules the sequences to be executed in the next iteration.
|
||||||
|
Then, it executes the model and updates the scheduler with the model
|
||||||
|
outputs. Finally, it decodes the sequences and returns the newly
|
||||||
|
generated results.
|
||||||
|
"""
|
||||||
|
batch = self.request_handler.schedule()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
# Use run_in_executor to asyncally run the sync method model.forward().
|
||||||
|
logits = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
self.model,
|
||||||
|
batch,
|
||||||
|
self.k_cache,
|
||||||
|
self.v_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.inference_config.pad_input:
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
|
|
||||||
|
finished_sequences = self.request_handler.update()
|
||||||
|
for sequence in finished_sequences:
|
||||||
|
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||||
|
|
||||||
|
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncInferenceEngine:
|
||||||
|
"""An asynchronous wrapper for the InferenceEngine class.
|
||||||
|
|
||||||
|
This class is used to wrap the InferenceEngine class to make it asynchronous.
|
||||||
|
It uses asyncio to create a background loop that keeps processing incoming
|
||||||
|
requests. Note that this class does not hold model directly, when incoming a new
|
||||||
|
request, it first called `add_request` and the Tracer will record the request, putting
|
||||||
|
it to the background `InferenceEngine`(done in background loop) to process. You can
|
||||||
|
consider this engine as an interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
|
||||||
|
|
||||||
|
def __init__(self, start_engine_loop: bool = True, **kwargs):
|
||||||
|
self.engine = self._init_engine(**kwargs)
|
||||||
|
self.background_loop = None
|
||||||
|
# reference to the unshielded loop
|
||||||
|
self._background_loop_unshielded = None
|
||||||
|
self.start_engine_loop = start_engine_loop
|
||||||
|
self._request_tracer = Tracer()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def background_loop_status(self):
|
||||||
|
return self.background_loop is not None and not self.background_loop.done()
|
||||||
|
|
||||||
|
def start_background_loop(self):
|
||||||
|
if self.background_loop_status:
|
||||||
|
raise RuntimeError("Existing loop is running")
|
||||||
|
|
||||||
|
self._request_tracer.init_event()
|
||||||
|
|
||||||
|
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
|
||||||
|
self._background_loop_unshielded.add_done_callback(
|
||||||
|
partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
|
||||||
|
)
|
||||||
|
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||||
|
|
||||||
|
def _init_engine(self, **kwargs):
|
||||||
|
return self._engine_class(**kwargs)
|
||||||
|
|
||||||
|
async def step(self):
|
||||||
|
"""
|
||||||
|
Run engine to process requests
|
||||||
|
|
||||||
|
Returns True if there are in-progress requests.
|
||||||
|
"""
|
||||||
|
new_requests = self._request_tracer.get_new_requests()
|
||||||
|
for new_request in new_requests:
|
||||||
|
self.engine.add_single_request(**new_request)
|
||||||
|
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
||||||
|
|
||||||
|
for seq in newly_finished_seqs:
|
||||||
|
self._request_tracer.process_finished_request(seq)
|
||||||
|
|
||||||
|
return has_running_requests
|
||||||
|
|
||||||
|
async def _engine_abort(self, request_ids: Iterable[int]):
|
||||||
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
|
async def abort(self, request_id: int):
|
||||||
|
"""
|
||||||
|
Abort a single request
|
||||||
|
"""
|
||||||
|
if not self.background_loop_status:
|
||||||
|
raise RuntimeError("Background loop is not running or launched correctly.")
|
||||||
|
return self._abort(request_id)
|
||||||
|
|
||||||
|
def _abort(self, request_id: int):
|
||||||
|
self._request_tracer.abort_request(request_id)
|
||||||
|
|
||||||
|
async def run_engine_loop(self):
|
||||||
|
processing_requests = False
|
||||||
|
while True:
|
||||||
|
if not processing_requests:
|
||||||
|
await self._request_tracer.wait_for_new_requests()
|
||||||
|
processing_requests = await self.step()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def add_request(
|
||||||
|
self,
|
||||||
|
request_id: int,
|
||||||
|
prompt: Optional[str],
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
) -> RequstStream:
|
||||||
|
"""
|
||||||
|
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
||||||
|
"""
|
||||||
|
if not self.background_loop_status:
|
||||||
|
if self.start_engine_loop:
|
||||||
|
self.start_background_loop()
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Background loop is not running.")
|
||||||
|
stream = self._request_tracer.add_request(
|
||||||
|
request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
)
|
||||||
|
return stream
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
request_id: int,
|
||||||
|
prompt: Optional[str],
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""
|
||||||
|
Generate output from a request. It receives the request from http server, adds it into the
|
||||||
|
waitting queue of Async Engine and streams the output sequence.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||||
|
return await stream.get_result()
|
||||||
|
|
||||||
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
|
# If there is an exception or coroutine is cancelled, abort the request.
|
||||||
|
self._abort(request_id)
|
||||||
|
raise e
|
|
@ -507,9 +507,9 @@ class InferenceEngine:
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str] = None,
|
request_ids: Union[List[int], int] = None,
|
||||||
|
prompts: Union[List[str], str] = None,
|
||||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
request_ids: List[int] = None,
|
|
||||||
return_token_ids: bool = False,
|
return_token_ids: bool = False,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -527,6 +527,9 @@ class InferenceEngine:
|
||||||
List[str]: Inference result returned by one generation.
|
List[str]: Inference result returned by one generation.
|
||||||
"""
|
"""
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||||
|
prompts = [prompts]
|
||||||
|
request_ids = [request_ids]
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||||
self.add_request(
|
self.add_request(
|
||||||
|
@ -580,13 +583,13 @@ class InferenceEngine:
|
||||||
if isinstance(prompts, (list, tuple)):
|
if isinstance(prompts, (list, tuple)):
|
||||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||||
elif isinstance(prompts, str):
|
elif isinstance(prompts, str):
|
||||||
return self.inference_config.rompt_template.format(input_text=prompts)
|
return self.inference_config.prompt_template.format(input_text=prompts)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_ids: List[int] = None,
|
request_ids: Union[List[int], int] = None,
|
||||||
prompts: List[str] = None,
|
prompts: List[str] = None,
|
||||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -601,11 +604,15 @@ class InferenceEngine:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# apply the prompt template to the input prompts
|
# apply the prompt template to the input prompts
|
||||||
|
|
||||||
if self.has_prompt_template and prompts is not None:
|
if self.has_prompt_template and prompts is not None:
|
||||||
prompts = self.format_prompt(prompts)
|
prompts = self.format_prompt(prompts)
|
||||||
|
|
||||||
block_size = self.inference_config.block_size
|
block_size = self.inference_config.block_size
|
||||||
|
|
||||||
|
if request_ids is not None and not isinstance(request_ids, list):
|
||||||
|
request_ids = [request_ids]
|
||||||
|
|
||||||
if prompts is not None and not isinstance(prompts, list):
|
if prompts is not None and not isinstance(prompts, list):
|
||||||
prompts = [prompts]
|
prompts = [prompts]
|
||||||
|
|
||||||
|
@ -615,8 +622,10 @@ class InferenceEngine:
|
||||||
"input_ids"
|
"input_ids"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# list of torch Tensor
|
||||||
if isinstance(prompts_token_ids, list):
|
if isinstance(prompts_token_ids, list):
|
||||||
pass
|
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||||
|
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||||
prompts_token_ids = prompts_token_ids.tolist()
|
prompts_token_ids = prompts_token_ids.tolist()
|
||||||
else:
|
else:
|
||||||
|
@ -632,8 +641,6 @@ class InferenceEngine:
|
||||||
|
|
||||||
for i in range(prompts_num):
|
for i in range(prompts_num):
|
||||||
if request_ids:
|
if request_ids:
|
||||||
if not isinstance(request_ids, list):
|
|
||||||
request_ids = [request_ids]
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
request_ids[0], int
|
request_ids[0], int
|
||||||
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
), f"The request_id type must be int, but got {type(request_ids[0])}"
|
||||||
|
@ -733,7 +740,6 @@ class InferenceEngine:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
self.request_handler.append_next_tokens(next_tokens)
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
|
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
return finished_sequences
|
return finished_sequences
|
||||||
|
|
|
@ -209,6 +209,7 @@ class RequestHandler:
|
||||||
break
|
break
|
||||||
|
|
||||||
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
||||||
|
# for now the recycle logic is not working
|
||||||
remove_list.extend(lst[:num_seqs_to_add])
|
remove_list.extend(lst[:num_seqs_to_add])
|
||||||
self.running_list.extend(lst[:num_seqs_to_add])
|
self.running_list.extend(lst[:num_seqs_to_add])
|
||||||
|
|
||||||
|
@ -263,24 +264,27 @@ class RequestHandler:
|
||||||
), f"Sequence {req.request_id} exceeds input length limit"
|
), f"Sequence {req.request_id} exceeds input length limit"
|
||||||
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
|
||||||
|
|
||||||
def abort_sequence(self, request_id: str):
|
def abort_sequence(self, request_id: int):
|
||||||
"""
|
"""
|
||||||
Abort the request.
|
Abort the request.
|
||||||
"""
|
"""
|
||||||
seq, priority = self._find_sequence(request_id)
|
result = self._find_sequence(request_id)
|
||||||
if seq.status == RequestStatus.WAITING:
|
if result is not None:
|
||||||
seq.mark_aborted()
|
seq, priority = result
|
||||||
self.waiting_list[priority].remove(seq)
|
if seq.status == RequestStatus.WAITING:
|
||||||
elif seq.status.is_running():
|
seq.mark_aborted()
|
||||||
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
self.waiting_list[priority].remove(seq)
|
||||||
self.running_list.remove(seq)
|
elif seq.status.is_running():
|
||||||
else:
|
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
|
||||||
try:
|
self.running_list.remove(seq)
|
||||||
self.done_list.remove(seq)
|
else:
|
||||||
except:
|
try:
|
||||||
return
|
self.done_list.remove(seq)
|
||||||
|
except:
|
||||||
|
return
|
||||||
|
return
|
||||||
|
|
||||||
def _find_sequence(self, request_id: str) -> Sequence:
|
def _find_sequence(self, request_id: int) -> Sequence:
|
||||||
"""
|
"""
|
||||||
Find the request by request_id.
|
Find the request by request_id.
|
||||||
"""
|
"""
|
||||||
|
@ -324,6 +328,9 @@ class RequestHandler:
|
||||||
def check_unfinished_seqs(self) -> bool:
|
def check_unfinished_seqs(self) -> bool:
|
||||||
return self._has_waiting() or not self.running_list.is_empty()
|
return self._has_waiting() or not self.running_list.is_empty()
|
||||||
|
|
||||||
|
def total_requests_in_batch_bucket(self) -> int:
|
||||||
|
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||||
|
|
||||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||||
"""
|
"""
|
||||||
Sample tokens for finished requests.
|
Sample tokens for finished requests.
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Online Service
|
||||||
|
Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and
|
||||||
|
you can simply construct a server with both completion and chat functionalities. For now we only support `Llama` model, we will fullfill
|
||||||
|
the blank quickly.
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
```bash
|
||||||
|
# First, Lauch an API locally.
|
||||||
|
python3 -m colossalai.inference.server.api_server --model path of your llama2 model --chat_template "{% for message in messages %}
|
||||||
|
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
||||||
|
|
||||||
|
|
||||||
|
# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
|
||||||
|
|
||||||
|
# For completion service, you can invoke it
|
||||||
|
curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}'
|
||||||
|
|
||||||
|
# For chat service, you can invoke it
|
||||||
|
curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"converation":
|
||||||
|
[{"role": "system", "content": "you are a helpful assistant"},
|
||||||
|
{"role": "user", "content": "what is 1+1?"},],
|
||||||
|
"stream": "False",}'
|
||||||
|
# If you just want to test a simple generation, turn to generate api
|
||||||
|
curl -X POST http://127.0.0.1:8000/generate -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}'
|
||||||
|
|
||||||
|
```
|
||||||
|
We also support streaming output, simply change the `stream` to `True` in the request body.
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""
|
||||||
|
Doc:
|
||||||
|
Feature:
|
||||||
|
- FastAPI based http server for Colossal-Inference
|
||||||
|
- Completion Service Supported
|
||||||
|
Usage: (for local user)
|
||||||
|
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
|
||||||
|
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
|
||||||
|
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||||
|
Version: V1.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from colossalai.inference.config import InferenceConfig
|
||||||
|
from colossalai.inference.server.chat_service import ChatServing
|
||||||
|
from colossalai.inference.server.completion_service import CompletionServing
|
||||||
|
from colossalai.inference.server.utils import id_generator
|
||||||
|
|
||||||
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||||
|
|
||||||
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||||
|
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
|
||||||
|
prompt_template_choices = ["llama", "vicuna"]
|
||||||
|
async_engine = None
|
||||||
|
chat_serving = None
|
||||||
|
completion_serving = None
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: (CjhHa1) models are still under development, need to be updated
|
||||||
|
@app.get("/models")
|
||||||
|
def get_available_models() -> Response:
|
||||||
|
return JSONResponse(supported_models_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def generate(request: Request) -> Response:
|
||||||
|
"""Generate completion for the request.
|
||||||
|
|
||||||
|
A request should be a JSON object with the following fields:
|
||||||
|
- prompts: the prompts to use for the generation.
|
||||||
|
- stream: whether to stream the results or not.
|
||||||
|
- other fields:
|
||||||
|
"""
|
||||||
|
request_dict = await request.json()
|
||||||
|
prompt = request_dict.pop("prompt")
|
||||||
|
stream = request_dict.pop("stream", "false").lower()
|
||||||
|
|
||||||
|
request_id = id_generator()
|
||||||
|
generation_config = get_generation_config(request_dict)
|
||||||
|
results = engine.generate(request_id, prompt, generation_config=generation_config)
|
||||||
|
|
||||||
|
# Streaming case
|
||||||
|
def stream_results():
|
||||||
|
for request_output in results:
|
||||||
|
ret = {"text": request_output[len(prompt) :]}
|
||||||
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
|
if stream == "true":
|
||||||
|
return StreamingResponse(stream_results())
|
||||||
|
|
||||||
|
# Non-streaming case
|
||||||
|
final_output = None
|
||||||
|
for request_output in results:
|
||||||
|
if request.is_disconnected():
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
engine.abort(request_id)
|
||||||
|
return Response(status_code=499)
|
||||||
|
final_output = request_output[len(prompt) :]
|
||||||
|
|
||||||
|
assert final_output is not None
|
||||||
|
ret = {"text": final_output}
|
||||||
|
return JSONResponse(ret)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/completion")
|
||||||
|
async def create_completion(request: Request):
|
||||||
|
request_dict = await request.json()
|
||||||
|
stream = request_dict.pop("stream", "false").lower()
|
||||||
|
generation_config = get_generation_config(request_dict)
|
||||||
|
result = await completion_serving.create_completion(request, generation_config)
|
||||||
|
|
||||||
|
ret = {"request_id": result.request_id, "text": result.output}
|
||||||
|
if stream == "true":
|
||||||
|
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return JSONResponse(content=ret)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chat")
|
||||||
|
async def create_chat(request: Request):
|
||||||
|
request_dict = await request.json()
|
||||||
|
|
||||||
|
stream = request_dict.get("stream", "false").lower()
|
||||||
|
generation_config = get_generation_config(request_dict)
|
||||||
|
message = await chat_serving.create_chat(request, generation_config)
|
||||||
|
if stream == "true":
|
||||||
|
return StreamingResponse(content=message, media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
ret = {"role": message.role, "text": message.content}
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_generation_config(request):
|
||||||
|
generation_config = async_engine.engine.generation_config
|
||||||
|
for arg in request:
|
||||||
|
if hasattr(generation_config, arg):
|
||||||
|
generation_config[arg] = request[arg]
|
||||||
|
return generation_config
|
||||||
|
|
||||||
|
|
||||||
|
def add_engine_config(parser):
|
||||||
|
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-model-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
||||||
|
)
|
||||||
|
# Parallel arguments
|
||||||
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
||||||
|
|
||||||
|
# KV cache arguments
|
||||||
|
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
|
||||||
|
|
||||||
|
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
|
||||||
|
|
||||||
|
# generation arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt_template",
|
||||||
|
choices=prompt_template_choices,
|
||||||
|
default=None,
|
||||||
|
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
|
||||||
|
|
||||||
|
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||||
|
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The model name used in the API. If not "
|
||||||
|
"specified, the model name will be the same as "
|
||||||
|
"the huggingface name.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the chat template, " "or the template in single-line form " "for the specified model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--response-role",
|
||||||
|
type=str,
|
||||||
|
default="assistant",
|
||||||
|
help="The role name to return if " "`request.add_generation_prompt=true`.",
|
||||||
|
)
|
||||||
|
parser = add_engine_config(parser)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
inference_config = InferenceConfig.from_dict(vars(args))
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
async_engine = AsyncInferenceEngine(
|
||||||
|
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config
|
||||||
|
)
|
||||||
|
engine = async_engine.engine
|
||||||
|
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
|
||||||
|
chat_serving = ChatServing(
|
||||||
|
async_engine,
|
||||||
|
served_model=model.__class__.__name__,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
response_role=args.response_role,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
)
|
||||||
|
app.root_path = args.root_path
|
||||||
|
uvicorn.run(
|
||||||
|
app=app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level="debug",
|
||||||
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||||
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
|
ssl_certfile=args.ssl_certfile,
|
||||||
|
)
|
|
@ -0,0 +1,142 @@
|
||||||
|
import asyncio
|
||||||
|
import codecs
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||||
|
|
||||||
|
from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator
|
||||||
|
|
||||||
|
logger = logging.getLogger("colossalai-inference")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatServing:
|
||||||
|
def __init__(
|
||||||
|
self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None
|
||||||
|
):
|
||||||
|
self.engine = engine
|
||||||
|
self.served_model = served_model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.response_role = response_role
|
||||||
|
self._load_chat_template(chat_template)
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_chat(self, request: Request, generation_config):
|
||||||
|
request_dict = await request.json()
|
||||||
|
messages = request_dict["messages"]
|
||||||
|
stream = request_dict.pop("stream", "false").lower()
|
||||||
|
add_generation_prompt = request_dict.pop("add_generation_prompt", False)
|
||||||
|
request_id = id_generator()
|
||||||
|
try:
|
||||||
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
|
conversation=messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Error in applying chat template from request: {str(e)}")
|
||||||
|
|
||||||
|
# it is not a intuitive way
|
||||||
|
self.engine.engine.generation_config = generation_config
|
||||||
|
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||||
|
|
||||||
|
if stream == "true":
|
||||||
|
return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id)
|
||||||
|
else:
|
||||||
|
return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id)
|
||||||
|
|
||||||
|
async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int):
|
||||||
|
# Send first response for each request.n (index) with the role
|
||||||
|
role = self.get_chat_request_role(request, request_dict)
|
||||||
|
n = request_dict.get("n", 1)
|
||||||
|
echo = request_dict.get("echo", "false").lower()
|
||||||
|
for i in range(n):
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role))
|
||||||
|
data = choice_data.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Send response to echo the input portion of the last message
|
||||||
|
if echo == "true":
|
||||||
|
last_msg_content = ""
|
||||||
|
if (
|
||||||
|
request_dict["messages"]
|
||||||
|
and isinstance(request_dict["messages"], list)
|
||||||
|
and request_dict["messages"][-1].get("content")
|
||||||
|
and request_dict["messages"][-1].get("role") == role
|
||||||
|
):
|
||||||
|
last_msg_content = request_dict["messages"][-1]["content"]
|
||||||
|
if last_msg_content:
|
||||||
|
for i in range(n):
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i, message=DeltaMessage(content=last_msg_content)
|
||||||
|
)
|
||||||
|
data = choice_data.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
result = await result_generator
|
||||||
|
choice_data = DeltaMessage(content=result.output)
|
||||||
|
data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Send the final done message after all response.n are finished
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
async def chat_completion_full_generator(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
request_dict: dict,
|
||||||
|
result_generator,
|
||||||
|
request_id,
|
||||||
|
):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
await self.engine.abort(request_id)
|
||||||
|
return {"error_msg": "Client disconnected"}
|
||||||
|
|
||||||
|
result = await result_generator
|
||||||
|
assert result is not None
|
||||||
|
role = self.get_chat_request_role(request, request_dict)
|
||||||
|
choice_data = ChatMessage(role=role, content=result.output)
|
||||||
|
echo = request_dict.get("echo", "false").lower()
|
||||||
|
|
||||||
|
if echo == "true":
|
||||||
|
last_msg_content = ""
|
||||||
|
if (
|
||||||
|
request.messages
|
||||||
|
and isinstance(request.messages, list)
|
||||||
|
and request.messages[-1].get("content")
|
||||||
|
and request.messages[-1].get("role") == role
|
||||||
|
):
|
||||||
|
last_msg_content = request.messages[-1]["content"]
|
||||||
|
|
||||||
|
full_message = last_msg_content + choice_data.content
|
||||||
|
choice_data.content = full_message
|
||||||
|
|
||||||
|
return choice_data
|
||||||
|
|
||||||
|
def get_chat_request_role(self, request: Request, request_dict: dict) -> str:
|
||||||
|
add_generation_prompt = request_dict.get("add_generation_prompt", False)
|
||||||
|
if add_generation_prompt:
|
||||||
|
return self.response_role
|
||||||
|
else:
|
||||||
|
return request_dict["messages"][-1]["role"]
|
||||||
|
|
||||||
|
def _load_chat_template(self, chat_template):
|
||||||
|
if chat_template is not None:
|
||||||
|
try:
|
||||||
|
with open(chat_template, "r") as f:
|
||||||
|
self.tokenizer.chat_template = f.read()
|
||||||
|
except OSError:
|
||||||
|
# If opening a file fails, set chat template to be args to
|
||||||
|
# ensure we decode so our escape are interpreted correctly
|
||||||
|
self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||||
|
|
||||||
|
logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}")
|
||||||
|
elif self.tokenizer.chat_template is not None:
|
||||||
|
logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}")
|
||||||
|
else:
|
||||||
|
logger.warning("No chat template provided. Chat API will not work.")
|
|
@ -0,0 +1,34 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||||
|
|
||||||
|
from .utils import id_generator
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionServing:
|
||||||
|
def __init__(self, engine: AsyncInferenceEngine, served_model: str):
|
||||||
|
self.engine = engine
|
||||||
|
self.served_model = served_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_completion(self, request, generation_config):
|
||||||
|
request_dict = await request.json()
|
||||||
|
request_id = id_generator()
|
||||||
|
|
||||||
|
prompt = request_dict.pop("prompt")
|
||||||
|
|
||||||
|
# it is not a intuitive way
|
||||||
|
self.engine.engine.generation_config = generation_config
|
||||||
|
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||||
|
|
||||||
|
if await request.is_disconnected():
|
||||||
|
# Abort the request if the client disconnects.
|
||||||
|
await self.engine.abort(request_id)
|
||||||
|
raise RuntimeError("Client disconnected")
|
||||||
|
|
||||||
|
final_res = await result_generator
|
||||||
|
return final_res
|
|
@ -0,0 +1,36 @@
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# make it singleton
|
||||||
|
class NumericIDGenerator:
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(NumericIDGenerator, cls).__new__(cls)
|
||||||
|
cls._instance.current_id = 0
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
self.current_id += 1
|
||||||
|
return self.current_id
|
||||||
|
|
||||||
|
|
||||||
|
id_generator = NumericIDGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: Any
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: DeltaMessage
|
|
@ -61,6 +61,7 @@ class Sequence:
|
||||||
pad_token_id (int): The pad token id for this inference process.
|
pad_token_id (int): The pad token id for this inference process.
|
||||||
max_output_len (int): Maximum output length.
|
max_output_len (int): Maximum output length.
|
||||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||||
|
output(str): The output of sequence
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_id: int
|
request_id: int
|
||||||
|
@ -73,6 +74,7 @@ class Sequence:
|
||||||
max_output_len: int = 256
|
max_output_len: int = 256
|
||||||
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
|
# NOTE(caidi) This is a temporary solution. It's better to move the logic to turn on or off the flag in sampling module in future.
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
|
output: str = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.output_token_id = []
|
self.output_token_id = []
|
||||||
|
@ -163,11 +165,13 @@ class Sequence:
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"(request_id={self.request_id}, "
|
f"(request_id={self.request_id}, "
|
||||||
f"prompt={self.prompt}, "
|
f"prompt={self.prompt},\n"
|
||||||
f"status={self.status.name}, "
|
f"output_token_id={self.output_token_id},\n"
|
||||||
f"sample_params={self.sample_params}, "
|
f"output={self.output},\n"
|
||||||
f"input_len={self.input_len},"
|
f"status={self.status.name},\n"
|
||||||
f"output_len={self.output_len})"
|
f"sample_params={self.sample_params},\n"
|
||||||
|
f"input_len={self.input_len},\n"
|
||||||
|
f"output_len={self.output_len})\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -249,7 +249,6 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||||
|
|
||||||
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
||||||
::
|
::
|
||||||
|
|
||||||
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
||||||
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
||||||
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
@ -36,7 +37,11 @@ class ShardFormer:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shard_config: ShardConfig):
|
def __init__(self, shard_config: ShardConfig):
|
||||||
self.coordinator = DistCoordinator()
|
self.is_distributed = dist.is_initialized()
|
||||||
|
if self.is_distributed:
|
||||||
|
self.coordinator = DistCoordinator()
|
||||||
|
else:
|
||||||
|
self.coordinator = None
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
|
|
||||||
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
from locust import HttpUser, between, tag, task
|
||||||
|
|
||||||
|
|
||||||
|
class QuickstartUser(HttpUser):
|
||||||
|
wait_time = between(1, 5)
|
||||||
|
|
||||||
|
@tag("online-generation")
|
||||||
|
@task(5)
|
||||||
|
def completion(self):
|
||||||
|
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
|
||||||
|
|
||||||
|
@tag("online-generation")
|
||||||
|
@task(5)
|
||||||
|
def completion_streaming(self):
|
||||||
|
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
|
||||||
|
|
||||||
|
@tag("online-chat")
|
||||||
|
@task(5)
|
||||||
|
def chat(self):
|
||||||
|
self.client.post(
|
||||||
|
"/chat",
|
||||||
|
json={
|
||||||
|
"converation": [
|
||||||
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
|
],
|
||||||
|
"stream": "False",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@tag("online-chat")
|
||||||
|
@task(5)
|
||||||
|
def chat_streaming(self):
|
||||||
|
self.client.post(
|
||||||
|
"/chat",
|
||||||
|
json={
|
||||||
|
"converation": [
|
||||||
|
{"role": "system", "content": "you are a helpful assistant"},
|
||||||
|
{"role": "user", "content": "what is 1+1?"},
|
||||||
|
],
|
||||||
|
"stream": "True",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@tag("offline-generation")
|
||||||
|
@task(5)
|
||||||
|
def generate_streaming(self):
|
||||||
|
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
|
||||||
|
|
||||||
|
@tag("offline-generation")
|
||||||
|
@task(5)
|
||||||
|
def generate(self):
|
||||||
|
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"})
|
||||||
|
|
||||||
|
@tag("online-generation", "offline-generation")
|
||||||
|
@task
|
||||||
|
def get_models(self):
|
||||||
|
self.client.get("/models")
|
|
@ -0,0 +1,27 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
#argument1: model_path
|
||||||
|
|
||||||
|
# launch server
|
||||||
|
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
|
||||||
|
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
||||||
|
echo "Model Path: $model_path"
|
||||||
|
echo "Starting server..."
|
||||||
|
python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template &
|
||||||
|
SERVER_PID=$!
|
||||||
|
|
||||||
|
# waiting time
|
||||||
|
sleep 60
|
||||||
|
|
||||||
|
# Run Locust
|
||||||
|
echo "Starting Locust..."
|
||||||
|
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
|
||||||
|
echo "Test completion api first"
|
||||||
|
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
||||||
|
echo "Test chat api"
|
||||||
|
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
||||||
|
# kill Server
|
||||||
|
echo "Stopping server..."
|
||||||
|
kill $SERVER_PID
|
||||||
|
|
||||||
|
echo "Test and server shutdown completely"
|
|
@ -0,0 +1,4 @@
|
||||||
|
#!/bin/bash
|
||||||
|
echo "Skip the test (this test is slow)"
|
||||||
|
|
||||||
|
# bash ./run_benchmark.sh
|
|
@ -0,0 +1,80 @@
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockSequence:
|
||||||
|
request_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class MockEngine:
|
||||||
|
def __init__(self):
|
||||||
|
self.step_calls = 0
|
||||||
|
self.add_request_calls = 0
|
||||||
|
self.abort_request_calls = 0
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
async def async_step(self):
|
||||||
|
self.step_calls += 1
|
||||||
|
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
|
||||||
|
|
||||||
|
def add_single_request(self, **kwargs):
|
||||||
|
del kwargs
|
||||||
|
self.add_request_calls += 1
|
||||||
|
|
||||||
|
def generate(self, request_id):
|
||||||
|
self.request_id = request_id
|
||||||
|
|
||||||
|
def stop_generating(self):
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
def add_request(self, **kwargs):
|
||||||
|
del kwargs # Unused
|
||||||
|
self.add_request_calls += 1
|
||||||
|
|
||||||
|
def abort_request(self, request_id):
|
||||||
|
del request_id # Unused
|
||||||
|
self.abort_request_calls += 1
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncInferenceEngine(AsyncInferenceEngine):
|
||||||
|
def _init_engine(self, *args, **kwargs):
|
||||||
|
return MockEngine()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_requests_event():
|
||||||
|
engine = MockAsyncInferenceEngine()
|
||||||
|
engine.start_background_loop()
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.step_calls == 0
|
||||||
|
|
||||||
|
await engine.add_request(1, "", None)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 1
|
||||||
|
assert engine.engine.step_calls == 1
|
||||||
|
|
||||||
|
await engine.add_request(2, "", None)
|
||||||
|
engine.engine.generate(2)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.add_request_calls == 2
|
||||||
|
assert engine.engine.step_calls == 2
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 3
|
||||||
|
engine.engine.stop_generating()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 4
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert engine.engine.step_calls == 4
|
||||||
|
|
||||||
|
await engine.add_request(3, "", None)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 3
|
||||||
|
assert engine.engine.step_calls == 5
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
assert engine.engine.add_request_calls == 3
|
||||||
|
assert engine.engine.step_calls == 5
|
|
@ -0,0 +1,68 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.inference.core.async_engine import Tracer
|
||||||
|
from colossalai.inference.struct import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
class SampleEvent:
|
||||||
|
def __init__(self):
|
||||||
|
self.flag = False
|
||||||
|
|
||||||
|
def set(self):
|
||||||
|
self.flag = True
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.flag = False
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_tracer():
|
||||||
|
tracker = Tracer()
|
||||||
|
tracker.new_requests_event = SampleEvent()
|
||||||
|
stream_1 = tracker.add_request(1)
|
||||||
|
assert tracker.new_requests_event.flag
|
||||||
|
new = tracker.get_new_requests()
|
||||||
|
assert not tracker.new_requests_event.flag
|
||||||
|
assert len(new) == 1
|
||||||
|
assert new[0]["request_id"] == 1
|
||||||
|
assert not stream_1.finished
|
||||||
|
|
||||||
|
stream_2 = tracker.add_request(2)
|
||||||
|
stream_3 = tracker.add_request(3)
|
||||||
|
assert tracker.new_requests_event.flag
|
||||||
|
new = tracker.get_new_requests()
|
||||||
|
assert not tracker.new_requests_event.flag
|
||||||
|
assert len(new) == 2
|
||||||
|
assert new[0]["request_id"] == 2
|
||||||
|
assert new[1]["request_id"] == 3
|
||||||
|
assert not stream_2.finished
|
||||||
|
assert not stream_3.finished
|
||||||
|
|
||||||
|
# request_ids must be unique
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
tracker.add_request(1)
|
||||||
|
assert not tracker.new_requests_event.flag
|
||||||
|
|
||||||
|
tracker.abort_request(1)
|
||||||
|
new = tracker.get_new_requests()
|
||||||
|
assert not new
|
||||||
|
|
||||||
|
stream_4 = tracker.add_request(4)
|
||||||
|
tracker.abort_request(4)
|
||||||
|
assert tracker.new_requests_event.flag
|
||||||
|
new = tracker.get_new_requests()
|
||||||
|
assert not new
|
||||||
|
assert stream_4.finished
|
||||||
|
|
||||||
|
stream_5 = tracker.add_request(5)
|
||||||
|
assert tracker.new_requests_event.flag
|
||||||
|
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
||||||
|
new = tracker.get_new_requests()
|
||||||
|
assert not tracker.new_requests_event.flag
|
||||||
|
assert len(new) == 1
|
||||||
|
assert new[0]["request_id"] == 5
|
||||||
|
assert stream_2.finished
|
||||||
|
assert not stream_5.finished
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_request_tracer()
|
|
@ -0,0 +1,103 @@
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||||
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
def setup_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_inputs(num_sequences, min_length, max_length):
|
||||||
|
sequences = []
|
||||||
|
for _ in range(num_sequences):
|
||||||
|
length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item()
|
||||||
|
# generating randomly lengthed sequences
|
||||||
|
sequence = torch.randint(10, 30000, size=(length,))
|
||||||
|
sequences.append(sequence)
|
||||||
|
return sequences
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize(
|
||||||
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"max_batch_size": 8,
|
||||||
|
"max_output_len": 512,
|
||||||
|
"max_input_len": 64,
|
||||||
|
"do_sample": False,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
|
||||||
|
setup_seed(20)
|
||||||
|
max_batch_size = test_config["max_batch_size"]
|
||||||
|
max_input_len = test_config["max_input_len"]
|
||||||
|
max_output_len = test_config["max_output_len"]
|
||||||
|
do_sample = test_config["do_sample"]
|
||||||
|
top_p = 0.5
|
||||||
|
top_k = 50
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||||
|
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
|
||||||
|
|
||||||
|
if use_engine:
|
||||||
|
inference_config = InferenceConfig(
|
||||||
|
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
|
||||||
|
)
|
||||||
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
|
assert inference_engine.generation_config.max_new_tokens == max_output_len
|
||||||
|
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
|
||||||
|
assert inference_engine.request_handler._has_waiting()
|
||||||
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||||
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
|
else:
|
||||||
|
if prompt_template:
|
||||||
|
# apply prompt template
|
||||||
|
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||||
|
inputs = inputs.cuda()
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
do_sample=do_sample,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
max_new_tokens=max_output_len,
|
||||||
|
)
|
||||||
|
outputs = model.generate(inputs, generation_config=generation_config)
|
||||||
|
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
assert len(outputs) == 10 * max_batch_size
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("prompt_template", [None, "llama"])
|
||||||
|
def check_continuous_batching(prompt_template):
|
||||||
|
check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_continuous_batching()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_continuous_batching():
|
||||||
|
spawn(run_dist, 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_continuous_batching()
|
|
@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
||||||
)
|
)
|
||||||
).cuda()
|
).cuda()
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||||
"介绍一下武汉,",
|
"介绍一下武汉,",
|
||||||
|
@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
||||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
|
||||||
|
)
|
||||||
outputs = inference_engine.generate(generation_config=generation_config)
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
else:
|
else:
|
||||||
if prompt_template:
|
if prompt_template:
|
||||||
|
@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
||||||
inputs = inputs.cuda()
|
inputs = inputs.cuda()
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
|
dtype="fp32",
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
|
Loading…
Reference in New Issue