ColossalAI/colossalai/inference/async_engine.py

134 lines
4.5 KiB
Python

import asyncio
from colossalai.inference.dynamic_batching.ray_dist_init import Driver
from .dynamic_batching.io_struct import RequestOutput
from .dynamic_batching.sampling_params import SamplingParams
class RequestTracker:
"""
A class for trace down all the requests, abstraction for async
"""
def __init__(self) -> None:
self._requests: asyncio.Queue[str] = asyncio.Queue()
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item):
return item in self._requests
def init_event(self):
self.new_requests_event = asyncio.Event()
def add_request(self, request_id: str):
"""Add a request to be sent to the engine on the next background
loop iteration."""
self._requests.put_nowait(request_id)
self.new_requests_event.set() # NOTE: we may find a better way to clear this event
def add_stop(self):
"""
Add a StopIteration flag to stop async generator.
"""
self._finished_requests.put_nowait(StopIteration)
self.new_requests_event.clear()
def process_request_output(self, request_output: RequestOutput) -> None:
"""Process a request output from the engine."""
self._finished_requests.put_nowait(request_output)
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
def __aiter__(self):
return self
async def __anext__(self) -> RequestOutput:
result = await self._finished_requests.get()
# print("result of ", result)
if result is StopIteration:
raise StopAsyncIteration
return result
class Async_Engine:
"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)
Request Tracker: manage incoming requests and restore finished ones
Generate: exposed func for add new input and return finished ones
"""
def __init__(
self,
router_config,
engine_config,
start_engine_loop: bool = True,
) -> None:
self.driver = Driver(router_config=router_config, engine_config=engine_config)
self.background_loop = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
def _step(self):
"""
Logic for handling requests
"""
request_outputs = self.driver.step()
if request_outputs is not None:
for request_output in request_outputs:
self._request_tracker.process_request_output(request_output)
self._request_tracker.add_stop()
def abort_request(self, request_id: str):
self.driver.abort(request_id)
def _has_requests_in_progress(self):
return self.driver.is_running()
async def run_loop_fwd(self):
has_requests_in_progress = self._has_requests_in_progress()
while True:
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
self._step()
await asyncio.sleep(0)
@property
def is_running(self):
return self.background_loop is not None and not self.background_loop.done()
def start_background_loop(self):
if self.is_running:
raise RuntimeError("Background loop is already running.")
self._request_tracker.init_event()
self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
self.background_loop = asyncio.shield(self.background_loop_unshielded)
async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
self.driver.add_input(request_id, prompt, sampling_params)
self._request_tracker.add_request(request_id)
async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
"""
The only exposed func, adding new request and return a async generator that yields the existing results.
"""
try:
if not self.is_running:
self.start_background_loop()
await self.add_request(request_id, prompt, sampling_params)
async for request_output in self._request_tracker:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the request.
self.abort_request(request_id)
raise e