mirror of https://github.com/hpcaitech/ColossalAI
[Infer] Serving example w/ ray-serve (multiple GPU case) (#4841)
* fix imports * add ray-serve with Colossal-Infer tp * trivial: send requests script * add README * fix worker port * fix readme * use app builder and autoscaling * trivial: input args * clean code; revise readme * testci (skip example test) * use auto model/tokenizer * revert imports fix (fixed in other PRs)pull/4849/head
parent
3a74eb4b3a
commit
573f270537
|
@ -0,0 +1,151 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Any, List, Union
|
||||
|
||||
import ray
|
||||
import ray.util.collective as collective
|
||||
import starlette
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from ray import serve
|
||||
from ray.serve import Application
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import free_port
|
||||
|
||||
ray_serve_logger = logging.getLogger("ray.serve")
|
||||
|
||||
|
||||
class GenConfigArgs(BaseModel):
|
||||
"""Config for generation"""
|
||||
|
||||
path: str
|
||||
tp_size: int = 2
|
||||
max_batch_size: int = 4
|
||||
max_input_len: int = 128
|
||||
max_output_len: int = 32
|
||||
|
||||
|
||||
def log_cuda_info(scope_name: str):
|
||||
ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}")
|
||||
ray_serve_logger.info(
|
||||
f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}"
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
ray_serve_logger.info(
|
||||
f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}"
|
||||
)
|
||||
else:
|
||||
ray_serve_logger.info(f" {scope_name}: cuda is not available!")
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class Worker:
|
||||
def __init__(self, model_path: str, tp_size: int, max_batch_size: int, max_input_len: int, max_output_len: int):
|
||||
log_cuda_info("Worker.init")
|
||||
self.tp_size = tp_size
|
||||
self.model_path = model_path
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_input_len = max_input_len
|
||||
self.max_output_len = max_output_len
|
||||
|
||||
def setup(self, world_size, rank, port):
|
||||
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
|
||||
collective.init_collective_group(world_size, rank, "nccl", "default")
|
||||
# initialize and set distributed environment
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
|
||||
log_cuda_info("Worker.setup")
|
||||
|
||||
# Load model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
|
||||
self.infer_engine = TPInferEngine(
|
||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||
)
|
||||
self.generate_kwargs = dict(max_new_tokens=self.max_output_len, do_sample=False)
|
||||
|
||||
return True
|
||||
|
||||
def generate(self, text: Union[str, List[str]]) -> str:
|
||||
input_tokens = self.tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True)
|
||||
ray_serve_logger.info(f"text: {text},\ninput_tokens: {input_tokens}")
|
||||
|
||||
model_output = self.infer_engine.generate(input_tokens, **self.generate_kwargs)
|
||||
ray_serve_logger.info(f"model_output.shape: {model_output.shape}")
|
||||
|
||||
text_output = []
|
||||
for i in range(len(model_output)):
|
||||
text_output.append(self.tokenizer.decode(model_output[i]))
|
||||
ray_serve_logger.info(f"output: {text_output}")
|
||||
|
||||
return text_output
|
||||
|
||||
|
||||
@serve.deployment(
|
||||
ray_actor_options={"num_cpus": 1, "num_gpus": 0},
|
||||
max_concurrent_queries=5,
|
||||
autoscaling_config={
|
||||
"target_num_ongoing_requests_per_replica": 1,
|
||||
"min_replicas": 1,
|
||||
"initial_replicas": 1,
|
||||
"max_replicas": 1,
|
||||
},
|
||||
)
|
||||
class Driver:
|
||||
def __init__(self, config: GenConfigArgs):
|
||||
log_cuda_info("Driver:init")
|
||||
model_path = config.path
|
||||
tp_size = config.tp_size
|
||||
|
||||
self.num_workers = tp_size
|
||||
self.workers = []
|
||||
init_rets = []
|
||||
|
||||
# Just grab a free port on localhost
|
||||
# NOTE workers in this communication group listen to the same port
|
||||
available_port = free_port()
|
||||
|
||||
for i in range(self.num_workers):
|
||||
worker_name = "worker_idx_{}".format(i)
|
||||
w = Worker.options(name=worker_name).remote(
|
||||
model_path, self.num_workers, config.max_batch_size, config.max_input_len, config.max_output_len
|
||||
)
|
||||
self.workers.append(w)
|
||||
init_rets.append(w.setup.remote(self.num_workers, i, available_port))
|
||||
_options = {
|
||||
"group_name": "default_driver",
|
||||
"world_size": self.num_workers,
|
||||
"ranks": [i for i in range(self.num_workers)],
|
||||
"backend": "nccl",
|
||||
}
|
||||
collective.create_collective_group(self.workers, **_options)
|
||||
_ = ray.get(init_rets)
|
||||
|
||||
# set batch wait delay in seconds and maximum number of sequences in a batch
|
||||
@serve.batch(batch_wait_timeout_s=0.8, max_batch_size=4)
|
||||
async def batch_generate(self, requests: List[str]):
|
||||
ray_serve_logger.info(f"Driver.batch_generate: requests length: {len(requests)}\n requests: {requests}")
|
||||
results = ray.get([w.generate.remote(requests) for w in self.workers])
|
||||
text_res = results[0] # get any one of the copies
|
||||
return text_res
|
||||
|
||||
async def __call__(self, request: starlette.requests.Request) -> Any:
|
||||
return await self.batch_generate(request.query_params["text"])
|
||||
|
||||
|
||||
def app(args: GenConfigArgs) -> Application:
|
||||
print(args)
|
||||
if args.path is None or not os.path.exists(args.path):
|
||||
raise ValueError("Model path not provided or invalid path!")
|
||||
|
||||
return Driver.options(name="Colossal-Inference-Driver").bind(config=args)
|
|
@ -0,0 +1,86 @@
|
|||
# Colossal-Inference with Ray Serve
|
||||
|
||||
This example is used for demonstrating and testing the deployment of Colossal Inference from `colossalai.inference` with [Ray Serve](https://docs.ray.io/en/latest/serve/index.html). It imports inference modules from colossalai and is based on https://github.com/hpcaitech/ColossalAI/tree/a22706337a57dd1c98b95739dd09d98bd55947a0.
|
||||
|
||||
Single-gpu inference as well as multiple-gpu inference (i.e. tensor parallel) serving are supported.
|
||||
|
||||
## Installation
|
||||
|
||||
### Conda Environment
|
||||
```bash
|
||||
# create a new conda env with python 3.8
|
||||
conda create -n ray_test python=3.8.18
|
||||
|
||||
# use torch1.13+cuda11.6
|
||||
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
|
||||
# install ray from wheels
|
||||
pip install -U "ray[default,serve]"
|
||||
|
||||
# install cuda toolkit (e.g. nvcc, etc)
|
||||
conda install -c "nvidia/label/cuda-11.6.2" cuda-toolkit
|
||||
|
||||
# install cuDNN, cuTENSOR, and NCCL
|
||||
conda install -c conda-forge cupy cudnn cutensor nccl cuda-version=11.6
|
||||
|
||||
# install colossalai with PyTorch extensions
|
||||
cd <path_to_ColossalAI_repo>
|
||||
CUDA_EXT=1 pip install -e .
|
||||
|
||||
# install other dependencies
|
||||
pip install triton==2.0.0.dev20221202
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
## Launch Ray Serve and run the app
|
||||
### Method #1. CLI command
|
||||
|
||||
Under the current directory, we could launch the app by the following command:
|
||||
```bash
|
||||
RAY_DEDUP_LOGS=0 serve run Colossal_Inference_rayserve:app path="PATH_TO_YOUR_MODEL_DIR"
|
||||
```
|
||||
|
||||
By default, Ray deduplicates logs across cluster. Here we set `RAY_DEDUP_LOGS=0` to disable log deduplication, enabling each actor to log information in CLI. `serve run` runs an application from the specified import path. The formats should be `<filename>:<app_name>`.
|
||||
|
||||
Then we could send requests by running python script in another window:
|
||||
```bash
|
||||
python send_request.py
|
||||
```
|
||||
|
||||
### Method #2. Run inside script
|
||||
|
||||
We could also launch ray serve and run the app inside a single script by making some modifications:
|
||||
To avoid ray handler from raising error in serializing pydantic objects, we'll replace the config class from `class GenConfigArgs(BaseModel)` to
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
@dataclass
|
||||
class GenConfigArgs:
|
||||
# attributes remain unchanged
|
||||
```
|
||||
Comment out the app builder
|
||||
```python
|
||||
# def app(args: GenConfigArgs) -> Application:
|
||||
# ...
|
||||
# return Driver.options(name="Colossal-Inference-Driver").bind(config=args)
|
||||
```
|
||||
And attach the following lines to the end of the file,
|
||||
```python
|
||||
from ray.serve.handle import DeploymentHandle, DeploymentResponse
|
||||
|
||||
app = Driver.bind(config=GenConfigArgs(path="<Path_to_model_dir>"))
|
||||
handle: DeploymentHandle = serve.run(app).options(use_new_handle_api=True)
|
||||
response: DeploymentResponse = handle.batch_generate.remote(requests="Introduce some landmarks in Beijing")
|
||||
print(response.result())
|
||||
```
|
||||
Then we could run the script
|
||||
```python
|
||||
python Colossal_Inference_rayserve.py
|
||||
```
|
||||
|
||||
### Terminate Ray Serve
|
||||
Ray serve and the application would terminate automatically as you choose the second method to run any job in the script. If you choose the first method (serve run), you might want to apply `ctrl+c` to shut down the application, or use `serve shutdown` to shut down serve and deletes all applications on the ray cluster.
|
||||
|
||||
To make sure all the active Ray processes are killed, run
|
||||
```bash
|
||||
ray stop
|
||||
```
|
|
@ -0,0 +1,15 @@
|
|||
import ray
|
||||
import requests
|
||||
|
||||
|
||||
@ray.remote
|
||||
def send_query(text):
|
||||
resp = requests.get("http://localhost:8000/?text={}".format(text))
|
||||
return resp.text
|
||||
|
||||
|
||||
test_sentence = "Introduce some landmarks in Beijing"
|
||||
|
||||
result = ray.get(send_query.remote(test_sentence))
|
||||
print("Result returned:")
|
||||
print(result)
|
|
@ -0,0 +1,27 @@
|
|||
import ray
|
||||
import requests
|
||||
|
||||
|
||||
@ray.remote
|
||||
def send_query(text):
|
||||
resp = requests.get("http://localhost:8000/?text={}".format(text))
|
||||
return resp.text
|
||||
|
||||
|
||||
test_sentences = [
|
||||
"Introduce some landmarks in Beijing",
|
||||
"What is the weather today",
|
||||
"Coding requires practice and patience",
|
||||
"Rainy days inspire cozy reading",
|
||||
"Laughter is contagious and heartwarming",
|
||||
"Hiking mountains builds strength and resilience",
|
||||
"Family bonds grow stronger with time",
|
||||
"Science unlocks mysteries of the universe",
|
||||
"Music soothes the soul and ignites passion",
|
||||
"Artistic expression knows no boundaries",
|
||||
]
|
||||
|
||||
results = ray.get([send_query.remote(text) for text in test_sentences])
|
||||
print("Result returned:")
|
||||
for res in results:
|
||||
print(res)
|
Loading…
Reference in New Issue