Browse Source

[Inference] Fix bugs and docs for feat/online-server (#5598)

* fix test bugs

* add do sample test

* del useless lines

* fix comments

* fix tests

* delete version tag

* delete version tag

* add

* del test sever

* fix test

* fix

* Revert "add"

This reverts commit b9305fb024.
feat/online-serving
Jianghai 7 months ago committed by CjhHa1
parent
commit
61a1b2e798
  1. 5
      colossalai/inference/config.py
  2. 52
      colossalai/inference/core/async_engine.py
  3. 13
      colossalai/inference/core/engine.py
  4. 2
      colossalai/inference/core/request_handler.py
  5. 40
      colossalai/inference/server/api_server.py
  6. 2
      colossalai/shardformer/layer/embedding.py
  7. 10
      examples/inference/client/locustfile.py
  8. 16
      tests/test_infer/test_async_engine/test_async_engine.py
  9. 27
      tests/test_infer/test_async_engine/test_request_tracer.py
  10. 18
      tests/test_infer/test_continuous_batching.py
  11. 6
      tests/test_infer/test_inference_engine.py
  12. 79
      tests/test_infer/test_server.py

5
colossalai/inference/config.py

@ -1,9 +1,8 @@
""" """
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
""" """
import dataclasses
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
import torch import torch
@ -218,7 +217,7 @@ class InferenceConfig:
@classmethod @classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Get the list of attributes of this dataclass. # Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in fields(cls)]
inference_config_args = {} inference_config_args = {}
for attr in attrs: for attr in attrs:
if attr in config_dict: if attr in config_dict:

52
colossalai/inference/core/async_engine.py

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
from functools import partial from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.core.engine import InferenceEngine
@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve
logger = logging.getLogger("colossalai-inference") logger = logging.getLogger("colossalai-inference")
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None:
msg = "Task finished unexpectedly. This should never happen! " msg = "Task finished unexpectedly. This should never happen! "
try: try:
try: try:
@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
class RequstStream: class RequstStream:
"""A stream of Output for a request that can be """
iterated over asynchronously.""" A stream of Output for a request that can be iterated over asynchronously.
Attributes: 1.request_id: The id of the request.
2._future: A future that will be set when the request is finished.
Methods: set_result and get_result, results will be set when finished, for once, and
the `self.future` will be set to done.
"""
def __init__(self, request_id: int) -> None: def __init__(self, request_id: int) -> None:
self.request_id = request_id self.request_id = request_id
@ -51,6 +57,10 @@ class RequstStream:
class Tracer: class Tracer:
""" """
Recording new requests and finished requests. Recording new requests and finished requests.
Attributes: 1._request_streams: We create one stream for each request to trace the output.
2._finished_requests: A queue to store the finished requests.
3._new_requests: New requests will be stored in this queue first, before sending them to the engine.
4.new_requests_event: An event to notify the engine that there are new requests.
""" """
def __init__(self) -> None: def __init__(self) -> None:
@ -93,8 +103,8 @@ class Tracer:
raise KeyError(f"Request {request_id} already exists.") raise KeyError(f"Request {request_id} already exists.")
stream = RequstStream(request_id) stream = RequstStream(request_id)
logger.info(f"Added request {request_id}.")
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set() self.new_requests_event.set()
return stream return stream
@ -108,6 +118,7 @@ class Tracer:
if request_id not in self._request_streams or self._request_streams[request_id].finished: if request_id not in self._request_streams or self._request_streams[request_id].finished:
# The request has already finished or been aborted. # The request has already finished or been aborted.
# The requests in new_requests will be aborted when try to get them(if marked aborted)
return return
self._request_streams[request_id].set_result(None) self._request_streams[request_id].set_result(None)
@ -117,9 +128,18 @@ class Tracer:
Get new requests from http server. Get new requests from http server.
""" """
new_requests: List[Dict] = [] new_requests: List[Dict] = []
finished_requests: Set[int] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
while not self._new_requests.empty(): while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait() stream, new_request = self._new_requests.get_nowait()
if new_request["request_id"] in finished_requests:
# The request has been aborted.
stream.set_result(None)
continue
self._request_streams[stream.request_id] = stream self._request_streams[stream.request_id] = stream
new_requests.append(new_request) new_requests.append(new_request)
@ -133,7 +153,8 @@ class Tracer:
class _AsyncInferenceEngine(InferenceEngine): class _AsyncInferenceEngine(InferenceEngine):
""" """
Async methods for Inference Engine. Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for
Methods: 1. async_step: The async version of Engine.step()
""" """
async def async_step(self) -> List[str]: async def async_step(self) -> List[str]:
@ -161,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine):
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.search_tokens(self.generation_config, logits)
# Return: List[Sequence]
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()
for sequence in finished_sequences: for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id) sequence.output = self.tokenizer.decode(sequence.output_token_id)
return finished_sequences, self.request_handler.current_requests_in_batch() > 0 return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
class AsyncInferenceEngine: class AsyncInferenceEngine:
"""An asynchronous wrapper for LLMEngine. """An asynchronous wrapper for the InferenceEngine class.
This class is used to wrap the InferenceEngine class to make it asynchronous. This class is used to wrap the InferenceEngine class to make it asynchronous.
It uses asyncio to create a background loop that keeps processing incoming It uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there are requests. Note that this class does not hold model directly, when incoming a new
requests in the waiting queue. The generate method yields the outputs request, it first called `add_request` and the Tracer will record the request, putting
from the InferenceEngine to the caller. it to the background `InferenceEngine`(done in background loop) to process. You can
consider this engine as an interface.
""" """
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine _engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
@ -253,7 +275,7 @@ class AsyncInferenceEngine:
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
) -> RequstStream: ) -> RequstStream:
""" """
Add a request to the background tracker(waitting queue), start the background loop if needed. Add a request to the background tracker(waiting queue), start the background loop if needed.
""" """
if not self.background_loop_status: if not self.background_loop_status:
if self.start_engine_loop: if self.start_engine_loop:
@ -276,14 +298,12 @@ class AsyncInferenceEngine:
""" """
Generate output from a request. It receives the request from http server, adds it into the Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence. waitting queue of Async Engine and streams the output sequence.
""" """
try: try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
return await stream.get_result() return await stream.get_result()
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the # If there is an exception or coroutine is cancelled, abort the request.
# request.
self._abort(request_id) self._abort(request_id)
raise e raise e

