mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline inference] Combine kvcache with pipeline inference (#4938)
* merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer testpull/5017/head
parent
c6cd629e7a
commit
1db6727678
|
@ -1,3 +1,4 @@
|
|||
from .pipeline import PPInferEngine
|
||||
|
||||
__all__ = ["PPInferEngine"]
|
||||
|
||||
__all__ = ['PPInferEngine']
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).
|
||||
|
||||
1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
|
||||
- Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`.
|
||||
- Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`.
|
||||
- Run the pipeline inference model.
|
||||
|
||||
2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
|
||||
|
@ -31,54 +31,53 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag
|
|||
|
||||
### Example
|
||||
```python
|
||||
from colossalai.pipeline import PPInferEngine
|
||||
# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
|
||||
model = LlamaForCausalLM.from_pretrained('/path/to/model')
|
||||
inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
|
||||
engine = PPInferEngine(
|
||||
pp_size=2,
|
||||
dtype='fp16',
|
||||
micro_batch_size=1,
|
||||
new_length=10,
|
||||
model=model,
|
||||
model_policy=LlamaForCausalLMPipelinePolicy())
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
output = engine.inference([inputs])
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
```
|
||||
model = LlamaForCausalLM.from_pretrained("/path/to/model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
|
||||
|
||||
### Quick start
|
||||
```shell
|
||||
cd benchmark
|
||||
sh run.sh
|
||||
# assume the model is inferred with 2 pipeline stages
|
||||
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32)
|
||||
|
||||
input = ["Introduce a landmark in London","Introduce a landmark in Singapore"]
|
||||
data = tokenizer(input, return_tensors='pt')
|
||||
output = inferengine.inference(data.to('cuda'))
|
||||
print(tokenizer.batch_decode(output))
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.
|
||||
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G.
|
||||
|
||||
### Llama Throughput(tokens/s)
|
||||
### Llama Throughput (tokens/s) | input length=1024, output length=128
|
||||
|
||||
#### 7b, fp16
|
||||
#### A10 7b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
|
||||
| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
|
||||
| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
|
||||
| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |
|
||||
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
|
||||
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
|
||||
|
||||
#### 7b, fp32
|
||||
#### A10 13b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
|
||||
| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
|
||||
| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
|
||||
| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |
|
||||
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
|
||||
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
|
||||
|
||||
#### 13b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
|
||||
| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
|
||||
| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
|
||||
| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |
|
||||
|
||||
#### A800 7b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
|
||||
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
|
||||
|
||||
|
||||
#### A800 13b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
|
||||
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
|
||||
|
|
|
@ -7,7 +7,7 @@ import transformers
|
|||
|
||||
import colossalai
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
@ -117,8 +117,11 @@ if __name__ == "__main__":
|
|||
micro_batch_size=args.mb_size,
|
||||
new_length=args.new_length,
|
||||
model=model,
|
||||
model_policy=LlamaForCausalLMPipelinePolicy(),
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
verbose=True,
|
||||
max_batch_size=args.mb_size,
|
||||
max_input_len=args.seq_len,
|
||||
max_output_len=args.seq_len + args.new_length + 256,
|
||||
)
|
||||
data = data_gen(args.batch_size, args.seq_len)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
script_dir=$(cd "$(dirname "$0")" && pwd)
|
||||
cd "${script_dir}"
|
||||
|
||||
# 7b, fp32, 2 gpu, 1024, 128
|
||||
# 7b, fp16, 2 gpu, 1024, 128
|
||||
for BATCH_SIZE in 2 4 8 16; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="7b" \
|
||||
|
@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do
|
|||
--pp_size=2
|
||||
done
|
||||
|
||||
# 7b, fp32, 2 gpu, 512, 512
|
||||
# 7b, fp16, 2 gpu, 512, 512
|
||||
for BATCH_SIZE in 2 4 8 16 32; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="7b" \
|
||||
|
@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do
|
|||
--pp_size=2
|
||||
done
|
||||
|
||||
# 7b, fp32, 2 gpu, 1024, 128
|
||||
# 7b, fp16, 2 gpu, 1024, 128
|
||||
for BATCH_SIZE in 2 4 8; do
|
||||
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
|
||||
--model="13b" \
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.schedule.generate import GenerateSchedule
|
||||
|
@ -7,6 +8,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from ..tensor_parallel.kvcache_manager import MemoryManager
|
||||
from .microbatch_manager import MicroBatchManager
|
||||
|
||||
|
||||
|
@ -23,20 +25,29 @@ class PPInferEngine:
|
|||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
new_length (int): the new length of the input sequence.
|
||||
early_stopping (bool): whether to stop early.
|
||||
max_batch_size (int): the maximum batch size.
|
||||
max_input_len (int): the maximum input length.
|
||||
max_output_len (int): the maximum output length.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from colossalai.ppinference import PPInferEngine
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
# assume the model is infered with 4 pipeline stages
|
||||
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
||||
# assume the model is infered with 2 pipeline stages
|
||||
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
|
||||
|
||||
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
|
||||
data = tokenizer(input, return_tensors='pt')
|
||||
output = inferengine.inference([data.to('cuda').data])
|
||||
|
||||
input = ["Hello, my dog is cute, and I like"]
|
||||
tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
output = engine.inference([tokenized_input])
|
||||
```
|
||||
|
||||
"""
|
||||
|
@ -51,6 +62,9 @@ class PPInferEngine:
|
|||
new_length: int = 32,
|
||||
micro_batch_size: int = 1,
|
||||
micro_batch_buffer_size: int = None,
|
||||
max_batch_size: int = 4,
|
||||
max_input_len: int = 32,
|
||||
max_output_len: int = 32,
|
||||
verbose: bool = False,
|
||||
# TODO: implement early_stopping, and various gerneration options
|
||||
early_stopping: bool = False,
|
||||
|
@ -58,24 +72,53 @@ class PPInferEngine:
|
|||
num_beams: int = 1,
|
||||
) -> None:
|
||||
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
|
||||
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
|
||||
max_output_len = max(max_output_len, max_input_len + new_length)
|
||||
|
||||
self.pp_size = pp_size
|
||||
if dtype == "fp16":
|
||||
self.dtype = torch.float16
|
||||
model.half()
|
||||
elif dtype == "bf16":
|
||||
self.dtype = torch.bfloat16
|
||||
model.to(torch.bfloat16)
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
self.pg_mesh = ProcessGroupMesh(pp_size)
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
|
||||
self.model = pp_model or self._shardformer(model, model_policy)
|
||||
self.cache_manager_list = [
|
||||
self._init_manager(max_batch_size, max_input_len, max_output_len)
|
||||
for _ in range(micro_batch_buffer_size or pp_size)
|
||||
]
|
||||
self.mb_manager = MicroBatchManager(
|
||||
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
|
||||
self.stage_manager.stage,
|
||||
new_length,
|
||||
micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size,
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
self.cache_manager_list,
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
|
||||
|
||||
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
if dtype == "fp16":
|
||||
model.half()
|
||||
elif dtype == "bf16":
|
||||
model.to(torch.bfloat16)
|
||||
self.model = pp_model or self._shardformer(model, model_policy)
|
||||
|
||||
def inference(self, input_list):
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
|
||||
"""
|
||||
Args:
|
||||
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
|
||||
|
||||
Returns:
|
||||
out (list): a list of output data, each element is a list of token.
|
||||
timestamp (float): the time cost of the inference, only return when verbose is `True`.
|
||||
"""
|
||||
assert isinstance(
|
||||
input_list, (BatchEncoding, dict)
|
||||
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
|
||||
if isinstance(input_list, BatchEncoding):
|
||||
input_list = input_list.data
|
||||
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
|
||||
if self.verbose:
|
||||
return out, timestamp
|
||||
else:
|
||||
|
@ -95,3 +138,17 @@ class PPInferEngine:
|
|||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
||||
|
||||
def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
|
||||
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
|
||||
head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
|
||||
head_num = self.model.config.num_attention_heads
|
||||
num_hidden_layers = (
|
||||
self.model.config.num_hidden_layers
|
||||
if hasattr(self.model.config, "num_hidden_layers")
|
||||
else self.model.config.num_layers
|
||||
)
|
||||
layer_num = num_hidden_layers // self.pp_size
|
||||
|
||||
cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
|
||||
return cache_manager
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from enum import Enum
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from ..tensor_parallel.batch_infer_state import BatchInferState
|
||||
from ..tensor_parallel.kvcache_manager import MemoryManager
|
||||
|
||||
__all__ = "MicroBatchManager"
|
||||
|
||||
|
||||
|
@ -27,21 +30,20 @@ class MicroBatchDescription:
|
|||
def __init__(
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
output_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
new_length: int,
|
||||
) -> None:
|
||||
assert output_dict.get("hidden_states") is not None
|
||||
self.mb_length = output_dict["hidden_states"].shape[-2]
|
||||
self.mb_length = inputs_dict["input_ids"].shape[-1]
|
||||
self.target_length = self.mb_length + new_length
|
||||
self.kv_cache = ()
|
||||
self.infer_state = BatchInferState.init_from_batch(
|
||||
batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
|
||||
)
|
||||
# print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}")
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
if output_dict is not None:
|
||||
self._update_kvcache(output_dict["past_key_values"])
|
||||
|
||||
def _update_kvcache(self, kv_cache: Tuple):
|
||||
assert type(kv_cache) == tuple
|
||||
self.kv_cache = kv_cache
|
||||
def update(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
|
@ -80,17 +82,21 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
new_length: int,
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, output_dict, new_length)
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length)
|
||||
assert inputs_dict is not None
|
||||
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
|
||||
self.input_ids = inputs_dict["input_ids"]
|
||||
self.attn_mask = inputs_dict["attention_mask"]
|
||||
self.new_tokens = None
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
super().update(output_dict, new_token)
|
||||
def update(self, new_token: torch.Tensor = None):
|
||||
if new_token is not None:
|
||||
self._update_newtokens(new_token)
|
||||
if self.state is not Status.DONE and new_token is not None:
|
||||
|
@ -125,16 +131,17 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
|||
|
||||
Args:
|
||||
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
|
||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
|
||||
self,
|
||||
inputs_dict: Dict[str, torch.Tensor],
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
new_length: int,
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, output_dict, new_length)
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
super().update(output_dict, new_token)
|
||||
super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
|
@ -142,10 +149,7 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
|||
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
|
||||
|
||||
"""
|
||||
if len(self.kv_cache) == 0:
|
||||
return self.mb_length
|
||||
else:
|
||||
return self.kv_cache[0][0].shape[-2] + 1
|
||||
return self.infer_state.seq_len.max().item()
|
||||
|
||||
|
||||
class MicroBatchManager:
|
||||
|
@ -160,16 +164,38 @@ class MicroBatchManager:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
|
||||
def __init__(
|
||||
self,
|
||||
stage: int,
|
||||
new_length: int,
|
||||
micro_batch_size: int,
|
||||
micro_batch_buffer_size: int,
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager_list: MemoryManager,
|
||||
):
|
||||
self.stage = stage
|
||||
self.new_length = new_length
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.buffer_size = micro_batch_buffer_size
|
||||
self.max_input_len = max_input_len
|
||||
self.max_output_len = max_output_len
|
||||
self.cache_manager_list = cache_manager_list
|
||||
self.mb_descrption_buffer = {}
|
||||
self.new_tokens_buffer = {}
|
||||
self.idx = 0
|
||||
|
||||
def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
|
||||
if self.stage == 0:
|
||||
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length
|
||||
)
|
||||
else:
|
||||
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
|
||||
inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length
|
||||
)
|
||||
|
||||
def step(self, new_token: torch.Tensor = None):
|
||||
"""
|
||||
Update the state if microbatch manager, 2 conditions.
|
||||
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
|
||||
|
@ -181,11 +207,7 @@ class MicroBatchManager:
|
|||
new_token (torch.Tensor): the new token generated by current stage.
|
||||
"""
|
||||
# Add descrption first if the descrption is None
|
||||
if inputs_dict is None and output_dict is None and new_token is None:
|
||||
return Status.PREFILL
|
||||
if self.mb_descrption_buffer.get(self.idx) is None:
|
||||
self._add_descrption(inputs_dict, output_dict)
|
||||
self.cur_descrption.update(output_dict, new_token)
|
||||
self.cur_descrption.update(new_token)
|
||||
return self.cur_state
|
||||
|
||||
def export_new_tokens(self):
|
||||
|
@ -204,16 +226,12 @@ class MicroBatchManager:
|
|||
|
||||
def clear(self):
|
||||
self.mb_descrption_buffer.clear()
|
||||
for cache in self.cache_manager_list:
|
||||
cache.free_all()
|
||||
|
||||
def next(self):
|
||||
self.idx = (self.idx + 1) % self.buffer_size
|
||||
|
||||
def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]):
|
||||
if self.stage == 0:
|
||||
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length)
|
||||
else:
|
||||
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length)
|
||||
|
||||
def _remove_descrption(self):
|
||||
self.mb_descrption_buffer.pop(self.idx)
|
||||
|
||||
|
@ -222,10 +240,10 @@ class MicroBatchManager:
|
|||
return self.mb_descrption_buffer.get(self.idx)
|
||||
|
||||
@property
|
||||
def cur_kv_cache(self):
|
||||
def cur_infer_state(self):
|
||||
if self.cur_descrption is None:
|
||||
return None
|
||||
return self.cur_descrption.kv_cache
|
||||
return self.cur_descrption.infer_state
|
||||
|
||||
@property
|
||||
def cur_state(self):
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .llama import LlamaInferenceForwards
|
||||
|
||||
__all__ = ["LlamaInferenceForwards"]
|
|
@ -0,0 +1,67 @@
|
|||
"""
|
||||
Utils for model inference
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||
|
||||
|
||||
def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
|
||||
"""
|
||||
This function copies the key and value cache to the memory cache
|
||||
Args:
|
||||
layer_id : id of current layer
|
||||
key_buffer : key cache
|
||||
value_buffer : value cache
|
||||
context_mem_index : index of memory cache in kv cache manager
|
||||
mem_manager : cache manager
|
||||
"""
|
||||
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
|
||||
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
"""
|
||||
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
|
||||
Args:
|
||||
self : Model that holds the rotary positional embedding
|
||||
base : calculation arg
|
||||
use_elem : activated when using chatglm-based models
|
||||
"""
|
||||
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
|
||||
if not hasattr(self.config, "rope_scaling"):
|
||||
rope_scaling_factor = 1.0
|
||||
else:
|
||||
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
|
||||
|
||||
if hasattr(self.config, "max_sequence_length"):
|
||||
max_seq_len = self.config.max_sequence_length
|
||||
elif hasattr(self.config, "max_position_embeddings"):
|
||||
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
|
||||
else:
|
||||
max_seq_len = 2048 * rope_scaling_factor
|
||||
base = float(base)
|
||||
|
||||
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
|
||||
|
||||
if ntk_alpha is not None:
|
||||
ntk_alpha = float(ntk_alpha)
|
||||
assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
|
||||
if ntk_alpha > 1:
|
||||
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
|
||||
max_seq_len *= ntk_alpha
|
||||
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
|
||||
|
||||
n_elem = self.config.head_dim_
|
||||
if use_elem:
|
||||
n_elem //= 2
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
|
||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
|
||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
|
@ -1,280 +0,0 @@
|
|||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class GPT2PipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of GPT2 models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def gpt2_model_forward(
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Preprocess passed in arguments
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
else:
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape[0], input_shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if stage_manager.is_first_stage():
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
else:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# Going through held blocks.
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for i, layer_past in zip(range(start_idx, end_idx), past_key_values):
|
||||
block = self.h[i]
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
return {"hidden_states": hidden_states, "past_key_values": presents}
|
||||
|
||||
@staticmethod
|
||||
def gpt2_lmhead_model_forward(
|
||||
self: GPT2LMHeadModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# If is first stage and after warmup, go throught lm_head first
|
||||
if stage_manager.is_first_stage() and hidden_states is not None:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
return {"logits": lm_logits}
|
||||
|
||||
# Not first stage or before warmup, go through gpt2 model
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(
|
||||
self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
|
||||
return outputs
|
|
@ -1,158 +1,72 @@
|
|||
from typing import List, Optional
|
||||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
try:
|
||||
from vllm import layernorm_ops, pos_encoding_ops
|
||||
|
||||
rms_norm = layernorm_ops.rms_norm
|
||||
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
|
||||
HAS_VLLM_KERNERL = True
|
||||
except:
|
||||
print("fall back to original rotary_embedding_neox of huggingface")
|
||||
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
|
||||
print(
|
||||
"if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
|
||||
)
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
try:
|
||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class LlamaInferenceForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
This class holds forwards for llama inference.
|
||||
We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
|
||||
"""
|
||||
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Preprocess passed in arguments
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * (end_idx - start_idx + 1))
|
||||
|
||||
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
|
||||
decoder_layer = self.layers[idx]
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
# past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
# always return dict for imediate stage
|
||||
return {"hidden_states": hidden_states, "past_key_values": next_cache}
|
||||
|
||||
def llama_for_causal_lm_forward(
|
||||
@staticmethod
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
|
@ -164,6 +78,7 @@ class LlamaPipelineForwards:
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
infer_state: BatchInferState = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
|
@ -175,24 +90,7 @@ class LlamaPipelineForwards:
|
|||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
@ -210,7 +108,7 @@ class LlamaPipelineForwards:
|
|||
return {"logits": lm_logits}
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = LlamaPipelineForwards.llama_model_forward(
|
||||
outputs = LlamaInferenceForwards.llama_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
|
@ -221,9 +119,348 @@ class LlamaPipelineForwards:
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
infer_state=infer_state,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
infer_state: BatchInferState = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
# batch_size = input_ids.shape[0] # input_ids.shape[0]
|
||||
# print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}")
|
||||
|
||||
# infer_state = self.infer_state
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
assert stage_manager is not None
|
||||
assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}"
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if infer_state.is_context_stage is False:
|
||||
past_key_values_length = infer_state.cache_manager.past_key_values_length
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
# NOTE: differentiate with prefill stage
|
||||
# block_loc require different value-assigning method for two different stage
|
||||
if use_cache and seq_length != 1:
|
||||
# NOTE assume prefill stage
|
||||
# allocate memory block
|
||||
infer_state.is_context_stage = True # set prefill stage, notify attention layer
|
||||
infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
|
||||
infer_state.init_block_loc(
|
||||
infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
|
||||
)
|
||||
else:
|
||||
infer_state.is_context_stage = False
|
||||
alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
|
||||
if alloc_mem is not None:
|
||||
infer_state.decode_is_contiguous = True
|
||||
infer_state.decode_mem_index = alloc_mem[0]
|
||||
infer_state.decode_mem_start = alloc_mem[1]
|
||||
infer_state.decode_mem_end = alloc_mem[2]
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
else:
|
||||
print(f" *** Encountered allocation non-contiguous")
|
||||
print(
|
||||
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
|
||||
)
|
||||
infer_state.decode_is_contiguous = False
|
||||
alloc_mem = infer_state.cache_manager.alloc(batch_size)
|
||||
infer_state.decode_mem_index = alloc_mem
|
||||
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
|
||||
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
|
||||
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
new_shape = [1] * position_ids.dim()
|
||||
new_shape[0] = batch_size
|
||||
position_ids = position_ids.repeat(*new_shape).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
|
||||
position_ids.view(-1).shape[0], -1
|
||||
)
|
||||
else:
|
||||
seq_len = infer_state.seq_len
|
||||
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
() if output_hidden_states else None
|
||||
() if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
infer_state.decode_layer_id = 0
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * (end_idx - start_idx + 1))
|
||||
|
||||
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
|
||||
decoder_layer = self.layers[idx]
|
||||
# NOTE: modify here for passing args to decoder layer
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
infer_state.decode_layer_id += 1
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
# update indices
|
||||
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
|
||||
infer_state.seq_len += 1
|
||||
|
||||
# TODO: fix this to necessary return
|
||||
# if not return_dict:
|
||||
# return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
# return BaseModelOutputWithPast(
|
||||
# last_hidden_state=hidden_states,
|
||||
# past_key_values=next_cache,
|
||||
# hidden_states=all_hidden_states,
|
||||
# attentions=all_self_attns,
|
||||
# )
|
||||
# print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}")
|
||||
return {"hidden_states": hidden_states, "past_key_values": next_cache}
|
||||
|
||||
@staticmethod
|
||||
def llama_decoder_layer_forward(
|
||||
self: LlamaDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
infer_state=infer_state,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def llama_flash_attn_kvcache_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
infer_state: Optional[BatchInferState] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
assert use_cache is True, "use_cache should be set to True using this llama attention"
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# NOTE might think about better way to handle transposed k and v
|
||||
# key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
|
||||
# key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
# NOTE might want to revise
|
||||
# need some way to record the length of past key values cache
|
||||
# since we won't return past_key_value_cache right now
|
||||
if infer_state.decode_layer_id == 0: # once per model.forward
|
||||
infer_state.cache_manager.past_key_values_length += q_len # seq_len
|
||||
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
|
||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
|
||||
if infer_state.is_context_stage:
|
||||
# print(f"rank:{torch.distributed.get_rank()}, {infer_state}")
|
||||
# first token generation
|
||||
|
||||
# copy key and value calculated in current step to memory manager
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_states,
|
||||
value_states,
|
||||
infer_state.context_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
llama_context_attn_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
|
||||
else:
|
||||
if infer_state.decode_is_contiguous:
|
||||
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
|
||||
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
|
||||
infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
|
||||
]
|
||||
cache_k.copy_(key_states)
|
||||
cache_v.copy_(value_states)
|
||||
else:
|
||||
# if decode is not contiguous, use triton kernel to copy key and value cache
|
||||
# k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
|
||||
copy_kv_to_mem_cache(
|
||||
infer_state.decode_layer_id,
|
||||
key_states,
|
||||
value_states,
|
||||
infer_state.decode_mem_index,
|
||||
infer_state.cache_manager,
|
||||
)
|
||||
|
||||
# second token and follows
|
||||
# kv = torch.stack((key_states, value_states), dim=2)
|
||||
# (batch_size, seqlen, nheads, headdim)
|
||||
attn_output = torch.empty_like(query_states)
|
||||
|
||||
token_attention_fwd(
|
||||
query_states,
|
||||
infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
|
||||
infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
|
||||
attn_output,
|
||||
infer_state.block_loc,
|
||||
infer_state.start_loc,
|
||||
infer_state.seq_len,
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
# print(f"rank:{torch.distributed.get_rank()}, {attn_output}")
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# return past_key_value as None
|
||||
return attn_output, None, None
|
||||
|
||||
|
||||
def get_llama_vllm_rmsnorm_forward():
|
||||
if HAS_VLLM_KERNERL:
|
||||
|
||||
def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
x = hidden_states
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(
|
||||
out,
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
return _vllm_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .llama import LlamaModelInferPolicy
|
||||
|
||||
__all__ = ["LlamaModelInferPolicy"]
|
|
@ -0,0 +1,145 @@
|
|||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
from ..modeling._utils import init_to_get_rotary
|
||||
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import rmsnorm_forward
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.inference_gptq:
|
||||
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=ColCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=RowCaiQuantLinear,
|
||||
kwargs={"split_num": 1},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_model_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||
)
|
||||
|
||||
infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaAttention
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
|
||||
)
|
||||
infer_forward = None
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
else:
|
||||
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
|
||||
infer_forward = get_llama_vllm_rmsnorm_forward()
|
||||
|
||||
if infer_forward is not None:
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
init_to_get_rotary(self.model.model)
|
||||
return self.model
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
|
@ -1,74 +0,0 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
from ..modeling.gpt2 import GPT2PipelineForwards
|
||||
|
||||
|
||||
class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
module_policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
GPT2LMHeadModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=GPT2LMHeadModel,
|
||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||
policy=module_policy,
|
||||
)
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
# make the tie weight lm_head and embedding in the same device to save memory
|
||||
# if self.pipeline_stage_manager.is_first_stage():
|
||||
if self.pipeline_stage_manager.is_first_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""The weights of wte and lm_head are shared."""
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None:
|
||||
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
|
||||
first_stage, last_stage = 0, stage_manager.num_stages - 1
|
||||
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
|
||||
return []
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if not self.pipeline_stage_manager:
|
||||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "GPT2Model":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.transformer
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
|
@ -1,48 +0,0 @@
|
|||
from typing import List
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import Linear1D_Col
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
from ..modeling.llama import LlamaPipelineForwards
|
||||
|
||||
|
||||
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
|
@ -1,35 +0,0 @@
|
|||
from typing import Set
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer._utils import getattr_, setattr_
|
||||
|
||||
|
||||
def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None:
|
||||
"""
|
||||
Set all parameters and buffers of model to None
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to set
|
||||
"""
|
||||
for module_suffix in include:
|
||||
set_module = getattr_(model, module_suffix)
|
||||
for n, p in set_module.named_parameters():
|
||||
setattr_(set_module, n, None)
|
||||
for n, buf in set_module.named_buffers():
|
||||
setattr_(set_module, n, None)
|
||||
setattr_(model, module_suffix, None)
|
||||
|
||||
|
||||
def get_suffix_name(suffix: str, name: str):
|
||||
"""
|
||||
Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit,
|
||||
and 'name' when `suffix` is empty.
|
||||
|
||||
Args:
|
||||
suffix (str): The suffix of the suffix module
|
||||
name (str): The name of the current module
|
||||
"""
|
||||
point = "" if suffix is "" else "."
|
||||
suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}"
|
||||
return suffix_name
|
|
@ -2,9 +2,11 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
from .kvcache_manager import MemoryManager
|
||||
|
||||
|
||||
# adapted from: lightllm/server/router/model_infer/infer_batch.py
|
||||
@dataclass
|
||||
class BatchInferState:
|
||||
|
@ -55,3 +57,62 @@ class BatchInferState:
|
|||
]
|
||||
start_index += cur_seq_len
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def init_from_batch(
|
||||
cls,
|
||||
batch: torch.Tensor,
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
cache_manager: MemoryManager,
|
||||
):
|
||||
if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)):
|
||||
raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state")
|
||||
|
||||
input_ids_list = None
|
||||
attention_mask = None
|
||||
|
||||
if isinstance(batch, (BatchEncoding, dict)):
|
||||
input_ids_list = batch["input_ids"]
|
||||
attention_mask = batch["attention_mask"]
|
||||
else:
|
||||
input_ids_list = batch
|
||||
if isinstance(input_ids_list[0], int): # for a single input
|
||||
input_ids_list = [input_ids_list]
|
||||
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
|
||||
|
||||
batch_size = len(input_ids_list)
|
||||
|
||||
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
start_index = 0
|
||||
|
||||
max_len_in_batch = -1
|
||||
if isinstance(batch, (BatchEncoding, dict)):
|
||||
for i, attn_mask in enumerate(attention_mask):
|
||||
curr_seq_len = len(attn_mask)
|
||||
seq_lengths[i] = curr_seq_len
|
||||
seq_start_indexes[i] = start_index
|
||||
start_index += curr_seq_len
|
||||
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
||||
else:
|
||||
length = max(len(input_id) for input_id in input_ids_list)
|
||||
for i, input_ids in enumerate(input_ids_list):
|
||||
curr_seq_len = length
|
||||
seq_lengths[i] = curr_seq_len
|
||||
seq_start_indexes[i] = start_index
|
||||
start_index += curr_seq_len
|
||||
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
|
||||
block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda")
|
||||
|
||||
return cls(
|
||||
batch_size=batch_size,
|
||||
max_len_in_batch=max_len_in_batch,
|
||||
seq_len=seq_lengths.to("cuda"),
|
||||
start_loc=seq_start_indexes.to("cuda"),
|
||||
block_loc=block_loc,
|
||||
decode_layer_id=0,
|
||||
past_key_values_len=0,
|
||||
is_context_stage=True,
|
||||
cache_manager=cache_manager,
|
||||
)
|
||||
|
|
|
@ -93,9 +93,9 @@ class GenerateSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
"""
|
||||
model_inputs = (
|
||||
{"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
|
||||
)
|
||||
model_inputs = {
|
||||
'infer_state': self.mb_manager.cur_descrption.infer_state
|
||||
}
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
|
@ -108,9 +108,8 @@ class GenerateSchedule(PipelineSchedule):
|
|||
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
||||
"""
|
||||
new_mask = self.mb_manager.cur_descrption.attn_mask
|
||||
past_key_values = self.mb_manager.cur_descrption.kv_cache
|
||||
|
||||
return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values)
|
||||
return dict(input_ids=new_token, attention_mask=new_mask)
|
||||
|
||||
def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
last_hidden_state = hidden_state[:, -1]
|
||||
|
@ -128,27 +127,38 @@ class GenerateSchedule(PipelineSchedule):
|
|||
return self.comm.p2p_recv()
|
||||
return self.comm.recv_forward()
|
||||
|
||||
def _load_stage_action(self, model: Module) -> None:
|
||||
def _init_infer_state_action(self) -> None:
|
||||
"""
|
||||
In this action, 1.load micro_batch 2.do the forward 3.step to update
|
||||
This action is only for no first stage, to load batch and init infer_state.
|
||||
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
|
||||
def _load_stage_action(self, model: Module) -> None:
|
||||
"""
|
||||
This action is only for first stage, load, init and do forward.
|
||||
1.load micro_batch 2.do the forward 3.step to update
|
||||
"""
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _gen_token_action(self, model: Module):
|
||||
"""
|
||||
In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update
|
||||
This action is only for first stage
|
||||
1.do the forward with hidden_states to generate new tokens 2.step to update
|
||||
"""
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
hidden_states = {"hidden_states": hidden_states}
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
||||
logits = model_forward(model, None, interval_inputs)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
|
@ -157,7 +167,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
self.mb_manager.step(new_token)
|
||||
self.action_interval_buffer.new_token = new_token
|
||||
self.action_interval_buffer.hidden_states = None
|
||||
|
||||
|
@ -168,20 +178,18 @@ class GenerateSchedule(PipelineSchedule):
|
|||
new_token = self.action_interval_buffer.new_token
|
||||
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _body_encoding_action(self, model: Module):
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
hidden_states = {"hidden_states": hidden_states}
|
||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, None, interval_inputs)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
|
||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||
"""
|
||||
|
@ -218,6 +226,8 @@ class GenerateSchedule(PipelineSchedule):
|
|||
actions.append(partial(self._gen_token_action, model))
|
||||
# other stage
|
||||
else:
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
actions.append(partial(self._init_infer_state_action))
|
||||
actions.append(partial(self._comm_action, True))
|
||||
actions.append(partial(self._body_encoding_action, model))
|
||||
|
||||
|
@ -308,8 +318,9 @@ class GenerateSchedule(PipelineSchedule):
|
|||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
# In GENERATE phase
|
||||
else:
|
||||
# Get hidden_states from previous stage
|
||||
|
@ -319,25 +330,28 @@ class GenerateSchedule(PipelineSchedule):
|
|||
assert (
|
||||
hidden_states is not None
|
||||
), "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
||||
logits = model_forward(model, None, interval_inputs)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert (
|
||||
"logits" in logits
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
self.mb_manager.step(new_token)
|
||||
# If the current micro batch is not DONE, go through blocks
|
||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||
inputs_dict = self._prepare_inputs_for_new_token(new_token)
|
||||
output_dict = model_forward(model, inputs_dict, None)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
else:
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
# inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
inputs_dict = None
|
||||
if self.mb_manager.cur_state is Status.PREFILL:
|
||||
inputs_dict = self.load_micro_batch()
|
||||
self.mb_manager.add_descrption(inputs_dict)
|
||||
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
|
||||
output_dict = model_forward(model, inputs_dict, interval_inputs)
|
||||
|
||||
# Current microbatch is not DONE, send hidden_state to next stage
|
||||
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
|
||||
|
|
|
@ -76,4 +76,5 @@ class ShardConfig:
|
|||
"""
|
||||
Set default params for inference.
|
||||
"""
|
||||
assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
|
||||
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
|
||||
pass
|
||||
|
|
|
@ -2,12 +2,15 @@ import pytest
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.pipeline.engine import PPInferEngine
|
||||
from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
|
||||
from colossalai.inference.pipeline import PPInferEngine
|
||||
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
|
||||
|
@ -24,20 +27,21 @@ for k, v in inputs.items():
|
|||
|
||||
|
||||
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
|
||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4))
|
||||
|
||||
engine = PPInferEngine(
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
||||
model_policy=LlamaModelInferPolicy(),
|
||||
new_length=new_length,
|
||||
micro_batch_size=micro_batch_size,
|
||||
)
|
||||
output = engine.inference([inputs])
|
||||
output = engine.inference(inputs)
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
|
||||
|
||||
|
||||
@parameterize("pp_size", [4])
|
||||
@parameterize("pp_size", [2])
|
||||
@parameterize("new_length", [4, 8, 16])
|
||||
@parameterize("micro_batch_size", [1, 4])
|
||||
@clear_cache_before_run()
|
||||
|
@ -51,11 +55,12 @@ def check_pipeline_inference(rank, world_size, port):
|
|||
run_pipeline_inference_test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_pipeline_inference():
|
||||
spawn(check_pipeline_inference, nprocs=4)
|
||||
spawn(check_pipeline_inference, nprocs=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue