mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
4.5 KiB
132 lines
4.5 KiB
import asyncio |
|
|
|
from colossalai.inference.dynamic_batching.ray_dist_init import Driver |
|
|
|
from .dynamic_batching.io_struct import RequestOutput |
|
from .dynamic_batching.sampling_params import SamplingParams |
|
|
|
|
|
class RequestTracker: |
|
""" |
|
A class for trace down all the requests, abstraction for async |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self._requests: asyncio.Queue[str] = asyncio.Queue() |
|
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() |
|
self.new_requests_event = None |
|
|
|
def __contains__(self, item): |
|
return item in self._requests |
|
|
|
def init_event(self): |
|
self.new_requests_event = asyncio.Event() |
|
|
|
def add_request(self, request_id: str): |
|
"""Add a request to be sent to the engine on the next background |
|
loop iteration.""" |
|
self._requests.put_nowait(request_id) |
|
self.new_requests_event.set() # NOTE: we may find a better way to clear this event |
|
|
|
def add_stop(self): |
|
""" |
|
Add a StopIteration flag to stop async generator. |
|
""" |
|
self._finished_requests.put_nowait(StopIteration) |
|
self.new_requests_event.clear() |
|
|
|
def process_request_output(self, request_output: RequestOutput) -> None: |
|
"""Process a request output from the engine.""" |
|
self._finished_requests.put_nowait(request_output) |
|
|
|
async def wait_for_new_requests(self): |
|
await self.new_requests_event.wait() |
|
|
|
def __aiter__(self): |
|
return self |
|
|
|
async def __anext__(self) -> RequestOutput: |
|
result = await self._finished_requests.get() |
|
# print("result of ", result) |
|
if result is StopIteration: |
|
raise StopAsyncIteration |
|
return result |
|
|
|
|
|
class Async_Engine: |
|
""" |
|
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager |
|
Background loop: inference reqs in waiting list (Listen) |
|
Request Tracker: manage incoming requests and restore finished ones |
|
Generate: exposed func for add new input and return finished ones |
|
""" |
|
|
|
def __init__( |
|
self, |
|
router_config, |
|
engine_config, |
|
start_engine_loop: bool = True, |
|
) -> None: |
|
self.driver = Driver(router_config=router_config, engine_config=engine_config) |
|
self.background_loop = None |
|
self.start_engine_loop = start_engine_loop |
|
self._request_tracker = RequestTracker() |
|
|
|
def _step(self): |
|
""" |
|
Logic for handling requests |
|
""" |
|
request_outputs = self.driver.step() |
|
if request_outputs is not None: |
|
for request_output in request_outputs: |
|
self._request_tracker.process_request_output(request_output) |
|
self._request_tracker.add_stop() |
|
|
|
def abort_request(self, request_id: str): |
|
self.driver.abort(request_id) |
|
|
|
def _has_requests_in_progress(self): |
|
return self.driver.is_running() |
|
|
|
async def run_loop_fwd(self): |
|
has_requests_in_progress = self._has_requests_in_progress() |
|
while True: |
|
if not has_requests_in_progress: |
|
await self._request_tracker.wait_for_new_requests() |
|
self._step() |
|
await asyncio.sleep(0) |
|
|
|
@property |
|
def is_running(self): |
|
return self.background_loop is not None and not self.background_loop.done() |
|
|
|
def start_background_loop(self): |
|
if self.is_running: |
|
raise RuntimeError("Background loop is already running.") |
|
|
|
self._request_tracker.init_event() |
|
|
|
self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) |
|
self.background_loop = asyncio.shield(self.background_loop_unshielded) |
|
|
|
async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): |
|
self.driver.add_input(request_id, prompt, sampling_params) |
|
self._request_tracker.add_request(request_id) |
|
|
|
async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): |
|
""" |
|
The only exposed func, adding new request and return a async generator that yields the existing results. |
|
""" |
|
try: |
|
if not self.is_running: |
|
self.start_background_loop() |
|
|
|
await self.add_request(request_id, prompt, sampling_params) |
|
|
|
async for request_output in self._request_tracker: |
|
yield request_output |
|
|
|
except (Exception, asyncio.CancelledError) as e: |
|
# If there is an exception or coroutine is cancelled, abort the request. |
|
self.abort_request(request_id) |
|
raise e
|
|
|