mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Fix bugs and docs for feat/online-server (#5598)
* fix test bugs
* add do sample test
* del useless lines
* fix comments
* fix tests
* delete version tag
* delete version tag
* add
* del test sever
* fix test
* fix
* Revert "add"
This reverts commit b9305fb024
.
feat/online-serving
parent
7bbb28e48b
commit
61a1b2e798
|
@ -1,9 +1,8 @@
|
|||
"""
|
||||
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
|
||||
"""
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -218,7 +217,7 @@ class InferenceConfig:
|
|||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
attrs = [attr.name for attr in fields(cls)]
|
||||
inference_config_args = {}
|
||||
for attr in attrs:
|
||||
if attr in config_dict:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
|
||||
|
@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
|
|||
logger = logging.getLogger("colossalai-inference")
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
|
||||
msg = "Task finished unexpectedly. This should never happen! "
|
||||
try:
|
||||
try:
|
||||
|
@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
|
|||
|
||||
|
||||
class RequstStream:
|
||||
"""A stream of Output for a request that can be
|
||||
iterated over asynchronously."""
|
||||
"""
|
||||
A stream of Output for a request that can be iterated over asynchronously.
|
||||
Attributes: 1.request_id: The id of the request.
|
||||
2._future: A future that will be set when the request is finished.
|
||||
Methods: set_result and get_result, results will be set when finished, for once, and
|
||||
the `self.future` will be set to done.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: int) -> None:
|
||||
self.request_id = request_id
|
||||
|
@ -51,6 +57,10 @@ class RequstStream:
|
|||
class Tracer:
|
||||
"""
|
||||
Recording new requests and finished requests.
|
||||
Attributes: 1._request_streams: We create one stream for each request to trace the output.
|
||||
2._finished_requests: A queue to store the finished requests.
|
||||
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
|
||||
4.new_requests_event: An event to notify the engine that there are new requests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -93,8 +103,8 @@ class Tracer:
|
|||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = RequstStream(request_id)
|
||||
logger.info(f"Added request {request_id}.")
|
||||
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
||||
return stream
|
||||
|
@ -108,6 +118,7 @@ class Tracer:
|
|||
|
||||
if request_id not in self._request_streams or self._request_streams[request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
# The requests in new_requests will be aborted when try to get them(if marked aborted)
|
||||
return
|
||||
|
||||
self._request_streams[request_id].set_result(None)
|
||||
|
@ -117,9 +128,18 @@ class Tracer:
|
|||
Get new requests from http server.
|
||||
"""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[int] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if new_request["request_id"] in finished_requests:
|
||||
# The request has been aborted.
|
||||
stream.set_result(None)
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
|
@ -133,7 +153,8 @@ class Tracer:
|
|||
|
||||
class _AsyncInferenceEngine(InferenceEngine):
|
||||
"""
|
||||
Async methods for Inference Engine.
|
||||
Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
|
||||
Methods: 1. async_step: The async version of Engine.step()
|
||||
"""
|
||||
|
||||
async def async_step(self) -> List[str]:
|
||||
|
@ -161,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine):
|
|||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
# Return: List[Sequence]
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
for sequence in finished_sequences:
|
||||
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||
|
||||
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
|
||||
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
|
||||
|
||||
|
||||
class AsyncInferenceEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
"""An asynchronous wrapper for the InferenceEngine class.
|
||||
|
||||
This class is used to wrap the InferenceEngine class to make it asynchronous.
|
||||
It uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMEngine is kicked by the generate method when there are
|
||||
requests in the waiting queue. The generate method yields the outputs
|
||||
from the InferenceEngine to the caller.
|
||||
requests. Note that this class does not hold model directly, when incoming a new
|
||||
request, it first called `add_request` and the Tracer will record the request, putting
|
||||
it to the background `InferenceEngine`(done in background loop) to process. You can
|
||||
consider this engine as an interface.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
|
||||
|
@ -253,7 +275,7 @@ class AsyncInferenceEngine:
|
|||
prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> RequstStream:
|
||||
"""
|
||||
Add a request to the background tracker(waitting queue), start the background loop if needed.
|
||||
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
||||
"""
|
||||
if not self.background_loop_status:
|
||||
if self.start_engine_loop:
|
||||
|
@ -276,14 +298,12 @@ class AsyncInferenceEngine:
|
|||
"""
|
||||
Generate output from a request. It receives the request from http server, adds it into the
|
||||
waitting queue of Async Engine and streams the output sequence.
|
||||
|
||||
"""
|
||||
try:
|
||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||
return await stream.get_result()
|
||||
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
# If there is an exception or coroutine is cancelled, abort the request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
|
|
@ -527,10 +527,15 @@ class InferenceEngine:
|
|||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
<<<<<<< HEAD
|
||||
|
||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||
prompts = [prompts]
|
||||
request_ids = [request_ids]
|
||||
=======
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598)
|
||||
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||
|
@ -612,6 +617,9 @@ class InferenceEngine:
|
|||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if request_ids is not None and not isinstance(request_ids, list):
|
||||
request_ids = [request_ids]
|
||||
|
||||
if prompts is not None and not isinstance(prompts, list):
|
||||
prompts = [prompts]
|
||||
|
||||
|
@ -621,9 +629,10 @@ class InferenceEngine:
|
|||
"input_ids"
|
||||
]
|
||||
|
||||
# list of torch Tensor
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
|
@ -738,8 +747,6 @@ class InferenceEngine:
|
|||
logits = logits[:, -1, :]
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
return finished_sequences
|
||||
|
|
|
@ -328,7 +328,7 @@ class RequestHandler:
|
|||
def check_unfinished_seqs(self) -> bool:
|
||||
return self._has_waiting() or not self.running_list.is_empty()
|
||||
|
||||
def current_requests_in_batch(self) -> int:
|
||||
def total_requests_in_batch_bucket(self) -> int:
|
||||
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
|
||||
|
||||
def search_tokens(self, generation_config: GenerationConfig, logits):
|
||||
|
|
|
@ -6,9 +6,10 @@ Doc:
|
|||
Usage: (for local user)
|
||||
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
|
||||
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
|
||||
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
|
||||
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||
Version: V1.0
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -36,7 +37,8 @@ completion_serving = None
|
|||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/v0/models")
|
||||
# NOTE: (CjhHa1) models are still under development, need to be updated
|
||||
@app.get("/models")
|
||||
def get_available_models() -> Response:
|
||||
return JSONResponse(supported_models_dict)
|
||||
|
||||
|
@ -81,7 +83,7 @@ async def generate(request: Request) -> Response:
|
|||
return JSONResponse(ret)
|
||||
|
||||
|
||||
@app.post("/v1/completion")
|
||||
@app.post("/completion")
|
||||
async def create_completion(request: Request):
|
||||
request_dict = await request.json()
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
|
@ -95,7 +97,7 @@ async def create_completion(request: Request):
|
|||
return JSONResponse(content=ret)
|
||||
|
||||
|
||||
@app.post("/v1/chat")
|
||||
@app.post("/chat")
|
||||
async def create_chat(request: Request):
|
||||
request_dict = await request.json()
|
||||
|
||||
|
@ -127,14 +129,6 @@ def add_engine_config(parser):
|
|||
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
||||
)
|
||||
# Parallel arguments
|
||||
parser.add_argument(
|
||||
"--worker-use-ray",
|
||||
action="store_true",
|
||||
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
|
||||
)
|
||||
|
||||
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
|
||||
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
||||
|
||||
# KV cache arguments
|
||||
|
@ -149,28 +143,6 @@ def add_engine_config(parser):
|
|||
default=None,
|
||||
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
|
||||
)
|
||||
|
||||
# Quantization settings.
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
"-q",
|
||||
type=str,
|
||||
choices=["awq", "gptq", "squeezellm", None],
|
||||
default=None,
|
||||
help="Method used to quantize the weights. If "
|
||||
"None, we first check the `quantization_config` "
|
||||
"attribute in the model config file. If that is "
|
||||
"None, we assume the model weights are not "
|
||||
"quantized and use `dtype` to determine the data "
|
||||
"type of the weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enforce-eager",
|
||||
action="store_true",
|
||||
help="Always use eager-mode PyTorch. If False, "
|
||||
"will use eager mode and CUDA graph in hybrid "
|
||||
"for maximal performance and flexibility.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
@ -248,7 +248,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
|||
he initializer of weight, defaults to normal initializer.
|
||||
|
||||
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
|
||||
|
||||
::
|
||||
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
|
||||
renormalized to have norm max_norm. Note: this will modify weight in-place.
|
||||
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
|
||||
|
|
|
@ -7,18 +7,18 @@ class QuickstartUser(HttpUser):
|
|||
@tag("online-generation")
|
||||
@task(5)
|
||||
def completion(self):
|
||||
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
|
||||
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
|
||||
|
||||
@tag("online-generation")
|
||||
@task(5)
|
||||
def completion_streaming(self):
|
||||
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
|
||||
self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
|
||||
|
||||
@tag("online-chat")
|
||||
@task(5)
|
||||
def chat(self):
|
||||
self.client.post(
|
||||
"v1/chat",
|
||||
"/chat",
|
||||
json={
|
||||
"converation": [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
|
@ -32,7 +32,7 @@ class QuickstartUser(HttpUser):
|
|||
@task(5)
|
||||
def chat_streaming(self):
|
||||
self.client.post(
|
||||
"v1/chat",
|
||||
"/chat",
|
||||
json={
|
||||
"converation": [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
|
@ -55,4 +55,4 @@ class QuickstartUser(HttpUser):
|
|||
@tag("online-generation", "offline-generation")
|
||||
@task
|
||||
def get_models(self):
|
||||
self.client.get("/v0/models")
|
||||
self.client.get("/models")
|
||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
|||
|
||||
|
||||
@dataclass
|
||||
class SequenceTpye:
|
||||
class MockSequence:
|
||||
request_id: int
|
||||
|
||||
|
||||
|
@ -20,7 +20,11 @@ class MockEngine:
|
|||
|
||||
async def async_step(self):
|
||||
self.step_calls += 1
|
||||
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
|
||||
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
|
||||
|
||||
def add_single_request(self, **kwargs):
|
||||
del kwargs
|
||||
self.add_request_calls += 1
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
@ -37,14 +41,14 @@ class MockEngine:
|
|||
self.abort_request_calls += 1
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncInferenceEngine):
|
||||
class MockAsyncInferenceEngine(AsyncInferenceEngine):
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine = MockAsyncInferenceEngine()
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
|
@ -74,7 +78,3 @@ async def test_new_requests_event():
|
|||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_new_requests_event()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from colossalai.inference.core.async_engine import RequestTracker
|
||||
from colossalai.inference.core.async_engine import Tracer
|
||||
from colossalai.inference.struct import Sequence
|
||||
|
||||
|
||||
|
@ -15,27 +15,25 @@ class SampleEvent:
|
|||
self.flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
def test_request_tracer():
|
||||
tracker = Tracer()
|
||||
tracker.new_requests_event = SampleEvent()
|
||||
stream_1 = tracker.add_request(1)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 1
|
||||
assert not finished
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request(2)
|
||||
stream_3 = tracker.add_request(3)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == 2
|
||||
assert new[1]["request_id"] == 3
|
||||
assert not finished
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
|
@ -45,28 +43,21 @@ def test_request_tracker():
|
|||
assert not tracker.new_requests_event.flag
|
||||
|
||||
tracker.abort_request(1)
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert 1 in finished
|
||||
new = tracker.get_new_requests()
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request(4)
|
||||
tracker.abort_request(4)
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert 4 in finished
|
||||
new = tracker.get_new_requests()
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request(5)
|
||||
assert tracker.new_requests_event.flag
|
||||
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
new = tracker.get_new_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(finished) == 1
|
||||
assert 2 in finished
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == 5
|
||||
assert stream_2.finished
|
||||
|
@ -74,4 +65,4 @@ def test_request_tracker():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_request_tracker()
|
||||
test_request_tracer()
|
|
@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
|
|||
|
||||
|
||||
@parameterize(
|
||||
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"max_batch_size": 8,
|
||||
"max_output_len": 512,
|
||||
"max_input_len": 64,
|
||||
"do_sample": False,
|
||||
}
|
||||
],
|
||||
)
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
max_batch_size = test_config["max_batch_size"]
|
||||
max_input_len = test_config["max_input_len"]
|
||||
max_output_len = test_config["max_output_len"]
|
||||
do_sample = test_config["do_sample"]
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||
model = model.eval()
|
||||
|
|
|
@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
|||
)
|
||||
).cuda()
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||
"介绍一下武汉,",
|
||||
|
@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
|||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
|
||||
)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
|
@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
|
|||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
dtype="fp32",
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
# inspired by vLLM
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import requests
|
||||
|
||||
MAX_WAITING_TIME = 300
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class ServerRunner:
|
||||
def __init__(self, args):
|
||||
self.proc = subprocess.Popen(
|
||||
["python3", "-m", "colossalai.inference.server.api_server"] + args,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
self._wait_for_server()
|
||||
|
||||
def ready(self):
|
||||
return True
|
||||
|
||||
def _wait_for_server(self):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get("http://localhost:8000/v0/models").status_code == 200:
|
||||
break
|
||||
except Exception as err:
|
||||
if self.proc.poll() is not None:
|
||||
raise RuntimeError("Server exited unexpectedly.") from err
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > MAX_WAITING_TIME:
|
||||
raise RuntimeError("Server failed to start in time.") from err
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "proc"):
|
||||
self.proc.terminate()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote(
|
||||
[
|
||||
"--model",
|
||||
"/home/chenjianghai/data/llama-7b-hf",
|
||||
]
|
||||
)
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
async def test_completion(server):
|
||||
data = {"prompt": "How are you?", "stream": "False"}
|
||||
response = await server.post("v1/completion", json=data)
|
||||
assert response is not None
|
||||
|
||||
|
||||
async def test_chat(server):
|
||||
messages = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "what is 1+1?"},
|
||||
]
|
||||
data = {"messages": messages, "stream": "False"}
|
||||
response = await server.post("v1/chat", data)
|
||||
assert response is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
Loading…
Reference in New Issue