[Infer] Colossal-Inference serving example w/ TorchServe (single GPU case) (#4771)

* add Colossal-Inference serving example w/ TorchServe

* add dockerfile

* fix dockerfile

* fix dockerfile: fix commit hash, install curl

* refactor file structure

* revise readme

* trivial

* trivial: dockerfile format

* clean dir; revise readme

* fix comments: fix imports and configs

* fix formats

* remove unused requirements
pull/4849/head
Yuanheng Zhao 2023-10-02 17:42:37 +08:00 committed by GitHub
parent ed06731e00
commit 3a74eb4b3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 386 additions and 0 deletions

View File

View File

@ -0,0 +1,193 @@
import logging
import os
import zipfile
from abc import ABC
import torch
import transformers
from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM
from ts.torch_handler.base_handler import BaseHandler
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import free_port
logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)
logger.info("ColossalAI version %s", colossalai.__version__)
class ColossalInferenceHandler(BaseHandler, ABC):
"""
Transformers handler class for testing
"""
def __init__(self):
super(ColossalInferenceHandler, self).__init__()
self.infer_engine = None
self.max_batch_size = None
self.max_input_len = None
self.max_output_len = None
self.tokenizer = None
self.initialized = False
def initialize(self, ctx):
"""Expected behaviour: the sharded Bloom/Llama model is loaded.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artefacts parameters.
"""
if ctx is not None or not hasattr(ctx, "model_yaml_config"):
logger.error("Context ctx and model-config are not appropriately passed in.")
self.manifest = ctx.manifest
gpu_id = ctx.system_properties.get("gpu_id", -1)
model_dir = ctx.system_properties.get("model_dir")
# Inference configs are collected together in model yaml config for handler use
inference_config = ctx.model_yaml_config["handler"]
self.inference_config = inference_config
logger.info(self.inference_config)
self.tp_size = self.inference_config.get("tp_size", 1)
self.max_batch_size = self.inference_config.get("max_batch_size", 4)
self.max_input_len = self.inference_config.get("max_input_len", 1024)
self.max_output_len = self.inference_config.get("max_output_len", 128)
self.device = torch.device("cuda:" + str(gpu_id) if torch.cuda.is_available() and gpu_id >= 0 else "cpu")
logger.info(f"Device set to {self.device}")
logger.info(f"torch.cuda.device_count() {torch.cuda.device_count()}")
# Unpacking from model_dir
model_dir_path = os.path.join(model_dir, "model")
with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref:
zip_ref.extractall(model_dir_path)
logger.info(f"Loading {self.inference_config['model_type']} pretrain model and tokenizer")
if self.inference_config["model_type"] == "bloom":
self.model = BloomForCausalLM.from_pretrained(
model_dir_path,
)
self.tokenizer = BloomTokenizerFast.from_pretrained(model_dir_path, return_tensors="pt")
elif self.inference_config["model_type"] == "llama":
self.model = LlamaForCausalLM.from_pretrained(
model_dir_path,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_dir_path, return_tensors="pt")
else:
logger.warning(f"Model type {self.inference_config['model_type']} not supported yet.")
logger.info("Transformer model from path %s loaded successfully", model_dir)
# NOTE world_size, rank, host, port here are used to launch colossalai dist environment
# This world_size is different from the world size of TorchServe
world_size = int(os.getenv("WORLD_SIZE", self.tp_size))
assert world_size == 1, "Colossal-Inference with tensor parallel is not supported on TorchServe for now"
rank = int(os.getenv("RANK", gpu_id))
local_rank = int(os.getenv("LOCAL_RANK", gpu_id))
host = os.getenv("MASTER_ADDR", "localhost")
port = os.getenv("MASTER_PORT", free_port()) # use a random free port
logger.info(
f" world_size {world_size}" f" local_rank {local_rank}" f" rank {rank}" f" host {host}" f" port {port}"
)
torch.cuda.set_device(self.device)
self.model.half()
self.model.cuda()
self.model.eval()
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
logger.info("Initializing TPInferEngine ...")
shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
logger.info("TPInferEngine initialized successfully")
self.model = self.infer_engine.model
self.initialized = True
def preprocess(self, requests):
"""Basic text preprocessing, based on the user's chocie of application mode.
Args:
requests: The Input data in the form of text is passed on to the preprocess
function.
Returns:
list : The preprocess function returns a list of Tensor for the size of the word tokens.
"""
logger.info("Pre-processing requests")
input_ids_batch = None
attention_mask_batch = None
for idx, data in enumerate(requests):
input_text = data.get("data")
if input_text is None:
input_text = data.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
logger.info("Received text: '%s'", input_text)
inputs = self.tokenizer.encode_plus(
input_text,
max_length=self.max_input_len,
padding=True,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# making a batch out of the recieved requests
# attention masks are passed for cases where input tokens are padded.
if input_ids.shape is not None:
if input_ids_batch is None:
input_ids_batch = input_ids
attention_mask_batch = attention_mask
else:
input_ids_batch = torch.cat((input_ids_batch, input_ids), 0)
attention_mask_batch = torch.cat((attention_mask_batch, attention_mask), 0)
return (input_ids_batch, attention_mask_batch)
def inference(self, input_batch):
"""Predict the class (or classes) of the received text using the
serialized transformers checkpoint.
Args:
input_batch (list): List of Text Tensors from the pre-process function is passed here
Returns:
list : It returns a list of the predicted value for the input text
"""
input_ids_batch, attention_mask_batch = input_batch
inferences = []
do_sample = self.inference_config.get("do_sample", True)
top_p = self.inference_config.get("top_p", 0.95 if do_sample else 1.0)
top_k = self.inference_config.get("top_k", 60 if do_sample else 50)
input_ids_batch = input_ids_batch.to(self.device)
outputs = self.infer_engine.generate(
dict(input_ids=input_ids_batch, attention_mask=attention_mask_batch),
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
)
for i, _ in enumerate(outputs):
inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True))
# For testing only
logger.info(
f"Generated text: {inferences}",
)
return inferences
def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions and Explanations.
"""
return inference_output

View File

@ -0,0 +1,109 @@
# Colossal-Inference with TorchServe
## Overview
This demo is used for testing and demonstrating the usage of Colossal Inference from `colossalai.inference` with deployment with TorchServe. It imports inference modules from colossalai and is based on
https://github.com/hpcaitech/ColossalAI/tree/3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0. For now, single-gpu inference serving is supported.
## Environment for testing
### Option #1: Use Conda Env
Records to create a conda env to test locally as follows. We might want to use docker or configure env on cloud platform later.
*NOTE*: It requires the installation of jdk and the set of `JAVA_HOME`. We recommend to install open-jdk-17 (Please refer to https://openjdk.org/projects/jdk/17/)
```bash
# use python 3.8 or 3.9
conda create -n infer python=3.9
# use torch 1.13+cuda11.6 for inference
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
# conda cuda toolkit (e.g. nvcc, etc)
conda install -c "nvidia/label/cuda-11.6.2" cuda-toolkit
# install colossalai with PyTorch extensions
cd <path_to_ColossalAI_repo>
pip install -r requirements/requirements.txt
pip install -r requirements/requirements-test.txt
CUDA_EXT=1 pip install -e .
# install torchserve
cd <path_to_torch_serve_repo>
python ./ts_scripts/install_dependencies.py --cuda=cu116
pip install torchserve torch-model-archiver torch-workflow-archiver
```
### Option #2: Use Docker
To use the stable diffusion Docker image, you can build using the provided the [Dockerfile](./docker/Dockerfile).
```bash
# build from dockerfile
cd ColossalAI/examples/inference/serving/torch_serve/docker
docker build -t hpcaitech/colossal-infer-ts:0.2.0 .
```
Once you have the image ready, you can launch the image with the following command
```bash
cd ColossalAI/examples/inference/serving/torch_serve
# run the docker container
docker run --rm \
-it --gpus all \
--name <name_you_assign> \
-v <your-data-dir>:/data/scratch \
-w <ColossalAI_dir> \
hpcaitech/colossal-infer-ts:0.2.0 \
/bin/bash
```
## Steps to deploy a model
### 1.download/prepare a model
We will download a bloom model, and then zip the downloaded model. You could download the model from [HuggingFace](https://huggingface.co/models) manually, or you might want to refer to this script [download_model.py](https://github.com/pytorch/serve/blob/c3ca2599b4d36d2b61302064b02eab1b65e1908d/examples/large_models/utils/Download_model.py) provided by pytorch-serve team to help you download a snapshot of the model.
```bash
# download snapshots
cd <path_to_torch_serve>/examples/large_models/utils/
huggingface-cli login
python download_model.py --model_name bigscience/bloom-560m -o <path_to_store_downloaded_model>
# zip the model repo
cd <path_to_store_downloaded_model>/models--bigscience--bloom-560m/snapshots/<specific_revision>
zip -r <path_to_place_zipped_model>//model.zip *
```
> **_NOTE:_** The torch archiver and server will use `/tmp/` folder. Depending on the limit of disk quota, using torch-model-archiver might cause OSError "Disk quota exceeded". To prevent the OSError, set tmp dir environment variable as follows:
`export TMPDIR=<dir_with_enough_space>/tmp` and `export TEMP=<dir_with_enough_space>/tmp`,
or use relatively small models (as we did) for local testing.
### 2. Archive the model
With torch archiver, we will pack the model file (.zip) as well as handler file (.py) together into a .mar file. And then in serving process these files will be unpacked by TorchServe. Revelant model configs and inference configs can be set in `model-config.yaml`.
```bash
cd ./ColossalAI/examples/inference/serving/torch_serve
# create a folder under the current directory to store the packed model created by torch archiver
mkdir model_store
torch-model-archiver --model-name bloom --version 0.1 --handler Colossal_Inference_Handler.py --config-file model-config.yaml --extra-files <dir_zipped_model>/model.zip --export-path ./model_store/
```
### 3. Launch serving
Modify `load_models` in config.properties to select the model(s) stored in <model_store> directory to be deployed. By default we use `load_models=all` to load and deploy all the models (.mar) we have.
```bash
torchserve --start --ncs --ts-config config.properties
```
We could set inference, management, and metrics addresses and other TorchServe settings in `config.properties`.
TorchServe will create a folder `logs/` under the current directory to store ts, model, and metrics logs.
### 4. Run inference
```bash
# check inference status
curl http://0.0.0.0:8084/ping
curl -X POST http://localhost:8084/predictions/bloom -T sample_text.txt
```
To stop TorchServe, run `torchserve --stop`

View File

@ -0,0 +1,10 @@
inference_address=http://0.0.0.0:8084
management_address=http://0.0.0.0:8085
metrics_address=http://0.0.0.0:8086
enable_envvars_config=true
install_py_dep_per_model=true
number_of_gpu=1
load_models=all
max_response_size=655350000
default_response_timeout=6000
model_store=./model_store

View File

@ -0,0 +1,57 @@
FROM hpcaitech/pytorch-cuda:1.13.0-11.6.0
# enable passwordless ssh
RUN mkdir ~/.ssh && \
printf "Host * \n ForwardAgent yes\nHost *\n StrictHostKeyChecking no" > ~/.ssh/config && \
ssh-keygen -t rsa -N "" -f ~/.ssh/id_rsa && \
cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
# install curl
RUN apt-get update && \
apt-get -y install curl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Download and extract OpenJDK 17
ENV JAVA_HOME /opt/openjdk-17
RUN apt-get update && \
apt-get install -y wget && \
wget -q https://download.java.net/openjdk/jdk17/ri/openjdk-17+35_linux-x64_bin.tar.gz -O /tmp/openjdk.tar.gz && \
mkdir -p $JAVA_HOME && \
tar xzf /tmp/openjdk.tar.gz -C $JAVA_HOME --strip-components=1 && \
rm /tmp/openjdk.tar.gz && \
apt-get purge -y --auto-remove wget && \
rm -rf /var/lib/apt/lists/*
ENV PATH $JAVA_HOME/bin:$PATH
RUN export JAVA_HOME
RUN java -version
# install ninja
RUN apt-get update && \
apt-get install -y --no-install-recommends ninja-build && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# install colossalai
ARG VERSION=main
RUN git clone -b ${VERSION} https://github.com/hpcaitech/ColossalAI.git && \
cd ./ColossalAI && \
git checkout 3e05c07bb8921f2a8f9736b6f6673d4e9f1697d0 && \
CUDA_EXT=1 pip install -v --no-cache-dir .
# install titans
RUN pip install --no-cache-dir titans
# install transformers
RUN pip install --no-cache-dir transformers
# install triton
RUN pip install --no-cache-dir triton==2.0.0.dev20221202
# install torchserve
ARG VERSION=master
RUN git clone -b ${VERSION} https://github.com/pytorch/serve.git && \
cd ./serve && \
python ./ts_scripts/install_dependencies.py --cuda=cu116 && \
pip install torchserve torch-model-archiver torch-workflow-archiver

View File

@ -0,0 +1,16 @@
# TS frontend parameters settings
minWorkers: 1 # minimum number of workers of a model
maxWorkers: 1 # maximum number of workers of a model
batchSize: 8 # batch size of a model
maxBatchDelay: 100 # maximum delay of a batch (ms)
responseTimeout: 120 # timeout of a specific model's response (*in sec)
deviceType: "gpu"
# deviceIds: [0, 1] # seting CUDA_VISIBLE_DEVICES
handler:
mode: "text_generation"
model_type: "bloom"
tp_size: 1
max_batch_size: 8
max_input_len: 1024
max_output_len: 128

View File

@ -0,0 +1 @@
Introduce some landmarks in Beijing