ColossalAI/colossalai/inference
flybird11111 0c10afd372
[FP8] rebase main (#5963)
* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 16:29:37 +08:00
..
core [FP8] rebase main (#5963) 2024-08-06 16:29:37 +08:00
executor [pre-commit.ci] pre-commit autoupdate (#5572) 2024-07-01 17:16:41 +08:00
kv_cache [pre-commit.ci] pre-commit autoupdate (#5572) 2024-07-01 17:16:41 +08:00
modeling [FP8] rebase main (#5963) 2024-08-06 16:29:37 +08:00
server [Inference]Fix readme and example for API server (#5742) 2024-05-24 10:03:05 +08:00
spec [Fix] Fix spec-dec Glide LlamaModel for compatibility with transformers (#5837) 2024-06-19 15:37:53 +08:00
README.md [FP8] rebase main (#5963) 2024-08-06 16:29:37 +08:00
__init__.py [doc] updated inference readme (#5343) 2024-02-02 14:31:10 +08:00
batch_bucket.py [pre-commit.ci] pre-commit autoupdate (#5572) 2024-07-01 17:16:41 +08:00
config.py [FP8] rebase main (#5963) 2024-08-06 16:29:37 +08:00
flash_decoding_utils.py add paged-attetionv2: support seq length split across thread block (#5707) 2024-05-14 12:46:54 +08:00
graph_runner.py [fix] pytest and fix dyn grid bug 2024-03-13 17:28:32 +08:00
logit_processors.py [Inference] Fix Inference Generation Config and Sampling (#5710) 2024-05-19 15:08:42 +08:00
sampler.py [Inference] Fix Inference Generation Config and Sampling (#5710) 2024-05-19 15:08:42 +08:00
struct.py [FP8] rebase main (#5963) 2024-08-06 16:29:37 +08:00
utils.py [FP8] rebase main (#5963) 2024-08-06 16:29:37 +08:00

README.md

ColossalAI-Inference

📚 Table of Contents

📌 Introduction

ColossalAI-Inference is a module which offers acceleration to the inference execution of Transformers models, especially LLMs and DiT Diffusion Models. In ColossalAI-Inference, we leverage high-performance kernels, KV cache, paged attention, continous batching and other techniques to accelerate the inference of LLMs. We also provide simple and unified APIs for the sake of user-friendliness. [blog]

🕹 Usage

➡️ Quick Start

The sample usage of the inference engine is given below:

import torch
import transformers
import colossalai
from colossalai.inference import InferenceEngine, InferenceConfig
from pprint import pprint

colossalai.launch_from_torch()

# Step 1: create a model in "transformers" way
model_path = "lmsys/vicuna-7b-v1.3"
model = transformers.LlamaForCausalLM.from_pretrained(model_path).cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)

# Step 2: create an inference_config
inference_config = InferenceConfig(
                dtype=torch.float16,
                max_batch_size=4,
                max_input_len=1024,
                max_output_len=512,
                use_cuda_kernel=True,
            )

# Step 3: create an engine with model and config
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)

# Step 4: try inference
prompts = ['Who is the best player in the history of NBA?']
response = engine.generate(prompts)
pprint(response)

You could run the sample code by

colossalai run --nproc_per_node 1 your_sample_name.py

For detailed examples, you might want to check inference examples.

🔖 Customize your inference engine

Besides the basic quick-start inference, you can also customize your inference engine via modifying inference config or uploading your own models, policies, or decoding components (logits processors or sampling strategies).

Inference Config

Inference Config is a unified config for initializing the inference engine, controlling multi-GPU generation (Tensor Parallelism), as well as presetting generation configs. Below are some commonly used InferenceConfig's arguments:

  • max_batch_size: The maximum batch size. Defaults to 8.
  • max_input_len: The maximum input length (number of tokens). Defaults to 256.
  • max_output_len: The maximum output length (number of tokens). Defaults to 256.
  • dtype: The data type of the model for inference. This can be one of fp16, bf16, or fp32. Defaults to fp16.
  • kv_cache_dtype: The data type used for KVCache. Defaults to the same data type as the model (dtype). KVCache quantization will be automatically enabled if it is different from that of model (dtype).
  • use_cuda_kernel: Determine whether to use CUDA kernels or not. If disabled, Triton kernels will be used. Defaults to False.
  • tp_size: Tensor-Parallelism size. Defaults to 1 (tensor parallelism is turned off by default).

Generation Config

Refer to transformers GenerationConfig on functionalities and usage of specific configs. In ColossalAI-Inference, generation configs can be preset in InferenceConfig. Supported generation configs include:

  • do_sample: Whether or not to use sampling. Defaults to False (greedy decoding).
  • top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 50.
  • top_p: If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to 1.0.
  • temperature: The value used to modulate the next token probabilities. Defaults to 1.0.
  • no_repeat_ngram_size: If set to int > 0, all ngrams of that size can only occur once. Defaults to 0.
  • repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
  • forced_eos_token_id: The id of the token to force as the last generated token when max_length is reached. Defaults to None.

Users can also create a transformers GenerationConfig as an input argument for InferenceEngine.generate API. For example

generation_config = GenerationConfig(
    max_length=128,
    do_sample=True,
    temperature=0.7,
    top_k=50,
    top_p=1.0,
)
response = engine.generate(prompts=prompts, generation_config=generation_config)

🗺 Roadmap

We will follow the following roadmap to develop major features of ColossalAI-Inference:

  • Blocked KV Cache
  • Paged Attention
  • 🟩 Fused Kernels
  • Speculative Decoding
  • Continuous Batching
  • 🟩 Tensor Parallelism
  • Online Inference
  • Beam Search
  • SplitFuse

Notations:

  • Completed
  • 🟩 Model specific and in still progress.

🪅 Support Matrix

Model Model Card Tensor Parallel Lazy Initialization Paged Attention Fused Kernels Speculative Decoding
Baichuan baichuan-inc/Baichuan2-7B-Base,
baichuan-inc/Baichuan2-13B-Base, etc
[ ] [ ]
ChatGLM [ ] [ ] [ ] [ ] [ ]
DeepSeek [ ] [ ] [ ] [ ] [ ]
Llama meta-llama/Llama-2-7b,
meta-llama/Llama-2-13b,
meta-llama/Meta-Llama-3-8B,
meta-llama/Meta-Llama-3-70B, etc
[ ]
Mixtral [ ] [ ] [ ] [ ] [ ]
Qwen [ ] [ ] [ ] [ ] [ ]
Vicuna lmsys/vicuna-13b-v1.3,
lmsys/vicuna-7b-v1.5
[ ]
Yi 01-ai/Yi-34B, etc [ ]

🛠 Design and Components

Overview

ColossalAI-Inference has 4 major components, namely engine, request handler, kv cache manager, and modeling.

colossalai-inference-components-overview

  • Engine: It orchestrates the inference step. During inference, it recives a request, calls request handler to schedule a decoding batch, and executes the model forward pass to perform a iteration. It returns the inference results back to the user at the end.
  • Request Handler: It manages requests and schedules a proper batch from exisiting requests.
  • KV Cache Manager It is bound within the request handler, updates cache blocks and logical block tables as scheduled by the request handler.
  • Modelling: We rewrite the model and layers of LLMs to simplify and optimize the forward pass for inference.

An overview of the inter-component interaction is given below (RPC version). We would also introduce more details in the next few sections.

colossalai-inference-framework-rpc

Engine

Engine is designed as the entry point where the user kickstarts an inference loop. User can easily initialize an inference engine with the inference configurations and execute with their requests. We provided several versions of inference engines, namely InferenceEngine, RPCInferenceEngine, and AsyncInferenceEngine, which are used for different conditions and purposes.

For examples/inference/llama and RPCInferenceEngine, we expose the following APIs for inference:

  • generate: main function which handles inputs, performs inference and returns outputs.
  • add_request: add a single or multiple requests to the inference engine.
  • step: perform one decoding iteration. The request handler first schedules a batch to do prefill/decoding. Then, it invokes a model to generate a batch of token and afterwards does logit processing and sampling, checks and decodes finished requests.
  • enable_spec_dec: used for speculative decoding. Enable speculative decoding for subsequent generations.
  • disable_spec_dec: used for speculative decoding. Disable speculative decoding for subsequent generations
  • clear_spec_dec: clear structures and models related to speculative decoding, if exists.

For AsyncInferenceEngine, we expose the following APIs for inference:

  • add_request: async method. Add a request to the inference engine, as well as to the waiting queue of the background tracker.
  • generate: async method. Perform inference from a request.
  • step: async method. Perform one decoding iteration, if there exists any request in waiting queue.

For now, InferenceEngine is used for offline generation; AsyncInferenceEngine is used for online serving with a single card; and RPCInferenceEngine is used for online serving with multiple cards. In future, we will focus on RPCInferenceEngine and improve user experience of LLM serving.

KV cache

Learnt from PagedAttention by vLLM team, we use a unified blocked KV cache and cache manager to allocate and manage memory. The physical memory is pre-allocated during initialization and represented by a logical block table. During decoding process, cache manager administrates the physical memory through block table of a batch and so that other components (i.e. engine) can focus on the lightweight block table. More details are given below.

  • logical cache block: We group physical memory into different memory blocks. A typical cache block is shaped (num_kv_heads, block_size, head_size). We determine the block number beforehand. The memory allocation and computation are executed at the granularity of memory block.
  • block table: Block table is the logical representation of cache blocks. Concretely, a block table of a single sequence is a 1D tensor, with each element holding a block ID. Block ID of -1 means "Not Allocated". In each iteration, we pass through a batch block table to the corresponding model.


Example of block table for a batch

Batching

Request handler is responsible for managing requests and scheduling a proper batch from exisiting requests. Based on Orca's and vLLM's research and work on batching requests, we applied continuous batching with unpadded sequences, which enables various number of sequences to pass projections (i.e. Q, K, and V) together in different steps by hiding the dimension of number of sequences, and decrement the latency of incoming sequences by inserting a prefill batch during a decoding step and then decoding together.


Naive Batching: decode until each sequence encounters eos in a batch


Continuous Batching: dynamically adjust the batch size by popping out finished sequences and inserting prefill batch

Modeling

Modeling contains models, layers, and policy, which are hand-crafted for better performance easier usage. Integrated with shardformer, users can define their own policy or use our preset policies for specific models. Our modeling files are aligned with Transformers. For more details about the usage of modeling and policy, please check colossalai/shardformer.

Online Service

Colossal-Inference supports fast-api based online service. Simple completion and chat are both supported. Follow the commands below and you can simply construct a server with both completion and chat functionalities. For now we support Llama2,Llama3 and Baichuan2 model, etc. we will fullfill the blank quickly.

API

  • GET '/ping': Ping is used to check if the server can receive and send information.
  • GET '/engine_check': Check is the background engine is working.
  • POST '/completion': Completion api is used for single sequence request, like answer a question or complete words.
  • POST '/chat': Chat api is used for conversation-style request, which often includes dialogue participants(i.e. roles) and corresponding words. Considering the input data are very different from normal inputs, we introduce Chat-Template to match the data format in chat models.

chat-template

Followed transformers, we add the chat-template argument. As chat models have been trained with very different formats for converting conversations into a single tokenizable string. Using a format that matches the training data is extremely important. This attribute(chat_template) is inclueded in HuggingFace tokenizers, containing a Jinja template that converts conversation histories into a correctly formatted string. You can refer to the HuggingFace-blog for more information. We also provide a simple example template bellow. Both str or file style chat template are supported.

Usage

Args for customizing your server

The configuration for api server contains both serving interface and engine backend. For Interface:

  • --host: The host url on your device for the server.
  • --port: The port for service
  • --model: The model that backend engine uses, both path and transformers model card are supported.
  • --chat-template The file path of chat template or the template string.
  • --response-role The role that colossal-inference plays. For Engine Backend:
  • --block_size: The memory usage for each block.
  • --max_batch_size: The max batch size for engine to infer. This changes the speed of inference,
  • --max_input_len: The max input length of a request.
  • --max_output_len: The output length of response.
  • --dtype and --use_cuda_kernel: Deciding the precision and kernel usage. For more detailed arguments, please refer to source code.

Examples

# First, Lauch an API locally.
python3 -m colossalai.inference.server.api_server  --model path of your model --chat-template "{% for message in messages %}{{'<|im_start|>'+message['role']+'\n'+message['content']+'<|im_end|>'+'\n'}}{% endfor %}"

# Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api

# For completion service, you can invoke it
curl -X POST  http://127.0.0.1:8000/completion  -H 'Content-Type: application/json'  -d '{"prompt":"hello, who are you? "}'

# For chat service, you can invoke it
curl -X POST http://127.0.0.1:8000/chat -H 'Content-Type: application/json' -d '{"messages":[{"role":"system","content":"you are a helpful assistant"},{"role":"user","content":"what is 1+1?"}]}'

# You can check the engine status now
curl http://localhost:8000/engine_check

🌟 Acknowledgement

This project was written from scratch but we learned a lot from several other great open-source projects during development. Therefore, we wish to fully acknowledge their contribution to the open-source community. These projects include

# vllm
@inproceedings{kwon2023efficient,
  title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
  author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
  booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
  year={2023}
}

# flash attention v1 & v2
@inproceedings{dao2022flashattention,
  title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}
@article{dao2023flashattention2,
  title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
  author={Dao, Tri},
  year={2023}
}

# StreamingLLM
@article{xiao2023streamingllm,
  title={Efficient Streaming Language Models with Attention Sinks},
  author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike},
  journal={arXiv},
  year={2023}
}

# Distrifusion
@InProceedings{Li_2024_CVPR,
    author={Li, Muyang and Cai, Tianle and Cao, Jiaxin and Zhang, Qinsheng and Cai, Han and Bai, Junjie and Jia, Yangqing and Li, Kai and Han, Song},
    title={DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models},
    booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month={June},
    year={2024},
    pages={7183-7193}
}