Browse Source

[Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test
pull/5017/head
Bin Jia 1 year ago committed by GitHub
parent
commit
1db6727678
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/inference/__init__.py
  2. 75
      colossalai/inference/pipeline/README.md
  3. 7
      colossalai/inference/pipeline/benchmark/benchmark.py
  4. 6
      colossalai/inference/pipeline/benchmark/run.sh
  5. 91
      colossalai/inference/pipeline/engine.py
  6. 100
      colossalai/inference/pipeline/microbatch_manager.py
  7. 3
      colossalai/inference/pipeline/modeling/__init__.py
  8. 67
      colossalai/inference/pipeline/modeling/_utils.py
  9. 280
      colossalai/inference/pipeline/modeling/gpt2.py
  10. 483
      colossalai/inference/pipeline/modeling/llama.py
  11. 3
      colossalai/inference/pipeline/policies/__init__.py
  12. 145
      colossalai/inference/pipeline/policies/llama.py
  13. 74
      colossalai/inference/pipeline/policy/gpt2_ppinfer.py
  14. 48
      colossalai/inference/pipeline/policy/llama_ppinfer.py
  15. 35
      colossalai/inference/pipeline/utils.py
  16. 61
      colossalai/inference/tensor_parallel/batch_infer_state.py
  17. 82
      colossalai/pipeline/schedule/generate.py
  18. 3
      colossalai/shardformer/shard/shard_config.py
  19. 19
      tests/test_infer/test_pipeline_infer.py

3
colossalai/inference/__init__.py

@ -1,3 +1,4 @@
from .pipeline import PPInferEngine
__all__ = ["PPInferEngine"]
__all__ = ['PPInferEngine']

75
colossalai/inference/pipeline/README.md

@ -17,7 +17,7 @@
Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).
1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
- Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`.
- Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`.
- Run the pipeline inference model.
2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
@ -31,54 +31,53 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag
### Example
```python
from colossalai.pipeline import PPInferEngine
# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
model = LlamaForCausalLM.from_pretrained('/path/to/model')
inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
engine = PPInferEngine(
pp_size=2,
dtype='fp16',
micro_batch_size=1,
new_length=10,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy())
output = engine.inference([inputs])
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
```
colossalai.launch_from_torch(config={})
model = LlamaForCausalLM.from_pretrained("/path/to/model")
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
### Quick start
```shell
cd benchmark
sh run.sh
# assume the model is inferred with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32)
input = ["Introduce a landmark in London","Introduce a landmark in Singapore"]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference(data.to('cuda'))
print(tokenizer.batch_decode(output))
```
## Performance
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G.
### Llama Throughput(tokens/s)
### Llama Throughput (tokens/s) | input length=1024, output length=128
#### 7b, fp16
#### A10 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
#### 7b, fp32
#### A10 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
#### 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |
#### A800 7b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
#### A800 13b, fp16
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
| :---: | :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |

7
colossalai/inference/pipeline/benchmark/benchmark.py

@ -7,7 +7,7 @@ import transformers
import colossalai
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024
@ -117,8 +117,11 @@ if __name__ == "__main__":
micro_batch_size=args.mb_size,
new_length=args.new_length,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy(),
model_policy=LlamaModelInferPolicy(),
verbose=True,
max_batch_size=args.mb_size,
max_input_len=args.seq_len,
max_output_len=args.seq_len + args.new_length + 256,
)
data = data_gen(args.batch_size, args.seq_len)

6
colossalai/inference/pipeline/benchmark/run.sh

@ -1,7 +1,7 @@
script_dir=$(cd "$(dirname "$0")" && pwd)
cd "${script_dir}"
# 7b, fp32, 2 gpu, 1024, 128
# 7b, fp16, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do
--pp_size=2
done
# 7b, fp32, 2 gpu, 512, 512
# 7b, fp16, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16 32; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do
--pp_size=2
done
# 7b, fp32, 2 gpu, 1024, 128
# 7b, fp16, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \

91
colossalai/inference/pipeline/engine.py

