# ⚡️ ColossalAI-Inference
## 📚 Table of Contents
- [⚡️ ColossalAI-Inference](#️-colossalai-inference)
- [📚 Table of Contents](#-table-of-contents)
- [📌 Introduction](#-introduction)
- [🕹 Usage](#-usage)
- [🗺 Roadmap](#-roadmap)
- [🪅 Support Matrix](#-support-matrix)
- [🛠 Design and Components](#-design-and-components)
- [Overview](#overview)
- [Engine](#engine)
- [Blocked KV Cache Manager](#kv-cache)
- [Batching](#batching)
- [Modeling](#modeling)
- [🌟 Acknowledgement](#-acknowledgement)
## 📌 Introduction
ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness.
## 🕹 Usage
### :arrow_right: Quick Start
The sample usage of the inference engine is given below:
```python
import torch
import transformers
import colossalai
from colossalai.inference import InferenceEngine, InferenceConfig
from pprint import pprint
colossalai.launch_from_torch()
# Step 1: create a model in "transformers" way
model_path = "lmsys/vicuna-7b-v1.3"
model = transformers.LlamaForCausalLM.from_pretrained(model_path).cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
# Step 2: create an inference_config
inference_config = InferenceConfig(
dtype=torch.float16,
max_batch_size=4,
max_input_len=1024,
max_output_len=512,
use_cuda_kernel=True,
)
# Step 3: create an engine with model and config
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
# Step 4: try inference
prompts = ['Who is the best player in the history of NBA?']
response = engine.generate(prompts)
pprint(response)
```
You could run the sample code by
```bash
colossalai run --nproc_per_node 1 your_sample_name.py
```
For detailed examples, you might want to check [inference examples](../../examples/inference/llama/README.md).
### :bookmark: Customize your inference engine
Besides the basic quick-start inference, you can also customize your inference engine via modifying inference config or uploading your own models, policies, or decoding components (logits processors or sampling strategies).
#### Inference Config
Inference Config is a unified config for initializing the inference engine, controlling multi-GPU generation (Tensor Parallelism), as well as presetting generation configs. Below are some commonly used `InferenceConfig`'s arguments:
- `max_batch_size`: The maximum batch size. Defaults to 8.
- `max_input_len`: The maximum input length (number of tokens). Defaults to 256.
- `max_output_len`: The maximum output length (number of tokens). Defaults to 256.
- `dtype`: The data type of the model for inference. This can be one of `fp16`, `bf16`, or `fp32`. Defaults to `fp16`.
- `kv_cache_dtype`: The data type used for KVCache. Defaults to the same data type as the model (`dtype`). KVCache quantization will be automatically enabled if it is different from that of model (`dtype`).
- `use_cuda_kernel`: Determine whether to use CUDA kernels or not. If disabled, Triton kernels will be used. Defaults to False.
- `tp_size`: Tensor-Parallelism size. Defaults to 1 (tensor parallelism is turned off by default).
#### Generation Config
Refer to transformers [GenerationConfig](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig) on functionalities and usage of specific configs. In ColossalAI-Inference, generation configs can be preset in `InferenceConfig`. Supported generation configs include:
- `do_sample`: Whether or not to use sampling. Defaults to False (greedy decoding).
- `top_k`: The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 50.
- `top_p`: If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to 1.0.
- `temperature`: The value used to modulate the next token probabilities. Defaults to 1.0.
- `no_repeat_ngram_size`: If set to int > 0, all ngrams of that size can only occur once. Defaults to 0.
- `repetition_penalty`: The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
- `forced_eos_token_id`: The id of the token to force as the last generated token when max_length is reached. Defaults to `None`.
Users can also create a transformers [GenerationConfig](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig) as an input argument for `InferenceEngine.generate` API. For example
```python
generation_config = GenerationConfig(
max_length=128,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=1.0,
)
response = engine.generate(prompts=prompts, generation_config=generation_config)
```
## 🗺 Roadmap
We will follow the following roadmap to develop major features of ColossalAI-Inference:
- [x] Blocked KV Cache
- [x] Paged Attention
- 🟩 Fused Kernels
- [x] Speculative Decoding
- [x] Continuous Batching
- 🟩 Tensor Parallelism
- [ ] Online Inference
- [ ] Beam Search
- [ ] SplitFuse
Notations:
- [x] Completed
- 🟩 Model specific and in still progress.
## 🪅 Support Matrix
| Model | Model Card | Tensor Parallel | Lazy Initialization | Paged Attention | Fused Kernels | Speculative Decoding |
|-----------|------------------------------------------------------------------------------------------------|-----------------|---------------------|-----------------|---------------|----------------------|
| Baichuan | `baichuan-inc/Baichuan2-7B-Base`,
`baichuan-inc/Baichuan2-13B-Base`, etc | ✅ | [ ] | ✅ | ✅ | [ ] |
| ChatGLM | | [ ] | [ ] | [ ] | [ ] | [ ] |
| DeepSeek | | [ ] | [ ] | [ ] | [ ] | [ ] |
| Llama | `meta-llama/Llama-2-7b`,
`meta-llama/Llama-2-13b`,
`meta-llama/Meta-Llama-3-8B`,
`meta-llama/Meta-Llama-3-70B`, etc | ✅ | [ ] | ✅ | ✅ | ✅ |
| Mixtral | | [ ] | [ ] | [ ] | [ ] | [ ] |
| Qwen | | [ ] | [ ] | [ ] | [ ] | [ ] |
| Vicuna | `lmsys/vicuna-13b-v1.3`,
`lmsys/vicuna-7b-v1.5` | ✅ | [ ] | ✅ | ✅ | ✅ |
| Yi | `01-ai/Yi-34B`, etc | ✅ | [ ] | ✅ | ✅ | ✅ |
## 🛠 Design and Components
### Overview
ColossalAI-Inference has **4** major components, namely `engine`, `request handler`, `kv cache manager`, and `modeling`.
Example of block table for a batch
Naive Batching: decode until each sequence encounters eos in a batch
Continuous Batching: dynamically adjust the batch size by popping out finished sequences and inserting prefill batch