import asyncio from colossalai.inference.dynamic_batching.ray_dist_init import Driver from .dynamic_batching.io_struct import RequestOutput from .dynamic_batching.sampling_params import SamplingParams class RequestTracker: """ A class for trace down all the requests, abstraction for async """ def __init__(self) -> None: self._requests: asyncio.Queue[str] = asyncio.Queue() self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() self.new_requests_event = None def __contains__(self, item): return item in self._requests def init_event(self): self.new_requests_event = asyncio.Event() def add_request(self, request_id: str): """Add a request to be sent to the engine on the next background loop iteration.""" self._requests.put_nowait(request_id) self.new_requests_event.set() # NOTE: we may find a better way to clear this event def add_stop(self): """ Add a StopIteration flag to stop async generator. """ self._finished_requests.put_nowait(StopIteration) self.new_requests_event.clear() def process_request_output(self, request_output: RequestOutput) -> None: """Process a request output from the engine.""" self._finished_requests.put_nowait(request_output) async def wait_for_new_requests(self): await self.new_requests_event.wait() def __aiter__(self): return self async def __anext__(self) -> RequestOutput: result = await self._finished_requests.get() # print("result of ", result) if result is StopIteration: raise StopAsyncIteration return result class Async_Engine: """ Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager Background loop: inference reqs in waiting list (Listen) Request Tracker: manage incoming requests and restore finished ones Generate: exposed func for add new input and return finished ones """ def __init__( self, router_config, engine_config, start_engine_loop: bool = True, ) -> None: self.driver = Driver(router_config=router_config, engine_config=engine_config) self.background_loop = None self.start_engine_loop = start_engine_loop self._request_tracker = RequestTracker() def _step(self): """ Logic for handling requests """ request_outputs = self.driver.step() if request_outputs is not None: for request_output in request_outputs: self._request_tracker.process_request_output(request_output) self._request_tracker.add_stop() def abort_request(self, request_id: str): self.driver.abort(request_id) def _has_requests_in_progress(self): return self.driver.is_running() async def run_loop_fwd(self): has_requests_in_progress = self._has_requests_in_progress() while True: if not has_requests_in_progress: await self._request_tracker.wait_for_new_requests() self._step() await asyncio.sleep(0) @property def is_running(self): return self.background_loop is not None and not self.background_loop.done() def start_background_loop(self): if self.is_running: raise RuntimeError("Background loop is already running.") self._request_tracker.init_event() self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) self.background_loop = asyncio.shield(self.background_loop_unshielded) async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): self.driver.add_input(request_id, prompt, sampling_params) self._request_tracker.add_request(request_id) async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): """ The only exposed func, adding new request and return a async generator that yields the existing results. """ try: if not self.is_running: self.start_background_loop() await self.add_request(request_id, prompt, sampling_params) async for request_output in self._request_tracker: yield request_output except (Exception, asyncio.CancelledError) as e: # If there is an exception or coroutine is cancelled, abort the request. self.abort_request(request_id) raise e