@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
@ -7,6 +8,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from ..tensor_parallel.kvcache_manager import MemoryManager
from .microbatch_manager import MicroBatchManager
@ -23,20 +25,29 @@ class PPInferEngine:
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.
max_batch_size (int): the maximum batch size.
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.
Example:
```python
from colossalai.ppinference import PPInferEngine
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
# assume the model is infered with 4 pipeline stages
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})
colossalai.launch_from_torch(config={})
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference([data.to('cuda').data])
input = ["Hello, my dog is cute, and I like"]
tokenized_input = tokenizer(input, return_tensors='pt')
output = engine.inference([tokenized_input])
```
"""
@ -51,6 +62,9 @@ class PPInferEngine:
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
max_batch_size: int = 4,
max_input_len: int = 32,
max_output_len: int = 32,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
@ -58,24 +72,53 @@ class PPInferEngine:
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
max_output_len = max(max_output_len, max_input_len + new_length)
self.pp_size = pp_size
if dtype == "fp16":
self.dtype = torch.float16
model.half()
elif dtype == "bf16":
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
else:
self.dtype = torch.float32
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.model = pp_model or self._shardformer(model, model_policy)
self.cache_manager_list = [
self._init_manager(max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
self.stage_manager.stage,
new_length,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == "fp16":
model.half()
elif dtype == "bf16":
model.to(torch.bfloat16)
self.model = pp_model or self._shardformer(model, model_policy)
def inference(self, input_list):
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
"""
Args:
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
Returns:
out (list): a list of output data, each element is a list of token.
timestamp (float): the time cost of the inference, only return when verbose is `True`.
"""
assert isinstance(
input_list, (BatchEncoding, dict)
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
if isinstance(input_list, BatchEncoding):
input_list = input_list.data
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp
else:
@ -95,3 +138,17 @@ class PPInferEngine:
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
head_num = self.model.config.num_attention_heads
num_hidden_layers = (
self.model.config.num_hidden_layers
if hasattr(self.model.config, "num_hidden_layers")
else self.model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager

100
colossalai/inference/pipeline/microbatch_manager.py

@ -1,8 +1,11 @@
from enum import Enum
from typing import Dict, Tuple
from typing import Dict
import torch
from ..tensor_parallel.batch_infer_state import BatchInferState
from ..tensor_parallel.kvcache_manager import MemoryManager
__all__ = "MicroBatchManager"
@ -27,21 +30,20 @@ class MicroBatchDescription:
def __init__(
self,
inputs_dict: Dict[str, torch.Tensor],
output_dict: Dict[str, torch.Tensor],
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
new_length: int,
) -> None:
assert output_dict.get("hidden_states") is not None
self.mb_length = output_dict["hidden_states"].shape[-2]
self.mb_length = inputs_dict["input_ids"].shape[-1]
self.target_length = self.mb_length + new_length
self.kv_cache = ()
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
if output_dict is not None:
self._update_kvcache(output_dict["past_key_values"])
self.infer_state = BatchInferState.init_from_batch(
batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
)
# print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}")
def _update_kvcache(self, kv_cache: Tuple):
assert type(kv_cache) == tuple
self.kv_cache = kv_cache
def update(self, *args, **kwargs):
pass
@property
def state(self):
@ -80,17 +82,21 @@ class HeadMicroBatchDescription(MicroBatchDescription):
"""
def __init__(
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
self,
inputs_dict: Dict[str, torch.Tensor],
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
new_length: int,
) -> None:
super().__init__(inputs_dict, output_dict, new_length)
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length)
assert inputs_dict is not None
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
self.input_ids = inputs_dict["input_ids"]
self.attn_mask = inputs_dict["attention_mask"]
self.new_tokens = None
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
super().update(output_dict, new_token)
def update(self, new_token: torch.Tensor = None):
if new_token is not None:
self._update_newtokens(new_token)
if self.state is not Status.DONE and new_token is not None:
@ -125,16 +131,17 @@ class BodyMicroBatchDescription(MicroBatchDescription):
Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""
def __init__(
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
self,
inputs_dict: Dict[str, torch.Tensor],
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
new_length: int,
) -> None:
super().__init__(inputs_dict, output_dict, new_length)
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
super().update(output_dict, new_token)
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length)
@property
def cur_length(self):
@ -142,10 +149,7 @@ class BodyMicroBatchDescription(MicroBatchDescription):
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
"""
if len(self.kv_cache) == 0:
return self.mb_length
else:
return self.kv_cache[0][0].shape[-2] + 1
return self.infer_state.seq_len.max().item()
class MicroBatchManager:
@ -160,16 +164,38 @@ class MicroBatchManager:
"""
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
def __init__(
self,
stage: int,
new_length: int,
micro_batch_size: int,
micro_batch_buffer_size: int,
max_input_len: int,
max_output_len: int,
cache_manager_list: MemoryManager,
):
self.stage = stage
self.new_length = new_length
self.micro_batch_size = micro_batch_size
self.buffer_size = micro_batch_buffer_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.cache_manager_list = cache_manager_list
self.mb_descrption_buffer = {}
self.new_tokens_buffer = {}
self.idx = 0
def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length
)
else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length
)
def step(self, new_token: torch.Tensor = None):
"""
Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
@ -181,11 +207,7 @@ class MicroBatchManager:
new_token (torch.Tensor): the new token generated by current stage.
"""
# Add descrption first if the descrption is None
if inputs_dict is None and output_dict is None and new_token is None:
return Status.PREFILL
if self.mb_descrption_buffer.get(self.idx) is None:
self._add_descrption(inputs_dict, output_dict)
self.cur_descrption.update(output_dict, new_token)
self.cur_descrption.update(new_token)
return self.cur_state
def export_new_tokens(self):
@ -204,16 +226,12 @@ class MicroBatchManager:
def clear(self):
self.mb_descrption_buffer.clear()
for cache in self.cache_manager_list:
cache.free_all()
def next(self):
self.idx = (self.idx + 1) % self.buffer_size
def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]):
if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length)
else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length)
def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)
@ -222,10 +240,10 @@ class MicroBatchManager:
return self.mb_descrption_buffer.get(self.idx)
@property
def cur_kv_cache(self):
def cur_infer_state(self):
if self.cur_descrption is None:
return None
return self.cur_descrption.kv_cache
return self.cur_descrption.infer_state
@property
def cur_state(self):

