mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Finish Online Serving Test, add streaming output api, continuous batching test and example (#5432)
* finish online test and add examples * fix test_contionus_batching * fix some bugs * fix bash * fix * fix inference * finish revision * fix typos * revisionpull/5584/head
parent
1572af2432
commit
3d211ff81b
|
@ -1,13 +1,13 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
# CLI logger
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger("colossalai-inference")
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
|
||||
|
@ -18,54 +18,45 @@ def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTrac
|
|||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc
|
||||
raise AsyncEngineDeadError(msg)
|
||||
raise RuntimeError(msg + " See stack trace above for the actual cause.") from exc
|
||||
raise RuntimeError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
class RequstStream:
|
||||
"""A stream of Output for a request that can be
|
||||
iterated over asynchronously."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
def __init__(self, request_id: int) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
self._future = asyncio.Future()
|
||||
|
||||
def put(self, item) -> None:
|
||||
if self._finished:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
def set_result(self, result) -> None:
|
||||
"""Set final result and signal taht it's ready"""
|
||||
if not self._future.done():
|
||||
self._future.set_result(result)
|
||||
|
||||
def finish(self) -> None:
|
||||
self._queue.put_nowait(StopIteration)
|
||||
self._finished = True
|
||||
async def get_result(self):
|
||||
"""Wait for the result to be set and return it."""
|
||||
return await self._future
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self._finished
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
result = await self._queue.get()
|
||||
if result is StopIteration:
|
||||
raise StopAsyncIteration
|
||||
elif isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
"""Check if the stream has finished by checking if the future is done."""
|
||||
return self._future.done()
|
||||
|
||||
|
||||
class RequestTracker:
|
||||
"""Synchronous abstraction for tracking requests."""
|
||||
class Tracer:
|
||||
"""
|
||||
Recording new requests and finished requests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._request_streams: Dict[int, RequstStream] = {}
|
||||
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[RequstStream, dict]] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
|
||||
def __contains__(self, item):
|
||||
|
@ -79,19 +70,21 @@ class RequestTracker:
|
|||
Propagate an exception to request streams (all if request_id is None).
|
||||
"""
|
||||
if request_id is not None:
|
||||
self._request_streams[request_id].put(exc)
|
||||
self._request_streams[request_id].set_result(exc)
|
||||
else:
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
stream.set_result(exc)
|
||||
|
||||
def process_finished_request(self, finished_request) -> None:
|
||||
"""Process a finished request from the engine."""
|
||||
request_id = finished_request.request_id
|
||||
|
||||
self._request_streams[request_id].put(finished_request)
|
||||
try:
|
||||
self._request_streams[request_id].set_result(finished_request)
|
||||
except:
|
||||
raise RuntimeError(f"The request_id {request_id} is not found in our stream, please check")
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream:
|
||||
def add_request(self, request_id: int, **engine_add_request_kwargs) -> RequstStream:
|
||||
"""
|
||||
Add a request to be sent to the engine on the next background
|
||||
loop iteration.
|
||||
|
@ -99,7 +92,7 @@ class RequestTracker:
|
|||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
stream = RequstStream(request_id)
|
||||
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
@ -109,7 +102,7 @@ class RequestTracker:
|
|||
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
Logger.info(f"Aborted request {request_id}.")
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
|
||||
|
@ -117,7 +110,7 @@ class RequestTracker:
|
|||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
self._request_streams[request_id].set_result(None)
|
||||
|
||||
def get_new_requests(self):
|
||||
"""
|
||||
|
@ -134,30 +127,6 @@ class RequestTracker:
|
|||
|
||||
return new_requests
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[int] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
await self.new_requests_event.wait()
|
||||
|
||||
|
@ -194,6 +163,8 @@ class _AsyncInferenceEngine(InferenceEngine):
|
|||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
# Return: List[Sequence]
|
||||
finished_sequences = self.request_handler.update()
|
||||
for sequence in finished_sequences:
|
||||
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||
|
||||
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
|
||||
|
||||
|
@ -216,7 +187,7 @@ class AsyncInferenceEngine:
|
|||
# reference to the unshielded loop
|
||||
self._background_loop_unshielded = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker = RequestTracker()
|
||||
self._request_tracer = Tracer()
|
||||
|
||||
@property
|
||||
def background_loop_status(self):
|
||||
|
@ -226,11 +197,11 @@ class AsyncInferenceEngine:
|
|||
if self.background_loop_status:
|
||||
raise RuntimeError("Existing loop is running")
|
||||
|
||||
self._request_tracker.init_event()
|
||||
self._request_tracer.init_event()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_raise_exception_on_finish, request_tracker=self._request_tracker)
|
||||
partial(_raise_exception_on_finish, request_tracker=self._request_tracer)
|
||||
)
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
|
@ -243,12 +214,13 @@ class AsyncInferenceEngine:
|
|||
|
||||
Returns True if there are in-progress requests.
|
||||
"""
|
||||
new_requests = self._request_tracker.get_new_requests()
|
||||
new_requests = self._request_tracer.get_new_requests()
|
||||
for new_request in new_requests:
|
||||
self.engine.add_single_request(**new_request)
|
||||
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
||||
|
||||
for seq in newly_finished_seqs:
|
||||
self._request_tracker.process_finished_request(seq)
|
||||
self._request_tracer.process_finished_request(seq)
|
||||
|
||||
return has_running_requests
|
||||
|
||||
|
@ -264,13 +236,13 @@ class AsyncInferenceEngine:
|
|||
return self._abort(request_id)
|
||||
|
||||
def _abort(self, request_id: int):
|
||||
self._request_tracker.abort_request(request_id)
|
||||
self._request_tracer.abort_request(request_id)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
processing_requests = False
|
||||
while True:
|
||||
if not processing_requests:
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
await self._request_tracer.wait_for_new_requests()
|
||||
processing_requests = await self.step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
@ -279,7 +251,7 @@ class AsyncInferenceEngine:
|
|||
request_id: int,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
) -> AsyncStream:
|
||||
) -> RequstStream:
|
||||
"""
|
||||
Add a request to the background tracker(waitting queue), start the background loop if needed.
|
||||
"""
|
||||
|
@ -288,7 +260,7 @@ class AsyncInferenceEngine:
|
|||
self.start_background_loop()
|
||||
else:
|
||||
raise RuntimeError("Background loop is not running.")
|
||||
stream = self._request_tracker.add_request(
|
||||
stream = self._request_tracer.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
|
@ -308,8 +280,7 @@ class AsyncInferenceEngine:
|
|||
"""
|
||||
try:
|
||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
return await stream.get_result()
|
||||
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
|
|
|
@ -535,10 +535,10 @@ class InferenceEngine:
|
|||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
print(prompts_token_ids)
|
||||
|
||||
if isinstance(prompts_token_ids, list):
|
||||
pass
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
|
@ -565,7 +565,6 @@ class InferenceEngine:
|
|||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
|
@ -646,8 +645,6 @@ class InferenceEngine:
|
|||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
self.request_handler.append_next_tokens(next_tokens)
|
||||
|
||||
print("in step", logits)
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
|
|
@ -209,6 +209,7 @@ class RequestHandler:
|
|||
break
|
||||
|
||||
num_seqs_to_add = min(len(lst), self.max_batch_size - self.running_list.total_seq_num)
|
||||
# for now the recycle logic is not working
|
||||
remove_list.extend(lst[:num_seqs_to_add])
|
||||
self.running_list.extend(lst[:num_seqs_to_add])
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ async def generate(request: Request) -> Response:
|
|||
# Streaming case
|
||||
def stream_results():
|
||||
for request_output in results:
|
||||
ret = {"text": request_output}
|
||||
ret = {"text": request_output[len(prompt) :]}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
|
@ -71,7 +71,7 @@ async def generate(request: Request) -> Response:
|
|||
# Abort the request if the client disconnects.
|
||||
engine.abort(request_id)
|
||||
return Response(status_code=499)
|
||||
final_output = request_output
|
||||
final_output = request_output[len(prompt) :]
|
||||
|
||||
assert final_output is not None
|
||||
ret = {"text": final_output}
|
||||
|
@ -81,11 +81,15 @@ async def generate(request: Request) -> Response:
|
|||
@app.post("/v1/completion")
|
||||
async def create_completion(request: Request):
|
||||
request_dict = await request.json()
|
||||
stream = request_dict.pop("stream", False)
|
||||
generation_config = get_generation_config(request_dict)
|
||||
generator = await completion_serving.create_completion(request, generation_config)
|
||||
output = tokenizer.decode(generator.output_token_id)
|
||||
ret = {"request_id": generator.request_id, "text": output}
|
||||
return ret
|
||||
result = await completion_serving.create_completion(request, generation_config)
|
||||
|
||||
ret = {"request_id": result.request_id, "text": result.output}
|
||||
if stream:
|
||||
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(content=ret)
|
||||
|
||||
|
||||
def get_generation_config(request):
|
||||
|
|
|
@ -18,18 +18,17 @@ class CompletionServing:
|
|||
async def create_completion(self, request, generation_config):
|
||||
request_dict = await request.json()
|
||||
request_id = id_generator()
|
||||
|
||||
prompt = request_dict.pop("prompt")
|
||||
|
||||
# it is not a intuitive way
|
||||
self.engine.engine.generation_config = generation_config
|
||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||
|
||||
final_res = None
|
||||
async for res in result_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return {"error_msg": "Client disconnected"}
|
||||
final_res = res
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
raise RuntimeError("Client disconnected")
|
||||
|
||||
final_res = await result_generator
|
||||
return final_res
|
||||
|
|
|
@ -64,6 +64,7 @@ class Sequence:
|
|||
eos_token_id (int): The eos token id for this inference process.
|
||||
pad_token_id (int): The pad token id for this inference process.
|
||||
max_output_len (int): Maximum output length.
|
||||
output(str): The output of sequence
|
||||
"""
|
||||
|
||||
request_id: int
|
||||
|
@ -74,6 +75,7 @@ class Sequence:
|
|||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
max_output_len: int = 256
|
||||
output: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.output_token_id = []
|
||||
|
|
|
@ -635,7 +635,7 @@ def decoding_fused_rotary_embedding(
|
|||
"""
|
||||
q_total_tokens, q_head_num, head_dim = q.shape
|
||||
assert q.size(0) == k.size(0) == v.size(0)
|
||||
assert q.size(1) == k.size(1) == v.size(1)
|
||||
assert k.size(1) == v.size(1)
|
||||
assert k_cache.size(-1) == v_cache.size(-1)
|
||||
|
||||
if head_dim >= 1024:
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
from locust import HttpUser, between, tag, task
|
||||
|
||||
|
||||
class QuickstartUser(HttpUser):
|
||||
wait_time = between(1, 5)
|
||||
|
||||
@tag("online-generation")
|
||||
@task(5)
|
||||
def completion(self):
|
||||
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "False"})
|
||||
|
||||
@tag("online-generation")
|
||||
@task(5)
|
||||
def completion_streaming(self):
|
||||
self.client.post("/v1/completion", json={"prompt": "hello, who are you? ", "stream": "True"})
|
||||
|
||||
@tag("offline-generation")
|
||||
@task(5)
|
||||
def generate_stream(self):
|
||||
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "True"})
|
||||
|
||||
@tag("offline-generation")
|
||||
@task(5)
|
||||
def generate(self):
|
||||
self.client.post("/generate", json={"prompt": "Can you help me? ", "stream": "False"})
|
||||
|
||||
@tag("online-generation", "offline-generation")
|
||||
@task
|
||||
def get_models(self):
|
||||
self.client.get("/v0/models")
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
|
||||
#argument1: model_path
|
||||
|
||||
# launch server
|
||||
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
|
||||
echo "Model Path: $model_path"
|
||||
echo "Starting server..."
|
||||
python -m colossalai.inference.server.api_server --model $model_path &
|
||||
SERVER_PID=$!
|
||||
|
||||
# waiting time
|
||||
sleep 60
|
||||
|
||||
# Run Locust
|
||||
echo "Starting Locust..."
|
||||
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
|
||||
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
||||
|
||||
# kill Server
|
||||
echo "Stopping server..."
|
||||
kill $SERVER_PID
|
||||
|
||||
echo "Test and server shutdown completely"
|
|
@ -1,98 +0,0 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.testing import spawn
|
||||
|
||||
INPUT_TEXTS = [
|
||||
"What is the longest river in the world?",
|
||||
"Explain the difference between process and thread in compouter science.",
|
||||
]
|
||||
|
||||
|
||||
def run_inference(args):
|
||||
llama_model_path = args.model_path
|
||||
llama_tokenize_path = args.tokenizer_path or args.model_path
|
||||
|
||||
max_input_len = args.max_input_len
|
||||
max_output_len = args.max_output_len
|
||||
max_batch_size = args.batch_size
|
||||
micro_batch_size = args.micro_batch_size
|
||||
tp_size = args.tp_size
|
||||
pp_size = args.pp_size
|
||||
rank = dist.get_rank()
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
if args.quant is None:
|
||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id)
|
||||
elif args.quant == "gptq":
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
model = AutoGPTQForCausalLM.from_quantized(
|
||||
llama_model_path, inject_fused_attention=False, device=torch.cuda.current_device()
|
||||
)
|
||||
elif args.quant == "smoothquant":
|
||||
from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
|
||||
|
||||
model = SmoothLlamaForCausalLM.from_quantized(llama_model_path, model_basename=args.smoothquant_base_name)
|
||||
model = model.cuda()
|
||||
|
||||
engine = InferenceEngine(
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
max_input_len=max_input_len,
|
||||
max_output_len=max_output_len,
|
||||
max_batch_size=max_batch_size,
|
||||
micro_batch_size=micro_batch_size,
|
||||
quant=args.quant,
|
||||
dtype=args.dtype,
|
||||
)
|
||||
|
||||
inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
|
||||
inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()}
|
||||
outputs = engine.generate(inputs)
|
||||
|
||||
if rank == 0:
|
||||
output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
for input_text, output_text in zip(INPUT_TEXTS, output_texts):
|
||||
print(f"Input: {input_text}")
|
||||
print(f"Output: {output_text}")
|
||||
|
||||
|
||||
def run_tp_pipeline_inference(rank, world_size, port, args):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_inference(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True)
|
||||
parser.add_argument("-i", "--input", default="What is the longest river in the world?")
|
||||
parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--quant",
|
||||
type=str,
|
||||
choices=["gptq", "smoothquant"],
|
||||
default=None,
|
||||
help="quantization type: 'gptq' or 'smoothquant'",
|
||||
)
|
||||
parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name")
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size")
|
||||
parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length")
|
||||
parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length")
|
||||
parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size")
|
||||
parser.add_argument("--dtype", default="fp16", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)
|
|
@ -0,0 +1,89 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def generate_inputs(num_sequences, min_length, max_length):
|
||||
sequences = []
|
||||
for _ in range(num_sequences):
|
||||
length = torch.randint(low=min_length, high=max_length + 1, size=(1,)).item()
|
||||
# generating randomly lengthed sequences
|
||||
sequence = torch.randint(10, 30000, size=(length,))
|
||||
sequences.append(sequence)
|
||||
return sequences
|
||||
|
||||
|
||||
@parameterize(
|
||||
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
|
||||
)
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||
model = model.eval()
|
||||
|
||||
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(
|
||||
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
|
||||
)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == max_output_len
|
||||
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
# apply prompt template
|
||||
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
|
||||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
outputs = model.generate(inputs, generation_config=generation_config)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
assert len(outputs) == 10 * max_batch_size
|
||||
|
||||
|
||||
@parameterize("prompt_template", [None, "llama"])
|
||||
def check_continuous_batching(prompt_template):
|
||||
check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_continuous_batching()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_continuous_batching():
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_continuous_batching()
|
Loading…
Reference in New Issue