diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 421c6b589..ee1cd7cfb 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -1,9 +1,8 @@ """ Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference. """ -import dataclasses import logging -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import Any, Dict, Optional, Union import torch @@ -218,7 +217,7 @@ class InferenceConfig: @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig": # Get the list of attributes of this dataclass. - attrs = [attr.name for attr in dataclasses.fields(cls)] + attrs = [attr.name for attr in fields(cls)] inference_config_args = {} for attr in attrs: if attr in config_dict: diff --git a/colossalai/inference/core/async_engine.py b/colossalai/inference/core/async_engine.py index e23d0b90f..6f7ab15d8 100644 --- a/colossalai/inference/core/async_engine.py +++ b/colossalai/inference/core/async_engine.py @@ -1,7 +1,7 @@ import asyncio import logging from functools import partial -from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type +from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type from colossalai.inference.core.engine import InferenceEngine @@ -10,7 +10,7 @@ logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(leve logger = logging.getLogger("colossalai-inference") -def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None: +def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "Tracer") -> None: msg = "Task finished unexpectedly. This should never happen! " try: try: @@ -26,8 +26,14 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac class RequstStream: - """A stream of Output for a request that can be - iterated over asynchronously.""" + """ + A stream of Output for a request that can be iterated over asynchronously. + Attributes: 1.request_id: The id of the request. + 2._future: A future that will be set when the request is finished. + Methods: set_result and get_result, results will be set when finished, for once, and + the `self.future` will be set to done. + + """ def __init__(self, request_id: int) -> None: self.request_id = request_id @@ -51,6 +57,10 @@ class RequstStream: class Tracer: """ Recording new requests and finished requests. + Attributes: 1._request_streams: We create one stream for each request to trace the output. + 2._finished_requests: A queue to store the finished requests. + 3._new_requests: New requests will be stored in this queue first, before sending them to the engine. + 4.new_requests_event: An event to notify the engine that there are new requests. """ def __init__(self) -> None: @@ -93,8 +103,8 @@ class Tracer: raise KeyError(f"Request {request_id} already exists.") stream = RequstStream(request_id) + logger.info(f"Added request {request_id}.") self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs})) - self.new_requests_event.set() return stream @@ -108,6 +118,7 @@ class Tracer: if request_id not in self._request_streams or self._request_streams[request_id].finished: # The request has already finished or been aborted. + # The requests in new_requests will be aborted when try to get them(if marked aborted) return self._request_streams[request_id].set_result(None) @@ -117,9 +128,18 @@ class Tracer: Get new requests from http server. """ new_requests: List[Dict] = [] + finished_requests: Set[int] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() + if new_request["request_id"] in finished_requests: + # The request has been aborted. + stream.set_result(None) + continue self._request_streams[stream.request_id] = stream new_requests.append(new_request) @@ -133,7 +153,8 @@ class Tracer: class _AsyncInferenceEngine(InferenceEngine): """ - Async methods for Inference Engine. + Async methods for Inference Engine. This engine is an extension for InferenceEngine, and the additional methods will only be used for + Methods: 1. async_step: The async version of Engine.step() """ async def async_step(self) -> List[str]: @@ -161,22 +182,23 @@ class _AsyncInferenceEngine(InferenceEngine): if self.inference_config.pad_input: logits = logits[:, -1, :] self.request_handler.search_tokens(self.generation_config, logits) - # Return: List[Sequence] + finished_sequences = self.request_handler.update() for sequence in finished_sequences: sequence.output = self.tokenizer.decode(sequence.output_token_id) - return finished_sequences, self.request_handler.current_requests_in_batch() > 0 + return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0 class AsyncInferenceEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for the InferenceEngine class. This class is used to wrap the InferenceEngine class to make it asynchronous. It uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there are - requests in the waiting queue. The generate method yields the outputs - from the InferenceEngine to the caller. + requests. Note that this class does not hold model directly, when incoming a new + request, it first called `add_request` and the Tracer will record the request, putting + it to the background `InferenceEngine`(done in background loop) to process. You can + consider this engine as an interface. """ _engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine @@ -253,7 +275,7 @@ class AsyncInferenceEngine: prompt_token_ids: Optional[List[int]] = None, ) -> RequstStream: """ - Add a request to the background tracker(waitting queue), start the background loop if needed. + Add a request to the background tracker(waiting queue), start the background loop if needed. """ if not self.background_loop_status: if self.start_engine_loop: @@ -276,14 +298,12 @@ class AsyncInferenceEngine: """ Generate output from a request. It receives the request from http server, adds it into the waitting queue of Async Engine and streams the output sequence. - """ try: stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) return await stream.get_result() except (Exception, asyncio.CancelledError) as e: - # If there is an exception or coroutine is cancelled, abort the - # request. + # If there is an exception or coroutine is cancelled, abort the request. self._abort(request_id) raise e diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 3f456e1f9..02a8c92a2 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -527,10 +527,15 @@ class InferenceEngine: List[str]: Inference result returned by one generation. """ with torch.inference_mode(): +<<<<<<< HEAD if isinstance(prompts, str) and isinstance(request_ids, int): prompts = [prompts] request_ids = [request_ids] +======= + if prompts is not None or prompts_token_ids is not None: + self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) +>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598) if prompts is not None or prompts_token_ids is not None: gen_config_dict = generation_config.to_dict() if generation_config is not None else {} @@ -612,6 +617,9 @@ class InferenceEngine: block_size = self.inference_config.block_size + if request_ids is not None and not isinstance(request_ids, list): + request_ids = [request_ids] + if prompts is not None and not isinstance(prompts, list): prompts = [prompts] @@ -621,9 +629,10 @@ class InferenceEngine: "input_ids" ] + # list of torch Tensor if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids[0], torch.Tensor): - prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids] + prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): prompts_token_ids = prompts_token_ids.tolist() else: @@ -738,8 +747,6 @@ class InferenceEngine: logits = logits[:, -1, :] next_tokens = self.request_handler.search_tokens(self.generation_config, logits) self.request_handler.append_next_tokens(next_tokens) - - self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() return finished_sequences diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 12c9cebf7..03b4d2305 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -328,7 +328,7 @@ class RequestHandler: def check_unfinished_seqs(self) -> bool: return self._has_waiting() or not self.running_list.is_empty() - def current_requests_in_batch(self) -> int: + def total_requests_in_batch_bucket(self) -> int: return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size def search_tokens(self, generation_config: GenerationConfig, logits): diff --git a/colossalai/inference/server/api_server.py b/colossalai/inference/server/api_server.py index 60ccf15fc..dfbd2c906 100644 --- a/colossalai/inference/server/api_server.py +++ b/colossalai/inference/server/api_server.py @@ -6,9 +6,10 @@ Doc: Usage: (for local user) - First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model` - Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api - - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \ + - For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/completion \ -H 'Content-Type: application/json' \ -d '{"prompt":"hello, who are you? ","stream":"False"}'` + Version: V1.0 """ import argparse @@ -36,7 +37,8 @@ completion_serving = None app = FastAPI() -@app.get("/v0/models") +# NOTE: (CjhHa1) models are still under development, need to be updated +@app.get("/models") def get_available_models() -> Response: return JSONResponse(supported_models_dict) @@ -81,7 +83,7 @@ async def generate(request: Request) -> Response: return JSONResponse(ret) -@app.post("/v1/completion") +@app.post("/completion") async def create_completion(request: Request): request_dict = await request.json() stream = request_dict.pop("stream", "false").lower() @@ -95,7 +97,7 @@ async def create_completion(request: Request): return JSONResponse(content=ret) -@app.post("/v1/chat") +@app.post("/chat") async def create_chat(request: Request): request_dict = await request.json() @@ -127,14 +129,6 @@ def add_engine_config(parser): help="model context length. If unspecified, " "will be automatically derived from the model.", ) # Parallel arguments - parser.add_argument( - "--worker-use-ray", - action="store_true", - help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU", - ) - - parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages") - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas") # KV cache arguments @@ -149,28 +143,6 @@ def add_engine_config(parser): default=None, help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.", ) - - # Quantization settings. - parser.add_argument( - "--quantization", - "-q", - type=str, - choices=["awq", "gptq", "squeezellm", None], - default=None, - help="Method used to quantize the weights. If " - "None, we first check the `quantization_config` " - "attribute in the model config file. If that is " - "None, we assume the model weights are not " - "quantized and use `dtype` to determine the data " - "type of the weights.", - ) - parser.add_argument( - "--enforce-eager", - action="store_true", - help="Always use eager-mode PyTorch. If False, " - "will use eager mode and CUDA graph in hybrid " - "for maximal performance and flexibility.", - ) return parser diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 93df5e522..9b77774aa 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -248,7 +248,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule): he initializer of weight, defaults to normal initializer. The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - + :: max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. diff --git a/examples/inference/client/locustfile.py b/examples/inference/client/locustfile.py index af00f3c91..a65c8b667 100644 --- a/examples/inference/client/locustfile.py +++ b/examples/inference/client/locustfile.py @@ -7,18 +7,18 @@ class QuickstartUser(HttpUser): @tag("online-generation") @task(5) def completion(self): - self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"}) + self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "False"}) @tag("online-generation") @task(5) def completion_streaming(self): - self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) + self.client.post("/completion", json={"prompt": "hello, who are you? ", "stream": "True"}) @tag("online-chat") @task(5) def chat(self): self.client.post( - "v1/chat", + "/chat", json={ "converation": [ {"role": "system", "content": "you are a helpful assistant"}, @@ -32,7 +32,7 @@ class QuickstartUser(HttpUser): @task(5) def chat_streaming(self): self.client.post( - "v1/chat", + "/chat", json={ "converation": [ {"role": "system", "content": "you are a helpful assistant"}, @@ -55,4 +55,4 @@ class QuickstartUser(HttpUser): @tag("online-generation", "offline-generation") @task def get_models(self): - self.client.get("/v0/models") + self.client.get("/models") diff --git a/tests/test_infer/test_async_engine/test_async_engine.py b/tests/test_infer/test_async_engine/test_async_engine.py index ebca11c72..ac532b1b1 100644 --- a/tests/test_infer/test_async_engine/test_async_engine.py +++ b/tests/test_infer/test_async_engine/test_async_engine.py @@ -7,7 +7,7 @@ from colossalai.inference.core.async_engine import AsyncInferenceEngine @dataclass -class SequenceTpye: +class MockSequence: request_id: int @@ -20,7 +20,11 @@ class MockEngine: async def async_step(self): self.step_calls += 1 - return [SequenceTpye(request_id=self.request_id)] if self.request_id else [] + return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False) + + def add_single_request(self, **kwargs): + del kwargs + self.add_request_calls += 1 def generate(self, request_id): self.request_id = request_id @@ -37,14 +41,14 @@ class MockEngine: self.abort_request_calls += 1 -class MockAsyncLLMEngine(AsyncInferenceEngine): +class MockAsyncInferenceEngine(AsyncInferenceEngine): def _init_engine(self, *args, **kwargs): return MockEngine() @pytest.mark.asyncio async def test_new_requests_event(): - engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False) + engine = MockAsyncInferenceEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 @@ -74,7 +78,3 @@ async def test_new_requests_event(): await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == 5 - - -if __name__ == "__main__": - test_new_requests_event() diff --git a/tests/test_infer/test_async_engine/test_request_tracker.py b/tests/test_infer/test_async_engine/test_request_tracer.py similarity index 69% rename from tests/test_infer/test_async_engine/test_request_tracker.py rename to tests/test_infer/test_async_engine/test_request_tracer.py index 9a797a862..14bcb9628 100644 --- a/tests/test_infer/test_async_engine/test_request_tracker.py +++ b/tests/test_infer/test_async_engine/test_request_tracer.py @@ -1,6 +1,6 @@ import pytest -from colossalai.inference.core.async_engine import RequestTracker +from colossalai.inference.core.async_engine import Tracer from colossalai.inference.struct import Sequence @@ -15,27 +15,25 @@ class SampleEvent: self.flag = False -def test_request_tracker(): - tracker = RequestTracker() +def test_request_tracer(): + tracker = Tracer() tracker.new_requests_event = SampleEvent() stream_1 = tracker.add_request(1) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag assert len(new) == 1 assert new[0]["request_id"] == 1 - assert not finished assert not stream_1.finished stream_2 = tracker.add_request(2) stream_3 = tracker.add_request(3) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag assert len(new) == 2 assert new[0]["request_id"] == 2 assert new[1]["request_id"] == 3 - assert not finished assert not stream_2.finished assert not stream_3.finished @@ -45,28 +43,21 @@ def test_request_tracker(): assert not tracker.new_requests_event.flag tracker.abort_request(1) - new, finished = tracker.get_new_and_finished_requests() - assert len(finished) == 1 - assert 1 in finished + new = tracker.get_new_requests() assert not new - assert stream_1.finished stream_4 = tracker.add_request(4) tracker.abort_request(4) assert tracker.new_requests_event.flag - new, finished = tracker.get_new_and_finished_requests() - assert len(finished) == 1 - assert 4 in finished + new = tracker.get_new_requests() assert not new assert stream_4.finished stream_5 = tracker.add_request(5) assert tracker.new_requests_event.flag tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0)) - new, finished = tracker.get_new_and_finished_requests() + new = tracker.get_new_requests() assert not tracker.new_requests_event.flag - assert len(finished) == 1 - assert 2 in finished assert len(new) == 1 assert new[0]["request_id"] == 5 assert stream_2.finished @@ -74,4 +65,4 @@ def test_request_tracker(): if __name__ == "__main__": - test_request_tracker() + test_request_tracer() diff --git a/tests/test_infer/test_continuous_batching.py b/tests/test_infer/test_continuous_batching.py index 0b0d92c7c..350ed473e 100644 --- a/tests/test_infer/test_continuous_batching.py +++ b/tests/test_infer/test_continuous_batching.py @@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length): @parameterize( - "max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50 + "test_config", + [ + { + "max_batch_size": 8, + "max_output_len": 512, + "max_input_len": 64, + "do_sample": False, + } + ], ) -def check_inference_engine(use_engine=False, prompt_template=None): +def check_inference_engine(test_config, use_engine=False, prompt_template=None): setup_seed(20) + max_batch_size = test_config["max_batch_size"] + max_input_len = test_config["max_input_len"] + max_output_len = test_config["max_output_len"] + do_sample = test_config["do_sample"] + top_p = 0.5 + top_k = 50 tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half() model = model.eval() diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index be1330898..919a10077 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru ) ).cuda() model = model.eval() - inputs = [ "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", "介绍一下武汉,", @@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + generation_config = GenerationConfig( + max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k + ) outputs = inference_engine.generate(generation_config=generation_config) else: if prompt_template: @@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru inputs = inputs.cuda() generation_config = GenerationConfig( do_sample=do_sample, + dtype="fp32", top_p=top_p, top_k=top_k, pad_token_id=tokenizer.pad_token_id, diff --git a/tests/test_infer/test_server.py b/tests/test_infer/test_server.py deleted file mode 100644 index 05ac5a264..000000000 --- a/tests/test_infer/test_server.py +++ /dev/null @@ -1,79 +0,0 @@ -# inspired by vLLM -import subprocess -import sys -import time - -import pytest -import ray -import requests - -MAX_WAITING_TIME = 300 - -pytestmark = pytest.mark.asyncio - - -@ray.remote(num_gpus=1) -class ServerRunner: - def __init__(self, args): - self.proc = subprocess.Popen( - ["python3", "-m", "colossalai.inference.server.api_server"] + args, - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_server() - - def ready(self): - return True - - def _wait_for_server(self): - # run health check - start = time.time() - while True: - try: - if requests.get("http://localhost:8000/v0/models").status_code == 200: - break - except Exception as err: - if self.proc.poll() is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_WAITING_TIME: - raise RuntimeError("Server failed to start in time.") from err - - def __del__(self): - if hasattr(self, "proc"): - self.proc.terminate() - - -@pytest.fixture(scope="session") -def server(): - ray.init() - server_runner = ServerRunner.remote( - [ - "--model", - "/home/chenjianghai/data/llama-7b-hf", - ] - ) - ray.get(server_runner.ready.remote()) - yield server_runner - ray.shutdown() - - -async def test_completion(server): - data = {"prompt": "How are you?", "stream": "False"} - response = await server.post("v1/completion", json=data) - assert response is not None - - -async def test_chat(server): - messages = [ - {"role": "system", "content": "you are a helpful assistant"}, - {"role": "user", "content": "what is 1+1?"}, - ] - data = {"messages": messages, "stream": "False"} - response = await server.post("v1/chat", data) - assert response is not None - - -if __name__ == "__main__": - pytest.main([__file__])