13
colossalai/inference/core/engine.py

@ -527,10 +527,15 @@ class InferenceEngine:
List[str]: Inference result returned by one generation. List[str]: Inference result returned by one generation.
""" """
with torch.inference_mode(): with torch.inference_mode():
<<<<<<< HEAD
if isinstance(prompts, str) and isinstance(request_ids, int): if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts] prompts = [prompts]
request_ids = [request_ids] request_ids = [request_ids]
=======
if prompts is not None or prompts_token_ids is not None:
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598)
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {} gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
@ -612,6 +617,9 @@ class InferenceEngine:
block_size = self.inference_config.block_size block_size = self.inference_config.block_size
if request_ids is not None and not isinstance(request_ids, list):
request_ids = [request_ids]
if prompts is not None and not isinstance(prompts, list): if prompts is not None and not isinstance(prompts, list):
prompts = [prompts] prompts = [prompts]
@ -621,9 +629,10 @@ class InferenceEngine:
"input_ids" "input_ids"
] ]
# list of torch Tensor
if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor): if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids] prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist() prompts_token_ids = prompts_token_ids.tolist()
else: else:
@ -738,8 +747,6 @@ class InferenceEngine:
logits = logits[:, -1, :] logits = logits[:, -1, :]
next_tokens = self.request_handler.search_tokens(self.generation_config, logits) next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens) self.request_handler.append_next_tokens(next_tokens)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()
return finished_sequences return finished_sequences

2
colossalai/inference/core/request_handler.py

@ -328,7 +328,7 @@ class RequestHandler:
def check_unfinished_seqs(self) -> bool: def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty() return self._has_waiting() or not self.running_list.is_empty()
def current_requests_in_batch(self) -> int: def total_requests_in_batch_bucket(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits): def search_tokens(self, generation_config: GenerationConfig, logits):

40
colossalai/inference/server/api_server.py