3
colossalai/inference/pipeline/modeling/__init__.py

@ -0,0 +1,3 @@
from .llama import LlamaInferenceForwards
__all__ = ["LlamaInferenceForwards"]

67
colossalai/inference/pipeline/modeling/_utils.py

@ -0,0 +1,67 @@
"""
Utils for model inference
"""
import os
import torch
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
"""
This function copies the key and value cache to the memory cache
Args:
layer_id : id of current layer
key_buffer : key cache
value_buffer : value cache
context_mem_index : index of memory cache in kv cache manager
mem_manager : cache manager
"""
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
def init_to_get_rotary(self, base=10000, use_elem=False):
"""
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
Args:
self : Model that holds the rotary positional embedding
base : calculation arg
use_elem : activated when using chatglm-based models
"""
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
if ntk_alpha is not None:
ntk_alpha = float(ntk_alpha)
assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
if ntk_alpha > 1:
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
n_elem = self.config.head_dim_
if use_elem:
n_elem //= 2
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()

280
colossalai/inference/pipeline/modeling/gpt2.py

@ -1,280 +0,0 @@
from typing import Dict, List, Optional, Tuple, Union
import torch
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
class GPT2PipelineForwards:
"""
This class serves as a micro library for forward function substitution of GPT2 models
under pipeline setting.
"""
@staticmethod
def gpt2_model_forward(
self: GPT2Model,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if stage_manager.is_first_stage():
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
else:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
for i, layer_past in zip(range(start_idx, end_idx), past_key_values):
block = self.h[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return {"hidden_states": hidden_states, "past_key_values": presents}
@staticmethod
def gpt2_lmhead_model_forward(
self: GPT2LMHeadModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
Please refer to original code of transformers for more details.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# If is first stage and after warmup, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {"logits": lm_logits}
# Not first stage or before warmup, go through gpt2 model
outputs = GPT2PipelineForwards.gpt2_model_forward(
self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
return outputs

483
colossalai/inference/pipeline/modeling/llama.py

@ -1,36 +1,100 @@
from typing import List, Optional
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
from typing import List, Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from transformers.utils import logging
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.pipeline.stage_manager import PipelineStageManager
from ._utils import copy_kv_to_mem_cache
class LlamaPipelineForwards:
try:
from vllm import layernorm_ops, pos_encoding_ops
rms_norm = layernorm_ops.rms_norm
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
print(
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
)
HAS_VLLM_KERNERL = False
try:
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaInferenceForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
This class holds forwards for llama inference.
We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
"""
def llama_model_forward(
self: LlamaModel,
@staticmethod
def llama_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
@ -38,11 +102,57 @@ class LlamaPipelineForwards:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# If is first stage and after warmup, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {"logits": lm_logits}
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaInferenceForwards.llama_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
infer_state=infer_state,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
return outputs
@staticmethod
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
infer_state: BatchInferState = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
# batch_size = input_ids.shape[0] # input_ids.shape[0]
# print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}")
# infer_state = self.infer_state
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if stage_manager is None or stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
@ -56,6 +166,8 @@ class LlamaPipelineForwards:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
assert stage_manager is not None
assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}"
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
@ -63,167 +175,292 @@ class LlamaPipelineForwards:
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if infer_state.is_context_stage is False:
past_key_values_length = infer_state.cache_manager.past_key_values_length
seq_length_with_past = seq_length_with_past + past_key_values_length
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1:
# NOTE assume prefill stage
# allocate memory block
infer_state.is_context_stage = True # set prefill stage, notify attention layer
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
infer_state.init_block_loc(
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
)
else:
infer_state.is_context_stage = False
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
if position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
position_ids = position_ids.unsqueeze(0)
new_shape = [1] * position_ids.dim()
new_shape[0] = batch_size
position_ids = position_ids.repeat(*new_shape).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if infer_state.is_context_stage:
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
() if output_hidden_states else None
() if output_attentions else None
next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0
start_idx, end_idx = stage_index[0], stage_index[1]
if past_key_values is None:
past_key_values = tuple([None] * (end_idx - start_idx + 1))
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
decoder_layer = self.layers[idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
# past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
# NOTE: modify here for passing args to decoder layer
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
infer_state=infer_state,
)
infer_state.decode_layer_id += 1
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
# always return dict for imediate stage
# update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
# TODO: fix this to necessary return
# if not return_dict:
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
# return BaseModelOutputWithPast(
# last_hidden_state=hidden_states,
# past_key_values=next_cache,
# hidden_states=all_hidden_states,
# attentions=all_self_attns,
# )
# print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}")
return {"hidden_states": hidden_states, "past_key_values": next_cache}
def llama_for_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
@staticmethod
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Returns:
hidden_states = self.input_layernorm(hidden_states)
Example:
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
infer_state=infer_state,
)
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
hidden_states = residual + hidden_states
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
outputs = (hidden_states,)
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
if output_attentions:
outputs += (self_attn_weights,)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if use_cache:
outputs += (present_key_value,)
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
return outputs
# If is first stage and after warmup, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {"logits": lm_logits}
@staticmethod
def llama_flash_attn_kvcache_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
infer_state: Optional[BatchInferState] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert use_cache is True, "use_cache should be set to True using this llama attention"
bsz, q_len, _ = hidden_states.size()
# NOTE might think about better way to handle transposed k and v
# key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
if infer_state.is_context_stage:
# print(f"rank:{torch.distributed.get_rank()}, {infer_state}")
# first token generation
# copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.context_mem_index,
infer_state.cache_manager,
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
attn_output = torch.empty_like(query_states)
return outputs
llama_context_attn_fwd(
query_states,
key_states,
value_states,
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
]
cache_k.copy_(key_states)
cache_v.copy_(value_states)
else:
# if decode is not contiguous, use triton kernel to copy key and value cache
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_states,
value_states,
infer_state.decode_mem_index,
infer_state.cache_manager,
)
# second token and follows
# kv = torch.stack((key_states, value_states), dim=2)
# (batch_size, seqlen, nheads, headdim)
attn_output = torch.empty_like(query_states)
token_attention_fwd(
query_states,
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
attn_output,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
# print(f"rank:{torch.distributed.get_rank()}, {attn_output}")
attn_output = self.o_proj(attn_output)
# return past_key_value as None
return attn_output, None, None
def get_llama_vllm_rmsnorm_forward():
if HAS_VLLM_KERNERL:
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
return _vllm_rmsnorm_forward
else:
return None

3
colossalai/inference/pipeline/policies/__init__.py

@ -0,0 +1,3 @@
from .llama import LlamaModelInferPolicy
__all__ = ["LlamaModelInferPolicy"]

145
colossalai/inference/pipeline/policies/llama.py

@ -0,0 +1,145 @@
from functools import partial
from typing import List
import torch
from torch.nn import Module
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
try:
from colossalai.kernel.triton import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return _triton_rmsnorm_forward
else:
return None
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.shard_config.inference_gptq:
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement=decoder_attribute_replacement,
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 1},
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=RowCaiQuantLinear,
kwargs={"split_num": 1},
),
],
)
self.shard_config._infer()
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
)
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaAttention
)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
)
return policy
def postprocess(self):
init_to_get_rotary(self.model.model)
return self.model
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_first_stage():
held_layers.append(self.model.lm_head)
return held_layers

74
colossalai/inference/pipeline/policy/gpt2_ppinfer.py

@ -1,74 +0,0 @@
from functools import partial
from typing import Callable, Dict, List
from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.policies.gpt2 import GPT2Policy
from ..modeling.gpt2 import GPT2PipelineForwards
class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
module_policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
)
]
)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
policy=module_policy,
)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
# make the tie weight lm_head and embedding in the same device to save memory
# if self.pipeline_stage_manager.is_first_stage():
if self.pipeline_stage_manager.is_first_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""The weights of wte and lm_head are shared."""
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
return []
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "GPT2Model":
module = self.model
else:
module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

48
colossalai/inference/pipeline/policy/llama_ppinfer.py

@ -1,48 +0,0 @@
from typing import List
from torch.nn import Module
from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaPolicy
from ..modeling.llama import LlamaPipelineForwards
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import LlamaForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_first_stage():
held_layers.append(self.model.lm_head)
return held_layers

35
colossalai/inference/pipeline/utils.py

@ -1,35 +0,0 @@
from typing import Set
import torch.nn as nn
from colossalai.shardformer._utils import getattr_, setattr_
def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None:
"""
Set all parameters and buffers of model to None
Args:
model (nn.Module): The model to set
"""
for module_suffix in include:
set_module = getattr_(model, module_suffix)
for n, p in set_module.named_parameters():
setattr_(set_module, n, None)
for n, buf in set_module.named_buffers():
setattr_(set_module, n, None)
setattr_(model, module_suffix, None)
def get_suffix_name(suffix: str, name: str):
"""
Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit,
and 'name' when `suffix` is empty.
Args:
suffix (str): The suffix of the suffix module
name (str): The name of the current module
"""
point = "" if suffix is "" else "."
suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}"
return suffix_name

61
colossalai/inference/tensor_parallel/batch_infer_state.py

@ -2,9 +2,11 @@
from dataclasses import dataclass
import torch
from transformers.tokenization_utils_base import BatchEncoding
from .kvcache_manager import MemoryManager
# adapted from: lightllm/server/router/model_infer/infer_batch.py
@dataclass
class BatchInferState:
@ -55,3 +57,62 @@ class BatchInferState:
]
start_index += cur_seq_len
return
@classmethod
def init_from_batch(
cls,
batch: torch.Tensor,
max_input_len: int,
max_output_len: int,
cache_manager: MemoryManager,
):
if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)):
raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state")
input_ids_list = None
attention_mask = None
if isinstance(batch, (BatchEncoding, dict)):
input_ids_list = batch["input_ids"]
attention_mask = batch["attention_mask"]
else:
input_ids_list = batch
if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
max_len_in_batch = -1
if isinstance(batch, (BatchEncoding, dict)):
for i, attn_mask in enumerate(attention_mask):
curr_seq_len = len(attn_mask)
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
else:
length = max(len(input_id) for input_id in input_ids_list)
for i, input_ids in enumerate(input_ids_list):
curr_seq_len = length
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda")
return cls(
batch_size=batch_size,
max_len_in_batch=max_len_in_batch,
seq_len=seq_lengths.to("cuda"),
start_loc=seq_start_indexes.to("cuda"),
block_loc=block_loc,
decode_layer_id=0,
past_key_values_len=0,
is_context_stage=True,
cache_manager=cache_manager,
)

