mirror of https://github.com/hpcaitech/ColossalAI
[Infer] Colossal-Inference serving example w/ TorchServe (single GPU case) (#4771)
* add Colossal-Inference serving example w/ TorchServe * add dockerfile * fix dockerfile * fix dockerfile: fix commit hash, install curl * refactor file structure * revise readme * trivial * trivial: dockerfile format * clean dir; revise readme * fix comments: fix imports and configs * fix formats * remove unused requirementspull/4849/head
parent
ed06731e00
commit
3a74eb4b3a
|
@ -0,0 +1,193 @@
|
|||
import logging
|
||||
import os
|
||||
import zipfile
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM
|
||||
from ts.torch_handler.base_handler import BaseHandler
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import free_port
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Transformers version %s", transformers.__version__)
|
||||
logger.info("ColossalAI version %s", colossalai.__version__)
|
||||
|
||||
|
||||
class ColossalInferenceHandler(BaseHandler, ABC):
|
||||
"""
|
||||
Transformers handler class for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ColossalInferenceHandler, self).__init__()
|
||||
self.infer_engine = None
|
||||
self.max_batch_size = None
|
||||
self.max_input_len = None
|
||||
self.max_output_len = None
|
||||
self.tokenizer = None
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self, ctx):
|
||||
"""Expected behaviour: the sharded Bloom/Llama model is loaded.
|
||||
|
||||
Args:
|
||||
ctx (context): It is a JSON Object containing information
|
||||
pertaining to the model artefacts parameters.
|
||||
"""
|
||||
if ctx is not None or not hasattr(ctx, "model_yaml_config"):
|
||||
logger.error("Context ctx and model-config are not appropriately passed in.")
|
||||
|
||||
self.manifest = ctx.manifest
|
||||
gpu_id = ctx.system_properties.get("gpu_id", -1)
|
||||
model_dir = ctx.system_properties.get("model_dir")
|
||||
|
||||
# Inference configs are collected together in model yaml config for handler use
|
||||
inference_config = ctx.model_yaml_config["handler"]
|
||||
self.inference_config = inference_config
|
||||
logger.info(self.inference_config)
|
||||
|
||||
self.tp_size = self.inference_config.get("tp_size", 1)
|
||||
self.max_batch_size = self.inference_config.get("max_batch_size", 4)
|
||||
self.max_input_len = self.inference_config.get("max_input_len", 1024)
|
||||
self.max_output_len = self.inference_config.get("max_output_len", 128)
|
||||
|
||||
self.device = torch.device("cuda:" + str(gpu_id) if torch.cuda.is_available() and gpu_id >= 0 else "cpu")
|
||||
logger.info(f"Device set to {self.device}")
|
||||
logger.info(f"torch.cuda.device_count() {torch.cuda.device_count()}")
|
||||
|
||||
# Unpacking from model_dir
|
||||
model_dir_path = os.path.join(model_dir, "model")
|
||||
with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref:
|
||||
zip_ref.extractall(model_dir_path)
|
||||
logger.info(f"Loading {self.inference_config['model_type']} pretrain model and tokenizer")
|
||||
if self.inference_config["model_type"] == "bloom":
|
||||
self.model = BloomForCausalLM.from_pretrained(
|
||||
model_dir_path,
|
||||
)
|
||||
self.tokenizer = BloomTokenizerFast.from_pretrained(model_dir_path, return_tensors="pt")
|
||||
elif self.inference_config["model_type"] == "llama":
|
||||
self.model = LlamaForCausalLM.from_pretrained(
|
||||
model_dir_path,
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir_path, return_tensors="pt")
|
||||
else:
|
||||
logger.warning(f"Model type {self.inference_config['model_type']} not supported yet.")
|
||||
|
||||
logger.info("Transformer model from path %s loaded successfully", model_dir)
|
||||
|
||||
# NOTE world_size, rank, host, port here are used to launch colossalai dist environment
|
||||
# This world_size is different from the world size of TorchServe
|
||||
world_size = int(os.getenv("WORLD_SIZE", self.tp_size))
|
||||
assert world_size == 1, "Colossal-Inference with tensor parallel is not supported on TorchServe for now"
|
||||
rank = int(os.getenv("RANK", gpu_id))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", gpu_id))
|
||||
host = os.getenv("MASTER_ADDR", "localhost")
|
||||
port = os.getenv("MASTER_PORT", free_port()) # use a random free port
|
||||
|
||||
logger.info(
|
||||
f" world_size {world_size}" f" local_rank {local_rank}" f" rank {rank}" f" host {host}" f" port {port}"
|
||||
)
|
||||
|
||||
torch.cuda.set_device(self.device)
|
||||
self.model.half()
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||
logger.info("Initializing TPInferEngine ...")
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True)
|
||||
self.infer_engine = TPInferEngine(
|
||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||
)
|
||||
logger.info("TPInferEngine initialized successfully")
|
||||
|
||||
self.model = self.infer_engine.model
|
||||
self.initialized = True
|
||||
|
||||
def preprocess(self, requests):
|
||||
"""Basic text preprocessing, based on the user's chocie of application mode.
|
||||
Args:
|
||||
requests: The Input data in the form of text is passed on to the preprocess
|
||||
function.
|
||||
Returns:
|
||||
list : The preprocess function returns a list of Tensor for the size of the word tokens.
|
||||
"""
|
||||
logger.info("Pre-processing requests")
|
||||
input_ids_batch = None
|
||||
attention_mask_batch = None
|
||||
for idx, data in enumerate(requests):
|
||||
input_text = data.get("data")
|
||||
if input_text is None:
|
||||
input_text = data.get("body")
|
||||
if isinstance(input_text, (bytes, bytearray)):
|
||||
input_text = input_text.decode("utf-8")
|
||||
|
||||
logger.info("Received text: '%s'", input_text)
|
||||
|
||||
inputs = self.tokenizer.encode_plus(
|
||||
input_text,
|
||||
max_length=self.max_input_len,
|
||||
padding=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"].to(self.device)
|
||||
attention_mask = inputs["attention_mask"].to(self.device)
|
||||
# making a batch out of the recieved requests
|
||||
# attention masks are passed for cases where input tokens are padded.
|
||||
if input_ids.shape is not None:
|
||||
if input_ids_batch is None:
|
||||
input_ids_batch = input_ids
|
||||
attention_mask_batch = attention_mask
|
||||
else:
|
||||
input_ids_batch = torch.cat((input_ids_batch, input_ids), 0)
|
||||
attention_mask_batch = torch.cat((attention_mask_batch, attention_mask), 0)
|
||||
return (input_ids_batch, attention_mask_batch)
|
||||
|
||||
def inference(self, input_batch):
|
||||
"""Predict the class (or classes) of the received text using the
|
||||
serialized transformers checkpoint.
|
||||
Args:
|
||||
input_batch (list): List of Text Tensors from the pre-process function is passed here
|
||||
Returns:
|
||||
list : It returns a list of the predicted value for the input text
|
||||
"""
|
||||
input_ids_batch, attention_mask_batch = input_batch
|
||||
inferences = []
|
||||
|
||||
do_sample = self.inference_config.get("do_sample", True)
|
||||
top_p = self.inference_config.get("top_p", 0.95 if do_sample else 1.0)
|
||||
top_k = self.inference_config.get("top_k", 60 if do_sample else 50)
|
||||
input_ids_batch = input_ids_batch.to(self.device)
|
||||
outputs = self.infer_engine.generate(
|
||||
dict(input_ids=input_ids_batch, attention_mask=attention_mask_batch),
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
for i, _ in enumerate(outputs):
|
||||
inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True))
|
||||
|
||||
# For testing only
|
||||
logger.info(
|
||||
f"Generated text: {inferences}",
|
||||
)
|
||||
|
||||
return inferences
|
||||
|
||||
def postprocess(self, inference_output):
|
||||
"""Post Process Function converts the predicted response into Torchserve readable format.
|
||||
Args:
|
||||
inference_output (list): It contains the predicted response of the input text.
|
||||
Returns:
|
||||
(list): Returns a list of the Predictions and Explanations.
|
||||
"""
|
||||
return inference_output
|
|
@ -0,0 +1,109 @@
|
|||
# Colossal-Inference with TorchServe
|
||||
|
||||
## Overview
|
||||
|
||||
This demo is used for testing and demonstrating the usage of Colossal Inference from `colossalai.inference` with deployment with TorchServe. It imports inference modules from colossalai and is based on
|
||||
https://github.com/hpcaitech/ColossalAI/tree/3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0. For now, single-gpu inference serving is supported.
|
||||
|
||||
## Environment for testing
|
||||
### Option #1: Use Conda Env
|
||||
Records to create a conda env to test locally as follows. We might want to use docker or configure env on cloud platform later.
|
||||
|
||||
*NOTE*: It requires the installation of jdk and the set of `JAVA_HOME`. We recommend to install open-jdk-17 (Please refer to https://openjdk.org/projects/jdk/17/)
|
||||
|
||||
```bash
|
||||
# use python 3.8 or 3.9
|
||||
conda create -n infer python=3.9
|
||||
|
||||
# use torch 1.13+cuda11.6 for inference
|
||||
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
|
||||
# conda cuda toolkit (e.g. nvcc, etc)
|
||||
conda install -c "nvidia/label/cuda-11.6.2" cuda-toolkit
|
||||
|
||||
# install colossalai with PyTorch extensions
|
||||
cd <path_to_ColossalAI_repo>
|
||||
pip install -r requirements/requirements.txt
|
||||
pip install -r requirements/requirements-test.txt
|
||||
CUDA_EXT=1 pip install -e .
|
||||
|
||||
# install torchserve
|
||||
cd <path_to_torch_serve_repo>
|
||||
python ./ts_scripts/install_dependencies.py --cuda=cu116
|
||||
pip install torchserve torch-model-archiver torch-workflow-archiver
|
||||
```
|
||||
|
||||
### Option #2: Use Docker
|
||||
To use the stable diffusion Docker image, you can build using the provided the [Dockerfile](./docker/Dockerfile).
|
||||
|
||||
```bash
|
||||
# build from dockerfile
|
||||
cd ColossalAI/examples/inference/serving/torch_serve/docker
|
||||
docker build -t hpcaitech/colossal-infer-ts:0.2.0 .
|
||||
```
|
||||
|
||||
Once you have the image ready, you can launch the image with the following command
|
||||
|
||||
```bash
|
||||
cd ColossalAI/examples/inference/serving/torch_serve
|
||||
|
||||
# run the docker container
|
||||
docker run --rm \
|
||||
-it --gpus all \
|
||||
--name <name_you_assign> \
|
||||
-v <your-data-dir>:/data/scratch \
|
||||
-w <ColossalAI_dir> \
|
||||
hpcaitech/colossal-infer-ts:0.2.0 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Steps to deploy a model
|
||||
|
||||
### 1.download/prepare a model
|
||||
We will download a bloom model, and then zip the downloaded model. You could download the model from [HuggingFace](https://huggingface.co/models) manually, or you might want to refer to this script [download_model.py](https://github.com/pytorch/serve/blob/c3ca2599b4d36d2b61302064b02eab1b65e1908d/examples/large_models/utils/Download_model.py) provided by pytorch-serve team to help you download a snapshot of the model.
|
||||
|
||||
```bash
|
||||
# download snapshots
|
||||
cd <path_to_torch_serve>/examples/large_models/utils/
|
||||
huggingface-cli login
|
||||
python download_model.py --model_name bigscience/bloom-560m -o <path_to_store_downloaded_model>
|
||||
|
||||
# zip the model repo
|
||||
cd <path_to_store_downloaded_model>/models--bigscience--bloom-560m/snapshots/<specific_revision>
|
||||
zip -r <path_to_place_zipped_model>//model.zip *
|
||||
```
|
||||
|
||||
> **_NOTE:_** The torch archiver and server will use `/tmp/` folder. Depending on the limit of disk quota, using torch-model-archiver might cause OSError "Disk quota exceeded". To prevent the OSError, set tmp dir environment variable as follows:
|
||||
`export TMPDIR=<dir_with_enough_space>/tmp` and `export TEMP=<dir_with_enough_space>/tmp`,
|
||||
or use relatively small models (as we did) for local testing.
|
||||
|
||||
### 2. Archive the model
|
||||
With torch archiver, we will pack the model file (.zip) as well as handler file (.py) together into a .mar file. And then in serving process these files will be unpacked by TorchServe. Revelant model configs and inference configs can be set in `model-config.yaml`.
|
||||
```bash
|
||||
cd ./ColossalAI/examples/inference/serving/torch_serve
|
||||
# create a folder under the current directory to store the packed model created by torch archiver
|
||||
mkdir model_store
|
||||
torch-model-archiver --model-name bloom --version 0.1 --handler Colossal_Inference_Handler.py --config-file model-config.yaml --extra-files <dir_zipped_model>/model.zip --export-path ./model_store/
|
||||
```
|
||||
|
||||
### 3. Launch serving
|
||||
|
||||
Modify `load_models` in config.properties to select the model(s) stored in <model_store> directory to be deployed. By default we use `load_models=all` to load and deploy all the models (.mar) we have.
|
||||
|
||||
```bash
|
||||
torchserve --start --ncs --ts-config config.properties
|
||||
```
|
||||
We could set inference, management, and metrics addresses and other TorchServe settings in `config.properties`.
|
||||
|
||||
TorchServe will create a folder `logs/` under the current directory to store ts, model, and metrics logs.
|
||||
|
||||
### 4. Run inference
|
||||
|
||||
```bash
|
||||
# check inference status
|
||||
curl http://0.0.0.0:8084/ping
|
||||
|
||||
curl -X POST http://localhost:8084/predictions/bloom -T sample_text.txt
|
||||
```
|
||||
|
||||
To stop TorchServe, run `torchserve --stop`
|
|
@ -0,0 +1,10 @@
|
|||
inference_address=http://0.0.0.0:8084
|
||||
management_address=http://0.0.0.0:8085
|
||||
metrics_address=http://0.0.0.0:8086
|
||||
enable_envvars_config=true
|
||||
install_py_dep_per_model=true
|
||||
number_of_gpu=1
|
||||
load_models=all
|
||||
max_response_size=655350000
|
||||
default_response_timeout=6000
|
||||
model_store=./model_store
|
|
@ -0,0 +1,57 @@
|
|||
FROM hpcaitech/pytorch-cuda:1.13.0-11.6.0
|
||||
|
||||
# enable passwordless ssh
|
||||
RUN mkdir ~/.ssh && \
|
||||
printf "Host * \n ForwardAgent yes\nHost *\n StrictHostKeyChecking no" > ~/.ssh/config && \
|
||||
ssh-keygen -t rsa -N "" -f ~/.ssh/id_rsa && \
|
||||
cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
|
||||
|
||||
# install curl
|
||||
RUN apt-get update && \
|
||||
apt-get -y install curl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Download and extract OpenJDK 17
|
||||
ENV JAVA_HOME /opt/openjdk-17
|
||||
RUN apt-get update && \
|
||||
apt-get install -y wget && \
|
||||
wget -q https://download.java.net/openjdk/jdk17/ri/openjdk-17+35_linux-x64_bin.tar.gz -O /tmp/openjdk.tar.gz && \
|
||||
mkdir -p $JAVA_HOME && \
|
||||
tar xzf /tmp/openjdk.tar.gz -C $JAVA_HOME --strip-components=1 && \
|
||||
rm /tmp/openjdk.tar.gz && \
|
||||
apt-get purge -y --auto-remove wget && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV PATH $JAVA_HOME/bin:$PATH
|
||||
RUN export JAVA_HOME
|
||||
RUN java -version
|
||||
|
||||
# install ninja
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends ninja-build && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# install colossalai
|
||||
ARG VERSION=main
|
||||
RUN git clone -b ${VERSION} https://github.com/hpcaitech/ColossalAI.git && \
|
||||
cd ./ColossalAI && \
|
||||
git checkout 3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0 && \
|
||||
CUDA_EXT=1 pip install -v --no-cache-dir .
|
||||
|
||||
# install titans
|
||||
RUN pip install --no-cache-dir titans
|
||||
|
||||
# install transformers
|
||||
RUN pip install --no-cache-dir transformers
|
||||
|
||||
# install triton
|
||||
RUN pip install --no-cache-dir triton==2.0.0.dev20221202
|
||||
|
||||
# install torchserve
|
||||
ARG VERSION=master
|
||||
RUN git clone -b ${VERSION} https://github.com/pytorch/serve.git && \
|
||||
cd ./serve && \
|
||||
python ./ts_scripts/install_dependencies.py --cuda=cu116 && \
|
||||
pip install torchserve torch-model-archiver torch-workflow-archiver
|
|
@ -0,0 +1,16 @@
|
|||
# TS frontend parameters settings
|
||||
minWorkers: 1 # minimum number of workers of a model
|
||||
maxWorkers: 1 # maximum number of workers of a model
|
||||
batchSize: 8 # batch size of a model
|
||||
maxBatchDelay: 100 # maximum delay of a batch (ms)
|
||||
responseTimeout: 120 # timeout of a specific model's response (*in sec)
|
||||
deviceType: "gpu"
|
||||
# deviceIds: [0, 1] # seting CUDA_VISIBLE_DEVICES
|
||||
|
||||
handler:
|
||||
mode: "text_generation"
|
||||
model_type: "bloom"
|
||||
tp_size: 1
|
||||
max_batch_size: 8
|
||||
max_input_len: 1024
|
||||
max_output_len: 128
|
|
@ -0,0 +1 @@
|
|||
Introduce some landmarks in Beijing
|
Loading…
Reference in New Issue