mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
1.0 KiB
34 lines
1.0 KiB
import asyncio |
|
|
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine |
|
|
|
from .utils import id_generator |
|
|
|
|
|
class CompletionServing: |
|
def __init__(self, engine: AsyncInferenceEngine, served_model: str): |
|
self.engine = engine |
|
self.served_model = served_model |
|
|
|
try: |
|
asyncio.get_running_loop() |
|
except RuntimeError: |
|
pass |
|
|
|
async def create_completion(self, request, generation_config): |
|
request_dict = await request.json() |
|
request_id = id_generator() |
|
|
|
prompt = request_dict.pop("prompt") |
|
|
|
# it is not a intuitive way |
|
self.engine.engine.generation_config = generation_config |
|
result_generator = self.engine.generate(request_id, prompt=prompt, generation_config=generation_config) |
|
|
|
if await request.is_disconnected(): |
|
# Abort the request if the client disconnects. |
|
await self.engine.abort(request_id) |
|
raise RuntimeError("Client disconnected") |
|
|
|
final_res = await result_generator |
|
return final_res
|
|
|