82
colossalai/pipeline/schedule/generate.py

@ -93,9 +93,9 @@ class GenerateSchedule(PipelineSchedule):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = (
{"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
)
model_inputs = {
'infer_state': self.mb_manager.cur_descrption.infer_state
}
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
@ -108,9 +108,8 @@ class GenerateSchedule(PipelineSchedule):
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
"""
new_mask = self.mb_manager.cur_descrption.attn_mask
past_key_values = self.mb_manager.cur_descrption.kv_cache
return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values)
return dict(input_ids=new_token, attention_mask=new_mask)
def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor:
last_hidden_state = hidden_state[:, -1]
@ -128,27 +127,38 @@ class GenerateSchedule(PipelineSchedule):
return self.comm.p2p_recv()
return self.comm.recv_forward()
def _init_infer_state_action(self) -> None:
"""
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
def _load_stage_action(self, model: Module) -> None:
"""
In this action, 1.load micro_batch 2.do the forward 3.step to update
This action is only for first stage, load, init and do forward.
1.load micro_batch 2.do the forward 3.step to update
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
output_dict = model_forward(model, inputs_dict, None)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _gen_token_action(self, model: Module):
"""
In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update
This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
hidden_states = {"hidden_states": hidden_states}
logits = model_forward(model, None, hidden_states)
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
@ -157,7 +167,7 @@ class GenerateSchedule(PipelineSchedule):
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(None, None, new_token)
self.mb_manager.step(new_token)
self.action_interval_buffer.new_token = new_token
self.action_interval_buffer.hidden_states = None
@ -168,20 +178,18 @@ class GenerateSchedule(PipelineSchedule):
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
output_dict = model_forward(model, inputs_dict, None)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage()
hidden_states = {"hidden_states": hidden_states}
output_dict = model_forward(model, inputs_dict, hidden_states)
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
@ -218,6 +226,8 @@ class GenerateSchedule(PipelineSchedule):
actions.append(partial(self._gen_token_action, model))
# other stage
else:
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._init_infer_state_action))
actions.append(partial(self._comm_action, True))
actions.append(partial(self._body_encoding_action, model))
@ -308,8 +318,9 @@ class GenerateSchedule(PipelineSchedule):
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
# Get hidden_states from previous stage
@ -319,25 +330,28 @@ class GenerateSchedule(PipelineSchedule):
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
logits = model_forward(model, None, hidden_states)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(None, None, new_token)
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage()
output_dict = model_forward(model, inputs_dict, hidden_states)
self.mb_manager.step(inputs_dict, output_dict, None)
# inputs_dict = self._prepare_inputs_for_interval_stage()
inputs_dict = None
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# Current microbatch is not DONE, send hidden_state to next stage
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (

3
colossalai/shardformer/shard/shard_config.py

@ -76,4 +76,5 @@ class ShardConfig:
"""
Set default params for inference.
"""
assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
pass

19
tests/test_infer/test_pipeline_infer.py

@ -2,12 +2,15 @@ import pytest
import torch
import torch.distributed as dist
import transformers
from packaging import version
import colossalai
from colossalai.inference.pipeline.engine import PPInferEngine
from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
from colossalai.inference.pipeline import PPInferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
@ -24,20 +27,21 @@ for k, v in inputs.items():
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4))
engine = PPInferEngine(
pp_size=pp_size,
model=model,
model_policy=GPT2LMHeadModelPipelinePolicy(),
model_policy=LlamaModelInferPolicy(),
new_length=new_length,
micro_batch_size=micro_batch_size,
)
output = engine.inference([inputs])
output = engine.inference(inputs)
if dist.get_rank() == 0:
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
@parameterize("pp_size", [4])
@parameterize("pp_size", [2])
@parameterize("new_length", [4, 8, 16])
@parameterize("micro_batch_size", [1, 4])
@clear_cache_before_run()
@ -51,11 +55,12 @@ def check_pipeline_inference(rank, world_size, port):
run_pipeline_inference_test()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=4)
spawn(check_pipeline_inference, nprocs=2)
if __name__ == "__main__":

Loading…
Cancel
Save