@ -6,9 +6,10 @@ Doc:
Usage: (for local user) Usage: (for local user)
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model` - First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api - Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \ - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"prompt":"hello, who are you? ","stream":"False"}'` -d '{"prompt":"hello, who are you? ","stream":"False"}'`
Version: V1.0
""" """
import argparse import argparse
@ -36,7 +37,8 @@ completion_serving = None
app = FastAPI() app = FastAPI()
@app.get("/v0/models") # NOTE: (CjhHa1) models are still under development, need to be updated
@app.get("/models")
def get_available_models() -> Response: def get_available_models() -> Response:
return JSONResponse(supported_models_dict) return JSONResponse(supported_models_dict)
@ -81,7 +83,7 @@ async def generate(request: Request) -> Response:
return JSONResponse(ret) return JSONResponse(ret)
@app.post("/v1/completion") @app.post("/completion")
async def create_completion(request: Request): async def create_completion(request: Request):
request_dict = await request.json() request_dict = await request.json()
stream = request_dict.pop("stream", "false").lower() stream = request_dict.pop("stream", "false").lower()
@ -95,7 +97,7 @@ async def create_completion(request: Request):
return JSONResponse(content=ret) return JSONResponse(content=ret)
@app.post("/v1/chat") @app.post("/chat")
async def create_chat(request: Request): async def create_chat(request: Request):
request_dict = await request.json() request_dict = await request.json()
@ -127,14 +129,6 @@ def add_engine_config(parser):
help="model context length. If unspecified, " "will be automatically derived from the model.", help="model context length. If unspecified, " "will be automatically derived from the model.",
) )
# Parallel arguments # Parallel arguments
parser.add_argument(
"--worker-use-ray",
action="store_true",
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
)
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
# KV cache arguments # KV cache arguments
@ -149,28 +143,6 @@ def add_engine_config(parser):
default=None, default=None,
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
) )
# Quantization settings.
parser.add_argument(
"--quantization",
"-q",
type=str,
choices=["awq", "gptq", "squeezellm", None],
default=None,
help="Method used to quantize the weights. If "
"None, we first check the `quantization_config` "
"attribute in the model config file. If that is "
"None, we assume the model weights are not "
"quantized and use `dtype` to determine the data "
"type of the weights.",
)
parser.add_argument(
"--enforce-eager",
action="store_true",
help="Always use eager-mode PyTorch. If False, "
"will use eager mode and CUDA graph in hybrid "
"for maximal performance and flexibility.",
)
return parser return parser

2
colossalai/shardformer/layer/embedding.py

@ -248,7 +248,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
he initializer of weight, defaults to normal initializer. he initializer of weight, defaults to normal initializer.
The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain:
::
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place. renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.

10
examples/inference/client/locustfile.py

@ -7,18 +7,18 @@ class QuickstartUser(HttpUser):
@tag("online-generation") @tag("online-generation")
@task(5) @task(5)
def completion(self): def completion(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"}) self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
@tag("online-generation") @tag("online-generation")
@task(5) @task(5)
def completion_streaming(self): def completion_streaming(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
@tag("online-chat") @tag("online-chat")
@task(5) @task(5)
def chat(self): def chat(self):
self.client.post( self.client.post(
"v1/chat", "/chat",
json={ json={
"converation": [ "converation": [
{"role": "system", "content": "you are a helpful assistant"}, {"role": "system", "content": "you are a helpful assistant"},
@ -32,7 +32,7 @@ class QuickstartUser(HttpUser):
@task(5) @task(5)
def chat_streaming(self): def chat_streaming(self):
self.client.post( self.client.post(
"v1/chat", "/chat",
json={ json={
"converation": [ "converation": [
{"role": "system", "content": "you are a helpful assistant"}, {"role": "system", "content": "you are a helpful assistant"},
@ -55,4 +55,4 @@ class QuickstartUser(HttpUser):
@tag("online-generation", "offline-generation") @tag("online-generation", "offline-generation")
@task @task
def get_models(self): def get_models(self):
self.client.get("/v0/models") self.client.get("/models")

16
tests/test_infer/test_async_engine/test_async_engine.py

@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine
@dataclass @dataclass
class SequenceTpye: class MockSequence:
request_id: int request_id: int
@ -20,7 +20,11 @@ class MockEngine:
async def async_step(self): async def async_step(self):
self.step_calls += 1 self.step_calls += 1
return [SequenceTpye(request_id=self.request_id)] if self.request_id else [] return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
def add_single_request(self, **kwargs):
del kwargs
self.add_request_calls += 1
def generate(self, request_id): def generate(self, request_id):
self.request_id = request_id self.request_id = request_id
@ -37,14 +41,14 @@ class MockEngine:
self.abort_request_calls += 1 self.abort_request_calls += 1
class MockAsyncLLMEngine(AsyncInferenceEngine): class MockAsyncInferenceEngine(AsyncInferenceEngine):
def _init_engine(self, *args, **kwargs): def _init_engine(self, *args, **kwargs):
return MockEngine() return MockEngine()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_requests_event(): async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) engine = MockAsyncInferenceEngine()
engine.start_background_loop() engine.start_background_loop()
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0 assert engine.engine.step_calls == 0
@ -74,7 +78,3 @@ async def test_new_requests_event():
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3 assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5 assert engine.engine.step_calls == 5
if __name__ == "__main__":
test_new_requests_event()

27
tests/test_infer/test_async_engine/test_request_tracker.py → tests/test_infer/test_async_engine/test_request_tracer.py

@ -1,6 +1,6 @@
import pytest import pytest
from colossalai.inference.core.async_engine import RequestTracker from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence from colossalai.inference.struct import Sequence
@ -15,27 +15,25 @@ class SampleEvent:
self.flag = False self.flag = False
def test_request_tracker(): def test_request_tracer():
tracker = RequestTracker() tracker = Tracer()
tracker.new_requests_event = SampleEvent() tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1) stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
assert len(new) == 1 assert len(new) == 1
assert new[0]["request_id"] == 1 assert new[0]["request_id"] == 1
assert not finished
assert not stream_1.finished assert not stream_1.finished
stream_2 = tracker.add_request(2) stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3) stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
assert len(new) == 2 assert len(new) == 2
assert new[0]["request_id"] == 2 assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3 assert new[1]["request_id"] == 3
assert not finished
assert not stream_2.finished assert not stream_2.finished
assert not stream_3.finished assert not stream_3.finished
@ -45,28 +43,21 @@ def test_request_tracker():
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
tracker.abort_request(1) tracker.abort_request(1)
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert len(finished) == 1
assert 1 in finished
assert not new assert not new
assert stream_1.finished
stream_4 = tracker.add_request(4) stream_4 = tracker.add_request(4)
tracker.abort_request(4) tracker.abort_request(4)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert len(finished) == 1
assert 4 in finished
assert not new assert not new
assert stream_4.finished assert stream_4.finished
stream_5 = tracker.add_request(5) stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0)) tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new, finished = tracker.get_new_and_finished_requests() new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert 2 in finished
assert len(new) == 1 assert len(new) == 1
assert new[0]["request_id"] == 5 assert new[0]["request_id"] == 5
assert stream_2.finished assert stream_2.finished
@ -74,4 +65,4 @@ def test_request_tracker():
if __name__ == "__main__": if __name__ == "__main__":
test_request_tracker() test_request_tracer()

18
tests/test_infer/test_continuous_batching.py

@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
@parameterize( @parameterize(
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50 "test_config",
[
{
"max_batch_size": 8,
"max_output_len": 512,
"max_input_len": 64,
"do_sample": False,
}
],
) )
def check_inference_engine(use_engine=False, prompt_template=None): def check_inference_engine(test_config, use_engine=False, prompt_template=None):
setup_seed(20) setup_seed(20)
max_batch_size = test_config["max_batch_size"]
max_input_len = test_config["max_input_len"]
max_output_len = test_config["max_output_len"]
do_sample = test_config["do_sample"]
top_p = 0.5
top_k = 50
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half() model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval() model = model.eval()

6
tests/test_infer/test_inference_engine.py

@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
) )
).cuda() ).cuda()
model = model.eval() model = model.eval()
inputs = [ inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,", "介绍一下武汉,",
@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
assert inference_engine.generation_config.max_new_tokens == output_len assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs) inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting() assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config) outputs = inference_engine.generate(generation_config=generation_config)
else: else:
if prompt_template: if prompt_template:
@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
inputs = inputs.cuda() inputs = inputs.cuda()
generation_config = GenerationConfig( generation_config = GenerationConfig(
do_sample=do_sample, do_sample=do_sample,
dtype="fp32",
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,

79
tests/test_infer/test_server.py

@ -1,79 +0,0 @@
# inspired by vLLM
import subprocess
import sys
import time
import pytest
import ray
import requests
MAX_WAITING_TIME = 300
pytestmark = pytest.mark.asyncio
@ray.remote(num_gpus=1)
class ServerRunner:
def __init__(self, args):
self.proc = subprocess.Popen(
["python3", "-m", "colossalai.inference.server.api_server"] + args,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()
def ready(self):
return True
def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get("http://localhost:8000/v0/models").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_WAITING_TIME:
raise RuntimeError("Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
@pytest.fixture(scope="session")
def server():
ray.init()
server_runner = ServerRunner.remote(
[
"--model",
"/home/chenjianghai/data/llama-7b-hf",
]
)
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
async def test_completion(server):
data = {"prompt": "How are you?", "stream": "False"}
response = await server.post("v1/completion", json=data)
assert response is not None
async def test_chat(server):
messages = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},
]
data = {"messages": messages, "stream": "False"}
response = await server.post("v1/chat", data)
assert response is not None
if __name__ == "__main__":
pytest.main([__file__])
Loading…
Cancel
Save