[Inference] Finish Online Serving Test, add streaming output api, continuous batching test and example (#5432)

* finish online test and add examples

* fix test_contionus_batching

* fix some bugs

* fix bash

* fix

* fix inference

* finish revision

* fix typos

* revision
pull/5584/head
Jianghai 2024-03-18 17:06:05 +08:00 committed by CjhHa1
parent 1572af2432
commit 3d211ff81b
11 changed files with 213 additions and 194 deletions

View File

@ -1,13 +1,13 @@
import asyncio
import logging
from functools import partial
from logging import Logger
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine
class AsyncEngineDeadError(RuntimeError):
pass
# CLI logger
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("colossalai-inference")
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
@ -18,54 +18,45 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc
raise AsyncEngineDeadError(msg)
raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
raise RuntimeError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
class AsyncStream:
class RequstStream:
"""A stream of Output for a request that can be
iterated over asynchronously."""
def __init__(self, request_id: str) -> None:
def __init__(self, request_id: int) -> None:
self.request_id = request_id
self._queue = asyncio.Queue()
self._finished = False
self._future = asyncio.Future()
def put(self, item) -> None:
if self._finished:
return
self._queue.put_nowait(item)
def set_result(self, result) -> None:
"""Set final result and signal taht it's ready"""
if not self._future.done():
self._future.set_result(result)
def finish(self) -> None:
self._queue.put_nowait(StopIteration)
self._finished = True
async def get_result(self):
"""Wait for the result to be set and return it."""
return await self._future
@property
def finished(self) -> bool:
return self._finished
def __aiter__(self):
return self
async def __anext__(self):
result = await self._queue.get()
if result is StopIteration:
raise StopAsyncIteration
elif isinstance(result, Exception):
raise result
return result
"""Check if the stream has finished by checking if the future is done."""
return self._future.done()
class RequestTracker:
"""Synchronous abstraction for tracking requests."""
class Tracer:
"""
Recording new requests and finished requests.
"""
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._request_streams: Dict[int, RequstStream] = {}
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item):
@ -79,19 +70,21 @@ class RequestTracker:
Propagate an exception to request streams (all if request_id is None).
"""
if request_id is not None:
self._request_streams[request_id].put(exc)
self._request_streams[request_id].set_result(exc)
else:
for stream in self._request_streams.values():
stream.put(exc)
stream.set_result(exc)
def process_finished_request(self, finished_request) -> None:
"""Process a finished request from the engine."""
request_id = finished_request.request_id
self._request_streams[request_id].put(finished_request)
try:
self._request_streams[request_id].set_result(finished_request)
except:
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
self.abort_request(request_id)
def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream:
def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
"""
Add a request to be sent to the engine on the next background
loop iteration.
@ -99,7 +92,7 @@ class RequestTracker:
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
stream = RequstStream(request_id)
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set()
@ -109,7 +102,7 @@ class RequestTracker:
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
Logger.info(f"Aborted request {request_id}.")
logger.info(f"Aborted request {request_id}.")
self._finished_requests.put_nowait(request_id)
@ -117,7 +110,7 @@ class RequestTracker:
# The request has already finished or been aborted.
return
self._request_streams[request_id].finish()
self._request_streams[request_id].set_result(None)
def get_new_requests(self):
"""
@ -134,30 +127,6 @@ class RequestTracker:
return new_requests
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[int] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
@ -194,6 +163,8 @@ class _AsyncInferenceEngine(InferenceEngine):
self.request_handler.search_tokens(self.generation_config, logits)
# Return: List[Sequence]
finished_sequences = self.request_handler.update()
for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id)
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
@ -216,7 +187,7 @@ class AsyncInferenceEngine:
# reference to the unshielded loop
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
self._request_tracer = Tracer()
@property
def background_loop_status(self):
@ -226,11 +197,11 @@ class AsyncInferenceEngine:
if self.background_loop_status:
raise RuntimeError("Existing loop is running")
self._request_tracker.init_event()
self._request_tracer.init_event()
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, request_tracker=self._request_tracker)
partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
)
self.background_loop = asyncio.shield(self._background_loop_unshielded)
@ -243,12 +214,13 @@ class AsyncInferenceEngine:
Returns True if there are in-progress requests.
"""
new_requests = self._request_tracker.get_new_requests()
new_requests = self._request_tracer.get_new_requests()
for new_request in new_requests:
self.engine.add_single_request(**new_request)
newly_finished_seqs, has_running_requests = await self.engine.async_step()
for seq in newly_finished_seqs:
self._request_tracker.process_finished_request(seq)
self._request_tracer.process_finished_request(seq)
return has_running_requests
@ -264,13 +236,13 @@ class AsyncInferenceEngine:
return self._abort(request_id)
def _abort(self, request_id: int):
self._request_tracker.abort_request(request_id)
self._request_tracer.abort_request(request_id)
async def run_engine_loop(self):
processing_requests = False
while True:
if not processing_requests:
await self._request_tracker.wait_for_new_requests()
await self._request_tracer.wait_for_new_requests()
processing_requests = await self.step()
await asyncio.sleep(0)
@ -279,7 +251,7 @@ class AsyncInferenceEngine:
request_id: int,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
) -> AsyncStream:
) -> RequstStream:
"""
Add a request to the background tracker(waitting queue), start the background loop if needed.
"""
@ -288,7 +260,7 @@ class AsyncInferenceEngine:
self.start_background_loop()
else:
raise RuntimeError("Background loop is not running.")
stream = self._request_tracker.add_request(
stream = self._request_tracer.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
@ -308,8 +280,7 @@ class AsyncInferenceEngine:
"""
try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
async for request_output in stream:
yield request_output
return await stream.get_result()
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the

View File

@ -535,10 +535,10 @@ class InferenceEngine:
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
print(prompts_token_ids)
if isinstance(prompts_token_ids, list):
pass
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:
@ -565,7 +565,6 @@ class InferenceEngine:
prompt = None
else:
prompt = prompts[i]
sequence = Sequence(
request_id,
prompt,
@ -646,8 +645,6 @@ class InferenceEngine:
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)
print("in step", logits)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()

View File

@ -209,6 +209,7 @@ class RequestHandler:
break
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
# for now the recycle logic is not working
remove_list.extend(lst[:num_seqs_to_add])
self.running_list.extend(lst[:num_seqs_to_add])

View File

@ -58,7 +58,7 @@ async def generate(request: Request) -> Response:
# Streaming case
def stream_results():
for request_output in results:
ret = {"text": request_output}
ret = {"text": request_output[len(prompt) :]}
yield (json.dumps(ret) + "\0").encode("utf-8")
if stream:
@ -71,7 +71,7 @@ async def generate(request: Request) -> Response:
# Abort the request if the client disconnects.
engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
final_output = request_output[len(prompt) :]
assert final_output is not None
ret = {"text": final_output}
@ -81,11 +81,15 @@ async def generate(request: Request) -> Response:
@app.post("/v1/completion")
async def create_completion(request: Request):
request_dict = await request.json()
stream = request_dict.pop("stream", False)
generation_config = get_generation_config(request_dict)
generator = await completion_serving.create_completion(request, generation_config)
output = tokenizer.decode(generator.output_token_id)
ret = {"request_id": generator.request_id, "text": output}
return ret
result = await completion_serving.create_completion(request, generation_config)
ret = {"request_id": result.request_id, "text": result.output}
if stream:
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
else:
return JSONResponse(content=ret)
def get_generation_config(request):

View File

@ -18,18 +18,17 @@ class CompletionServing:
async def create_completion(self, request, generation_config):
request_dict = await request.json()
request_id = id_generator()
prompt = request_dict.pop("prompt")
# it is not a intuitive way
self.engine.engine.generation_config = generation_config
result_generator = self.engine.generate(request_id, prompt=prompt)
final_res = None
async for res in result_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return {"error_msg": "Client disconnected"}
final_res = res
raise RuntimeError("Client disconnected")
final_res = await result_generator
return final_res

View File

@ -64,6 +64,7 @@ class Sequence:
eos_token_id (int): The eos token id for this inference process.
pad_token_id (int): The pad token id for this inference process.
max_output_len (int): Maximum output length.
output(str): The output of sequence
"""
request_id: int
@ -74,6 +75,7 @@ class Sequence:
eos_token_id: int
pad_token_id: int
max_output_len: int = 256
output: str = None
def __post_init__(self):
self.output_token_id = []

View File

@ -635,7 +635,7 @@ def decoding_fused_rotary_embedding(
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0) == v.size(0)
assert q.size(1) == k.size(1) == v.size(1)
assert k.size(1) == v.size(1)
assert k_cache.size(-1) == v_cache.size(-1)
if head_dim >= 1024:

View File

@ -0,0 +1,30 @@
from locust import HttpUser, between, tag, task
class QuickstartUser(HttpUser):
wait_time = between(1, 5)
@tag("online-generation")
@task(5)
def completion(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
@tag("online-generation")
@task(5)
def completion_streaming(self):
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
@tag("offline-generation")
@task(5)
def generate_stream(self):
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
@tag("offline-generation")
@task(5)
def generate(self):
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"})
@tag("online-generation", "offline-generation")
@task
def get_models(self):
self.client.get("/v0/models")

View File

@ -0,0 +1,24 @@
#!/bin/bash
#argument1: model_path
# launch server
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
echo "Model Path: $model_path"
echo "Starting server..."
python -m colossalai.inference.server.api_server --model $model_path &
SERVER_PID=$!
# waiting time
sleep 60
# Run Locust
echo "Starting Locust..."
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
# kill Server
echo "Stopping server..."
kill $SERVER_PID
echo "Test and server shutdown completely"

View File

@ -1,98 +0,0 @@
import argparse
import torch
import torch.distributed as dist
from transformers import LlamaForCausalLM, LlamaTokenizer
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.inference import InferenceEngine
from colossalai.testing import spawn
INPUT_TEXTS = [
"What is the longest river in the world?",
"Explain the difference between process and thread in compouter science.",
]
def run_inference(args):
llama_model_path = args.model_path
llama_tokenize_path = args.tokenizer_path or args.model_path
max_input_len = args.max_input_len
max_output_len = args.max_output_len
max_batch_size = args.batch_size
micro_batch_size = args.micro_batch_size
tp_size = args.tp_size
pp_size = args.pp_size
rank = dist.get_rank()
tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left")
tokenizer.pad_token_id = tokenizer.eos_token_id
if args.quant is None:
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id)
elif args.quant == "gptq":
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
llama_model_path, inject_fused_attention=False, device=torch.cuda.current_device()
)
elif args.quant == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
model = SmoothLlamaForCausalLM.from_quantized(llama_model_path, model_basename=args.smoothquant_base_name)
model = model.cuda()
engine = InferenceEngine(
tp_size=tp_size,
pp_size=pp_size,
model=model,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
micro_batch_size=micro_batch_size,
quant=args.quant,
dtype=args.dtype,
)
inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()}
outputs = engine.generate(inputs)
if rank == 0:
output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for input_text, output_text in zip(INPUT_TEXTS, output_texts):
print(f"Input: {input_text}")
print(f"Output: {output_text}")
def run_tp_pipeline_inference(rank, world_size, port, args):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_inference(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True)
parser.add_argument("-i", "--input", default="What is the longest river in the world?")
parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None)
parser.add_argument(
"-q",
"--quant",
type=str,
choices=["gptq", "smoothquant"],
default=None,
help="quantization type: 'gptq' or 'smoothquant'",
)
parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name")
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size")
parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length")
parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length")
parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size")
parser.add_argument("--dtype", default="fp16", type=str)
args = parser.parse_args()
spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)

View File

@ -0,0 +1,89 @@
import random
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def generate_inputs(num_sequences, min_length, max_length):
sequences = []
for _ in range(num_sequences):
length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item()
# generating randomly lengthed sequences
sequence = torch.randint(10, 30000, size=(length,))
sequences.append(sequence)
return sequences
@parameterize(
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
)
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval()
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
if use_engine:
inference_config = InferenceConfig(
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == max_output_len
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=max_output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert len(outputs) == 10 * max_batch_size
@parameterize("prompt_template", [None, "llama"])
def check_continuous_batching(prompt_template):
check_inference_engine(use_engine=True, prompt_template=prompt_template)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_continuous_batching()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_continuous_batching():
spawn(run_dist, 1)
if __name__ == "__main__":
test_continuous_batching()