mirror of https://github.com/hpcaitech/ColossalAI
36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
![]() |
import asyncio
|
||
|
|
||
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||
|
|
||
|
from .utils import id_generator
|
||
|
|
||
|
|
||
|
class CompletionServing:
|
||
|
def __init__(self, engine: AsyncInferenceEngine, served_model: str):
|
||
|
self.engine = engine
|
||
|
self.served_model = served_model
|
||
|
|
||
|
try:
|
||
|
asyncio.get_running_loop()
|
||
|
except RuntimeError:
|
||
|
pass
|
||
|
|
||
|
async def create_completion(self, request, generation_config):
|
||
|
request_dict = await request.json()
|
||
|
request_id = id_generator()
|
||
|
prompt = request_dict.pop("prompt")
|
||
|
|
||
|
# it is not a intuitive way
|
||
|
self.engine.engine.generation_config = generation_config
|
||
|
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||
|
|
||
|
final_res = None
|
||
|
async for res in result_generator:
|
||
|
if await request.is_disconnected():
|
||
|
# Abort the request if the client disconnects.
|
||
|
await self.engine.abort(request_id)
|
||
|
return {"error_msg": "Client disconnected"}
|
||
|
final_res = res
|
||
|
|
||
|
return final_res
|