mirror of https://github.com/hpcaitech/ColossalAI
resolve rebase conflicts on Branch feat/online-serving
parent
61a1b2e798
commit
bc9063adf1
|
@ -527,16 +527,9 @@ class InferenceEngine:
|
||||||
List[str]: Inference result returned by one generation.
|
List[str]: Inference result returned by one generation.
|
||||||
"""
|
"""
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
<<<<<<< HEAD
|
|
||||||
|
|
||||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||||
prompts = [prompts]
|
prompts = [prompts]
|
||||||
request_ids = [request_ids]
|
request_ids = [request_ids]
|
||||||
=======
|
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
|
||||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
|
||||||
>>>>>>> [Inference] Fix bugs and docs for feat/online-server (#5598)
|
|
||||||
|
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||||
self.add_request(
|
self.add_request(
|
||||||
|
@ -545,7 +538,7 @@ class InferenceEngine:
|
||||||
prompts_token_ids=prompts_token_ids,
|
prompts_token_ids=prompts_token_ids,
|
||||||
**gen_config_dict,
|
**gen_config_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_seqs_list = []
|
output_seqs_list = []
|
||||||
total_tokens_list = []
|
total_tokens_list = []
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Online Service
|
||||||
|
Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and
|
||||||
|
you can simply construct a server with both completion and chat functionalities. For now we only support `Llama` model, we will fullfill
|
||||||
|
the blank quickly.
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
```bash
|
||||||
|
# First, Lauch an API locally.
|
||||||
|
python3 -m colossalai.inference.server.api_server --model path of your llama2 model --chat_template "{% for message in messages %}
|
||||||
|
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
||||||
|
|
||||||
|
|
||||||
|
# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
|
||||||
|
|
||||||
|
# For completion service, you can invoke it
|
||||||
|
curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}'
|
||||||
|
|
||||||
|
# For chat service, you can invoke it
|
||||||
|
curl -X POST http://127.0.0.1:8000/completion -H 'Content-Type: application/json' -d '{"converation":
|
||||||
|
[{"role": "system", "content": "you are a helpful assistant"},
|
||||||
|
{"role": "user", "content": "what is 1+1?"},],
|
||||||
|
"stream": "False",}'
|
||||||
|
# If you just want to test a simple generation, turn to generate api
|
||||||
|
curl -X POST http://127.0.0.1:8000/generate -H 'Content-Type: application/json' -d '{"prompt":"hello, who are you? ","stream":"False"}'
|
||||||
|
|
||||||
|
```
|
||||||
|
We also support streaming output, simply change the `stream` to `True` in the request body.
|
|
@ -598,8 +598,6 @@ def decoding_fused_rotary_embedding(
|
||||||
"""
|
"""
|
||||||
q_total_tokens, q_head_num, head_dim = q.shape
|
q_total_tokens, q_head_num, head_dim = q.shape
|
||||||
assert q.size(0) == k.size(0) == v.size(0)
|
assert q.size(0) == k.size(0) == v.size(0)
|
||||||
assert k.size(1) == v.size(1)
|
|
||||||
assert k_cache.size(-1) == v_cache.size(-1)
|
|
||||||
|
|
||||||
if head_dim >= 512:
|
if head_dim >= 512:
|
||||||
num_warps = 16
|
num_warps = 16
|
||||||
|
|
|
@ -89,7 +89,7 @@ def check_continuous_batching(prompt_template):
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
check_continuous_batching()
|
check_continuous_batching()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue